Skip to content

Commit c618d90

Browse files
authored
feat: Add inverse transformations to supported datasets (#104)
1 parent a9de1ab commit c618d90

6 files changed

Lines changed: 75 additions & 25 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pandas==1.2.*
22
numpy==1.19.*
3-
scikit-learn==0.22.*
3+
scikit-learn==1.0.*
44
matplotlib==3.3.2
55
seaborn==0.11.*
66
tensorflow==2.4.*
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Inverts all preprocessing pipelines provided in the preprocessing examples
2+
from typing import Union
3+
4+
import pandas as pd
5+
6+
from sklearn.pipeline import Pipeline
7+
from sklearn.compose import ColumnTransformer
8+
from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler
9+
10+
11+
def inverse_transform(data: pd.DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, StandardScaler]) -> pd.DataFrame:
12+
"""Inverts data transformations taking place in a standard sklearn processor.
13+
Supported processes are sklearn pipelines, column transformers or base estimators like standard scalers.
14+
15+
Args:
16+
data (pd.DataFrame): The data object that needs inversion of preprocessing
17+
processor (Union[Pipeline, ColumnTransformer, BaseEstimator]): The processor applied on the original data
18+
19+
Returns:
20+
inv_data (pd.DataFrame): The data object after inverting preprocessing"""
21+
inv_data = data.copy()
22+
if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, Pipeline)):
23+
inv_data = pd.DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_)
24+
elif isinstance(processor, ColumnTransformer):
25+
output_indices = processor.output_indices_
26+
assert isinstance(data, pd.DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame."
27+
for t_name, t, t_cols in processor.transformers_[::-1]:
28+
slice_ = output_indices[t_name]
29+
t_indices = list(range(slice_.start, slice_.stop, 1 if slice_.step is None else slice_.step))
30+
if t == 'drop':
31+
continue
32+
elif t == 'passthrough':
33+
inv_cols = pd.DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index)
34+
inv_col_names = inv_cols.columns
35+
else:
36+
inv_cols = pd.DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index)
37+
inv_col_names = inv_cols.columns
38+
if set(inv_col_names).issubset(set(inv_data.columns)):
39+
inv_data[inv_col_names] = inv_cols[inv_col_names]
40+
else:
41+
inv_data = pd.concat([inv_data, inv_cols], axis=1)
42+
else:
43+
print('The provided data processor is not supported and cannot be inverted with this method.')
44+
return None
45+
return inv_data[processor.feature_names_in_]

src/ydata_synthetic/preprocessing/regular/adult.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,28 @@
99
def transformations():
1010
data = fetch_data('adult')
1111

12-
numerical_features = ['age', 'fnlwgt',
12+
numerical_features = ['age', 'fnlwgt',
1313
'capital-gain', 'capital-loss',
1414
'hours-per-week']
1515
numerical_transformer = Pipeline(steps=[
16-
('onehot', StandardScaler())])
16+
('scaler', StandardScaler())])
1717

18-
categorical_features = ['workclass','education', 'marital-status',
18+
categorical_features = ['workclass','education', 'marital-status',
1919
'occupation', 'relationship',
2020
'race', 'sex']
2121
categorical_transformer = Pipeline(steps=[
2222
('onehot', OneHotEncoder(handle_unknown='ignore'))])
2323

24+
remaining_features = ['education-num', 'native-country','target']
25+
remaining_transformer = 'passthrough'
2426
preprocessor = ColumnTransformer(
2527
transformers=[
2628
('num', numerical_transformer, numerical_features),
27-
('cat', categorical_transformer, categorical_features)])
29+
('cat', categorical_transformer, categorical_features),
30+
('remaining', remaining_transformer, remaining_features)])
2831

2932
processed_data = pd.DataFrame.sparse.from_spmatrix(preprocessor.fit_transform(data))
3033

3134
return data, processed_data, preprocessor
3235

3336

34-

src/ydata_synthetic/preprocessing/regular/breast_cancer_wisconsin.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,18 @@
66

77
from pmlb import fetch_data
88

9-
def transformations(auto=True):
10-
if auto:
11-
data = fetch_data('breast_cancer_wisconsin')
12-
else:
13-
data = fetch_data('breast_cancer_wisconsin')
14-
9+
def transformations():
10+
data = fetch_data('breast_cancer_wisconsin')
11+
1512
scaler = StandardScaler()
1613
processed_data = scaler.fit_transform(data)
1714
processed_data = pd.DataFrame(processed_data)
18-
15+
1916
return data, processed_data, scaler
2017

2118

2219
if __name__ == '__main__':
23-
24-
data = transformations(auto=True)
25-
20+
21+
data = transformations()
22+
2623
print(data)
27-

src/ydata_synthetic/preprocessing/regular/cardiovascular.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def transformations(data):
1919
transformers=[
2020
('num', numerical_transformer, numerical_features),
2121
('cat', categorical_transformer, categorical_features)])
22-
22+
2323
processed_data = preprocessor.fit_transform(data)
2424
processed_data = pd.DataFrame.sparse.from_spmatrix(preprocessor.fit_transform(processed_data))
25-
return processed_data, preprocessor
25+
return data, processed_data, preprocessor
Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1-
#Data transformations to be aplied
1+
#Data transformations to be applied
22
import numpy as np
33
import pandas as pd
44
import matplotlib.pyplot as plt
55
import math
66

77
from sklearn.preprocessing import PowerTransformer
8+
from sklearn.pipeline import Pipeline
9+
from sklearn.compose import ColumnTransformer
810

911
def transformations(data):
1012
#Log transformation to Amount variable
13+
processed_data = data.copy()
1114
data_cols = list(data.columns[data.columns != 'Class'])
12-
13-
#data[data_cols] = StandardScaler().fit_transform(data[data_cols])
14-
data[data_cols] = PowerTransformer(method='yeo-johnson', standardize=True, copy=True).fit_transform(data[data_cols])
15-
16-
return data
15+
16+
data_transformer = Pipeline(steps=[
17+
('PowerTransformer', PowerTransformer(method='yeo-johnson', standardize=True, copy=True))])
18+
19+
preprocessor = ColumnTransformer(
20+
transformers = [('power', data_transformer, data_cols)])
21+
processed_data[data_cols] = preprocessor.fit_transform(data[data_cols])
22+
23+
return data, processed_data, preprocessor

0 commit comments

Comments
 (0)