Skip to content

Commit bd20953

Browse files
authored
feat: update to python 3.10, update examples (#223)
* feat: update to python 3.10, update examples * feat: add CWGANGP example * chore: remove unused imports * chore: remove trailing whitespace
1 parent 6c79c89 commit bd20953

9 files changed

Lines changed: 66 additions & 58 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
![](https://img.shields.io/github/workflow/status/ydataai/ydata-synthetic/prerelease)
22
![](https://img.shields.io/pypi/status/ydata-synthetic)
33
[![](https://pepy.tech/badge/ydata-synthetic)](https://pypi.org/project/ydata-synthetic/)
4-
![](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue)
4+
![](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue)
55
[![](https://img.shields.io/pypi/v/ydata-synthetic)](https://pypi.org/project/ydata-synthetic/)
66
![](https://img.shields.io/github/license/ydataai/ydata-synthetic)
77

examples/regular/models/adult_dragan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
# DRAGAN training
1313
#Defining the training parameters of DRAGAN
14-
1514
noise_dim = 128
1615
dim = 128
1716
batch_size = 500
@@ -35,10 +34,10 @@
3534
synth = RegularSynthesizer(modelname='dragan', model_parameters=gan_args, n_discriminator=3)
3635
synth.fit(data = data, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)
3736

38-
synth.save('adult_synth.pkl')
37+
synth.save('adult_dragan_model.pkl')
3938

4039
#########################################################
4140
# Loading and sampling from a trained synthesizer #
4241
#########################################################
43-
synthesizer = RegularSynthesizer.load('adult_synth.pkl')
42+
synthesizer = RegularSynthesizer.load('adult_dragan_model.pkl')
4443
synthesizer.sample(1000)

examples/regular/models/adult_wgangp.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
from pmlb import fetch_data
22

3-
from ydata_synthetic.synthesizers.regular import WGAN_GP
3+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
44
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
55

6-
model = WGAN_GP
7-
86
#Load data and define the data processor parameters
97
data = fetch_data('adult')
108
num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
119
cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
1210
'native-country', 'target']
1311

1412
#Defining the training parameters
15-
1613
noise_dim = 128
1714
dim = 128
1815
batch_size = 500
@@ -33,10 +30,13 @@
3330
train_args = TrainParameters(epochs=epochs,
3431
sample_interval=log_step)
3532

36-
synthesizer = model(gan_args, n_critic=2)
37-
synthesizer.train(data, train_args, num_cols, cat_cols)
33+
synth = RegularSynthesizer(modelname='wgangp', model_parameters=gan_args, n_critic=2)
34+
synth.fit(data, train_args, num_cols, cat_cols)
3835

39-
synthesizer.save('test.pkl')
36+
synth.save('adult_wgangp_model.pkl')
4037

41-
synthesizer = model.load('test.pkl')
42-
synth_data = synthesizer.sample(1000)
38+
#########################################################
39+
# Loading and sampling from a trained synthesizer #
40+
#########################################################
41+
synth = RegularSynthesizer.load('adult_wgangp_model.pkl')
42+
synth_data = synth.sample(1000)

examples/regular/models/cgan_example.py renamed to examples/regular/models/creditcard_cgan.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
11
"""
22
CGAN architecture example file
33
"""
4-
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
5-
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
6-
74
import pandas as pd
8-
import numpy as np
95
from sklearn import cluster
106

7+
from ydata_synthetic.utils.cache import cache_file
8+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
9+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
10+
1111
#Read the original data and have it preprocessed
12-
data = pd.read_csv('../../data/creditcard.csv', index_col=[0])
12+
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
13+
data = pd.read_csv(data_path, index_col=[0])
1314

14-
#List of columns different from the Class column
15+
#Data processing and analysis
1516
num_cols = list(data.columns[ data.columns != 'Class' ])
16-
cat_cols = [] # Condition features are not preprocessed and therefore not listed here
17+
cat_cols = []
1718

1819
print('Dataset columns: {}'.format(num_cols))
19-
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
20-
data = data[ sorted_cols ].copy()
20+
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19',
21+
'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15',
22+
'V9', 'V23', 'Class']
23+
processed_data = data[ sorted_cols ].copy()
24+
processed_data['Class'] = processed_data['Class'].apply(lambda x: 1 if x == "'1'" else 0)
2125

2226
#For the purpose of this example we will only synthesize the minority class
23-
train_data = data.loc[ data['Class']==1 ].copy()
27+
train_data = processed_data.loc[processed_data['Class'] == 1].copy()
2428

2529
#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional GAN
2630
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
2731
algorithm = cluster.KMeans
2832
args, kwds = (), {'n_clusters':2, 'random_state':0}
2933
labels = algorithm(*args, **kwds).fit_predict(train_data[ num_cols ])
3034

31-
print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )
32-
3335
fraud_w_classes = train_data.copy()
3436
fraud_w_classes['Class'] = labels
3537

@@ -72,10 +74,10 @@
7274
synth.fit(data=fraud_w_classes, label_cols=["Class"], train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)
7375

7476
#Saving the synthesizer
75-
synth.save('cgan_synthtrained.pkl')
77+
synth.save('creditcard_cgan_model.pkl')
7678

7779
#Loading the synthesizer
78-
synthesizer = RegularSynthesizer.load('cgan_synthtrained.pkl')
80+
synthesizer = RegularSynthesizer.load('creditcard_cgan_model.pkl')
7981

8082
#Sampling from the synthesizer
8183
cond_array = pd.DataFrame(100*[1], columns=['Class'])

examples/regular/models/cramergan_example.py renamed to examples/regular/models/creditcard_cramergan.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,25 @@
77
import numpy as np
88
import pandas as pd
99

10+
from ydata_synthetic.utils.cache import cache_file
1011
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
1112
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
1213

1314
#Read the original data and have it preprocessed
14-
data = pd.read_csv('../../../data/creditcard.csv', index_col=[0])
15+
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
16+
data = pd.read_csv(data_path, index_col=[0])
1517

16-
#List of columns different from the Class column
18+
#Data processing and analysis
1719
num_cols = list(data.columns[ data.columns != 'Class' ])
1820
cat_cols = ['Class']
1921

2022
print('Dataset columns: {}'.format(num_cols))
2123
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
22-
data = data[ sorted_cols ].copy()
24+
processed_data = data[ sorted_cols ].copy()
25+
processed_data['Class'] = processed_data['Class'].apply(lambda x: 1 if x == "'1'" else 0)
2326

2427
#For the purpose of this example we will only synthesize the minority class
25-
train_data = data.loc[ data['Class']==1 ].copy()
28+
train_data = processed_data.loc[processed_data['Class'] == 1].copy()
2629

2730
#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional GAN
2831
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
@@ -62,12 +65,12 @@
6265
synth.fit(data=train_data, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)
6366

6467
#Saving the synthesizer to later generate new events
65-
synth.save(path='cramergan_creditcard.pkl')
68+
synth.save(path='creditcard_cramergan_model.pkl')
6669

6770
#########################################################
6871
# Loading and sampling from a trained synthesizer #
6972
#########################################################
70-
synth = RegularSynthesizer.load(path='cramergan_creditcard.pkl')
73+
synth = RegularSynthesizer.load(path='creditcard_cramergan_model.pkl')
7174
#Sampling the data
7275
#Note that the data returned it is not inverse processed.
7376
data_sample = synth.sample(100000)

examples/regular/models/cwgangp_example.py renamed to examples/regular/models/creditcard_cwgangp.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
1-
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
2-
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
3-
41
import pandas as pd
52
import numpy as np
63
from sklearn import cluster
74

5+
from ydata_synthetic.utils.cache import cache_file
6+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
7+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
8+
89
#Read the original data and have it preprocessed
9-
data = pd.read_csv('../../data/creditcard.csv', index_col=[0])
10+
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
11+
data = pd.read_csv(data_path, index_col=[0])
1012

11-
#List of columns different from the Class column
12-
num_cols = list(data.columns[~data.columns.isin(['Class', 'Amount'])])
13-
cat_cols = [] # Condition features are not preprocessed and therefore not listed here
13+
#Data processing and analysis
14+
num_cols = list(data.columns[ data.columns != 'Class' ])
15+
cat_cols = [] #['Class']
1416

1517
print('Dataset columns: {}'.format(num_cols))
1618
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
17-
data = data[ sorted_cols ].copy()
19+
processed_data = data[ sorted_cols ].copy()
20+
processed_data['Class'] = processed_data['Class'].apply(lambda x: 1 if x == "'1'" else 0)
1821

1922
#For the purpose of this example we will only synthesize the minority class
20-
train_data = data.loc[ data['Class']==1 ].copy()
23+
train_data = processed_data.loc[processed_data['Class'] == 1].copy()
2124

2225
#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional WGANGP
2326
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
2427
algorithm = cluster.KMeans
2528
args, kwds = (), {'n_clusters':2, 'random_state':0}
2629
labels = algorithm(*args, **kwds).fit_predict(train_data[ num_cols ])
2730

28-
print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )
29-
3031
fraud_w_classes = train_data.copy()
3132
fraud_w_classes['Class'] = labels
3233

@@ -66,16 +67,16 @@
6667
synth = RegularSynthesizer(modelname='cwgangp', model_parameters=gan_args, n_critic=5)
6768

6869
#Fitting the synthesizer
69-
synth.fit(data=fraud_w_classes, label_cols=["Class", "Amount"], train_arguments=train_args,
70+
synth.fit(data=fraud_w_classes, label_cols=["Class"], train_arguments=train_args,
7071
num_cols=num_cols, cat_cols=cat_cols)
7172

72-
synth.save('.model.pkl')
73+
synth.save('creditcard_cwgangp_model.pkl')
7374

7475
#########################################################
7576
# Loading and sampling from a trained synthesizer #
7677
#########################################################
77-
new_synth = RegularSynthesizer.load('.model.pkl')
78+
new_synth = RegularSynthesizer.load('creditcard_cwgangp_model.pkl')
7879

7980
sample_len = 2000
80-
cond_array = fraud_w_classes[["Class", "Amount"]]
81+
cond_array = fraud_w_classes[["Class"]]
8182
new_synth.sample(cond_array)

examples/regular/models/wgan_example.py renamed to examples/regular/models/creditcard_wgan.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import pandas as pd
55
import numpy as np
66

7+
from ydata_synthetic.utils.cache import cache_file
78
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
89
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
910

1011
#Read the original data and have it preprocessed
11-
data = pd.read_csv('../../../data/creditcard.csv', index_col=[0])
12+
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
13+
data = pd.read_csv(data_path, index_col=[0])
1214

1315
#Data processing and analysis
1416
num_cols = list(data.columns[ data.columns != 'Class' ])
@@ -17,9 +19,10 @@
1719
print('Dataset columns: {}'.format(num_cols))
1820
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
1921
processed_data = data[ sorted_cols ].copy()
22+
processed_data['Class'] = processed_data['Class'].apply(lambda x: 1 if x == "'1'" else 0)
2023

2124
#For the purpose of this example we will only synthesize the minority class
22-
train_data = data.loc[ data['Class']==1 ].copy()
25+
train_data = processed_data.loc[processed_data['Class'] == 1].copy()
2326

2427
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
2528
algorithm = cluster.KMeans
@@ -61,12 +64,12 @@
6164
synth.fit(data=train_data, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)
6265

6366
#Saving the synthesizer to later generate new events
64-
synth.save(path='models/wgan_creditcard.pkl')
67+
synth.save(path='creditcard_wgan_model.pkl')
6568

6669
#########################################################
6770
# Loading and sampling from a trained synthesizer #
6871
#########################################################
69-
synth = RegularSynthesizer.load(path='models/wgan_creditcard.pkl')
72+
synth = RegularSynthesizer.load(path='creditcard_wgan_model.pkl')
7073

7174
#Sampling the data
7275
data_sample = synth.sample(100000)

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
requests>=2.24.0, <2.29
2-
pandas==1.4.*
2+
pandas==1.5.*
33
numpy==1.23.*
4-
scikit-learn==1.1.*
5-
matplotlib==3.5.*
6-
tensorflow==2.9.0
7-
easydict==1.9
4+
scikit-learn==1.2.*
5+
matplotlib==3.6.*
6+
tensorflow==2.11.0
7+
easydict==1.10
88
pmlb==1.0.*
99
tqdm<5.0
1010
typeguard==2.13.*

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
keywords='data science ydata',
4343
url='https://github.com/ydataai/ydata-synthetic',
4444
license="https://github.com/ydataai/ydata-synthetic/blob/master/LICENSE",
45-
python_requires=">=3.6, <3.9",
45+
python_requires=">=3.6, <3.11",
4646
packages=find_namespace_packages('src'),
4747
package_dir={'':'src'},
4848
include_package_data=True,

0 commit comments

Comments
 (0)