Skip to content

Commit ae641a6

Browse files
authored
feat: Add new regular data model synthesizer CWGANGP (#153)
* CWGANGP * fix column order in cgan sample
1 parent 4fc1f4f commit ae641a6

10 files changed

Lines changed: 347 additions & 43 deletions

File tree

examples/regular/cgan_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@
7777
#Sampling from the synthesizer
7878
cond_array = np.array([0])
7979
# Synthesizer samples are returned in the original format (inverse_transform of internal processing already took place)
80-
synthesizer = synthesizer.sample(cond_array, 1000)
80+
sample = synthesizer.sample(cond_array, 1000)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from ydata_synthetic.synthesizers.regular import CWGANGP
2+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
3+
4+
import pandas as pd
5+
import numpy as np
6+
from sklearn import cluster
7+
8+
model = CWGANGP
9+
10+
#Read the original data and have it preprocessed
11+
data = pd.read_csv('data/creditcard.csv', index_col=[0])
12+
13+
#List of columns different from the Class column
14+
num_cols = list(data.columns[ data.columns != 'Class' ])
15+
cat_cols = [] # Condition features are not preprocessed and therefore not listed here
16+
17+
print('Dataset columns: {}'.format(num_cols))
18+
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
19+
data = data[ sorted_cols ].copy()
20+
21+
#For the purpose of this example we will only synthesize the minority class
22+
train_data = data.loc[ data['Class']==1 ].copy()
23+
24+
#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional WGANGP
25+
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
26+
algorithm = cluster.KMeans
27+
args, kwds = (), {'n_clusters':2, 'random_state':0}
28+
labels = algorithm(*args, **kwds).fit_predict(train_data[ num_cols ])
29+
30+
print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )
31+
32+
fraud_w_classes = train_data.copy()
33+
fraud_w_classes['Class'] = labels
34+
35+
#----------------------------
36+
# GAN Training
37+
#----------------------------
38+
39+
#Define the Conditional WGANGP and training parameters
40+
noise_dim = 32
41+
dim = 128
42+
batch_size = 128
43+
beta_1 = 0.5
44+
beta_2 = 0.9
45+
46+
log_step = 100
47+
epochs = 300 + 1
48+
learning_rate = 5e-4
49+
models_dir = './cache'
50+
51+
#Test here the new inputs
52+
gan_args = ModelParameters(batch_size=batch_size,
53+
lr=learning_rate,
54+
betas=(beta_1, beta_2),
55+
noise_dim=noise_dim,
56+
layers_dim=dim)
57+
58+
train_args = TrainParameters(epochs=epochs,
59+
cache_prefix='',
60+
sample_interval=log_step,
61+
label_dim=-1,
62+
labels=(0,1))
63+
64+
#Init the Conditional WGANGP providing the index of the label column as one of the arguments
65+
synthesizer = model(model_parameters=gan_args, num_classes=2, n_critic=3)
66+
67+
#Training the Conditional WGANGP
68+
synthesizer.train(data=fraud_w_classes, label_col="Class", train_arguments=train_args,
69+
num_cols=num_cols, cat_cols=cat_cols)
70+
71+
#Saving the synthesizer
72+
synthesizer.save('cwgangp_synthtrained.pkl')
73+
74+
#Loading the synthesizer
75+
synthesizer = model.load('cwgangp_synthtrained.pkl')
76+
77+
#Sampling from the synthesizer
78+
cond_array = np.array([0])
79+
# Synthesizer samples are returned in the original format (inverse_transform of internal processing already took place)
80+
sample = synthesizer.sample(cond_array, 1000)

src/ydata_synthetic/preprocessing/base_processor.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22
from __future__ import annotations
33

44
from abc import ABC, abstractmethod
5-
from collections import namedtuple
5+
from types import SimpleNamespace
66
from typing import List, Optional
77

8-
from numpy import concatenate, ndarray, split, zeros
9-
from pandas import DataFrame, Series, concat
8+
from numpy import ndarray
9+
from pandas import DataFrame, Series
1010
from sklearn.base import BaseEstimator, TransformerMixin
1111
from sklearn.exceptions import NotFittedError
1212
from typeguard import typechecked
1313

14-
ProcessorInfo = namedtuple("ProcessorInfo", ["numerical", "categorical"])
15-
PipelineInfo = namedtuple("PipelineInfo", ["feat_names_in", "feat_names_out"])
1614

1715
# pylint: disable=R0902
1816
@typechecked
@@ -50,23 +48,25 @@ def types(self) -> Series:
5048
return self._types
5149

5250
@property
53-
def col_transform_info(self) -> ProcessorInfo:
51+
def col_transform_info(self) -> SimpleNamespace:
5452
"""Returns a ProcessorInfo object specifying input/output feature mappings of this processor's pipelines."""
5553
self._check_is_fitted()
5654
if self._col_transform_info is None:
5755
self._col_transform_info = self.__create_metadata_synth()
5856
return self._col_transform_info
5957

60-
def __create_metadata_synth(self):
61-
num_info = PipelineInfo([], [])
62-
cat_info = PipelineInfo([], [])
63-
# Numerical ls named tuple
58+
def __create_metadata_synth(self) -> SimpleNamespace:
59+
def new_pipeline_info(feat_in, feat_out):
60+
return SimpleNamespace(feat_names_in = feat_in, feat_names_out = feat_out)
6461
if self.num_cols:
65-
num_info = PipelineInfo(self.num_pipeline.feature_names_in_, self.num_pipeline.get_feature_names_out())
66-
# Categorical ls named tuple
62+
num_info = new_pipeline_info(self.num_pipeline.feature_names_in_, self.num_pipeline.get_feature_names_out())
63+
else:
64+
num_info = new_pipeline_info([], [])
6765
if self.cat_cols:
68-
cat_info = PipelineInfo(self.cat_pipeline.feature_names_in_, self.cat_pipeline.get_feature_names_out())
69-
return ProcessorInfo(num_info, cat_info)
66+
cat_info = new_pipeline_info(self.cat_pipeline.feature_names_in_, self.cat_pipeline.get_feature_names_out())
67+
else:
68+
cat_info = new_pipeline_info([], [])
69+
return SimpleNamespace(numerical=num_info, categorical=cat_info)
7070

7171
def _check_is_fitted(self):
7272
"""Checks if the processor is fitted by testing the numerical pipeline.

src/ydata_synthetic/preprocessing/regular/processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class RegularModels(Enum):
2121
GAN = 'VanillaGAN'
2222
WGAN = 'WGAN'
2323
WGAN_GP = 'WGAN_GP'
24+
CWGAN_GP = 'CWGAN_GP'
2425

2526

2627
@typechecked

src/ydata_synthetic/synthesizers/gan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def sample(self, n_samples: int):
131131
def save(self, path):
132132
"Saves the pickled synthesizer instance in the given path."
133133
#Save only the generator?
134-
if self.__MODEL__=='WGAN' or self.__MODEL__=='WGAN_GP':
134+
if self.__MODEL__=='WGAN' or self.__MODEL__=='WGAN_GP' or self.__MODEL__=='CWGAN_GP':
135135
del self.critic
136136
make_keras_picklable()
137137
dump(self, path)

src/ydata_synthetic/synthesizers/regular/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP
55
from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN
66
from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN
7+
from ydata_synthetic.synthesizers.regular.cwgangp.model import CWGANGP
78

89
__all__ = [
910
"VanilllaGAN",
1011
"CGAN",
1112
"WGAN",
1213
"WGAN_GP",
1314
"DRAGAN",
14-
"CRAMERGAN"
15+
"CRAMERGAN",
16+
"CWGANGP"
1517
]

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""CGAN implementation"""
22
import os
33
from os import path
4-
from typing import List, Tuple, Union, Optional, NamedTuple
4+
from typing import List, Optional, NamedTuple
55

66
import numpy as np
77
from numpy import array, empty, hstack, ndarray, vstack, save
@@ -30,6 +30,7 @@ class CGAN(BaseModel):
3030
def __init__(self, model_parameters, num_classes):
3131
self.num_classes = num_classes
3232
self._label_col = None
33+
self._col_order = None
3334
super().__init__(model_parameters)
3435

3536
@property
@@ -38,18 +39,9 @@ def label_col(self) -> str:
3839
return self._label_col
3940

4041
@label_col.setter
41-
def label_col(self, data_label: Tuple[Union[DataFrame, array], str]):
42-
"Validates the label_col format, raises ValueError if invalid."
43-
data, label_col = data_label
44-
assert label_col in data.columns, f"The column {label_col} could not be found on the provided dataset and \
45-
cannot be used as condition."
46-
assert data[label_col].isna().sum() == 0, "The label column contains NaN values, please impute or drop the \
47-
respective records before proceeding."
48-
assert is_float_dtype(data[label_col]) or is_integer_dtype(data[label_col]), "The label column is expected to be an \
49-
integer or a float dtype to ensure the function of the embedding layer."
50-
unique_frac = data[label_col].nunique()/len(data.index)
51-
assert unique_frac < 1, "The provided column {label_col} is constituted by unique values and is not suitable \
52-
to be used as condition."
42+
def label_col(self, label_col: str):
43+
"""Set the label_col property."""
44+
self._label_col = label_col
5345

5446
def define_gan(self, activation_info: Optional[NamedTuple] = None):
5547
self.generator = Generator(self.batch_size, self.num_classes). \
@@ -103,18 +95,20 @@ def get_data_batch(self, data, batch_size, seed=0):
10395
data_ix = np.random.choice(data.shape[0], replace=False, size=len(data)) # wasteful to shuffle every time
10496
return data[data_ix[start_i: stop_i]]
10597

106-
def train(self, data: Union[DataFrame, array], label_col: str, train_arguments: TrainParameters, num_cols: List[str],
98+
def train(self, data: DataFrame, label_col: str, train_arguments: TrainParameters, num_cols: List[str],
10799
cat_cols: List[str]):
108100
"""
109101
Args:
110-
data: A pandas DataFrame or a Numpy array with the data to be synthesized
102+
data: A pandas DataFrame with the data to be synthesized
111103
label: The name of the column to be used as a label and condition for the training
112104
train_arguments: GAN training arguments.
113105
num_cols: List of columns of the data object to be handled as numerical
114106
cat_cols: List of columns of the data object to be handled as categorical
115107
"""
116108
# Validating the label column
117-
self.label_col = (data, label_col)
109+
self._validate_label_col(data, label_col)
110+
self._col_order = data.columns
111+
self.label_col = label_col
118112

119113
# Separating labels from the rest of the data to fit the data processor
120114
data, label = data.loc[:, data.columns != label_col], expand_dims(data[label_col], 1)
@@ -182,15 +176,28 @@ def sample(self, condition: ndarray, n_samples: int,) -> ndarray:
182176
steps = n_samples // self.batch_size + 1
183177
data = []
184178
z_dist = self.get_batch_noise()
185-
condition = expand_dims(convert_to_tensor(condition, dtypes.float32), axis=0)
186-
cond_seq = tile(condition, multiples=[self.batch_size, 1])
179+
cond_seq = expand_dims(convert_to_tensor(condition, dtypes.float32), axis=0)
180+
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
187181
for _ in trange(steps, desc='Synthetic data generation'):
188182
records = empty(shape=(self.batch_size, self.data_dim))
189183
records = self.generator([next(z_dist), cond_seq], training=False)
190184
data.append(records)
191185
data = self.processor.inverse_transform(array(vstack(data)))
192-
data[self.label_col] = tile(condition, multiples=[data.shape[0], 1])
193-
return data
186+
data[self.label_col] = condition[0]
187+
return data[self._col_order]
188+
189+
@staticmethod
190+
def _validate_label_col(data: DataFrame, label_col: str):
191+
"Validates the label_col format, raises ValueError if invalid."
192+
assert label_col in data.columns, f"The column {label_col} could not be found on the provided dataset and \
193+
cannot be used as condition."
194+
assert data[label_col].isna().sum() == 0, "The label column contains NaN values, please impute or drop the \
195+
respective records before proceeding."
196+
assert is_float_dtype(data[label_col]) or is_integer_dtype(data[label_col]), "The label column is expected to be an \
197+
integer or a float dtype to ensure the function of the embedding layer."
198+
unique_frac = data[label_col].nunique()/len(data.index)
199+
assert unique_frac < 1, "The provided column {label_col} is constituted by unique values and is not suitable \
200+
to be used as condition."
194201

195202

196203
# pylint: disable=R0903

src/ydata_synthetic/synthesizers/regular/cwgangp/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)