Skip to content

Commit cea1d8e

Browse files
feat: add CTGAN model (#233)
* feat: add CTGAN model * fix: change imports
1 parent e4cace4 commit cea1d8e

10 files changed

Lines changed: 876 additions & 28 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pmlb import fetch_data
2+
3+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
4+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
5+
6+
# Load data and define the data processor parameters
7+
data = fetch_data('adult')
8+
num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
9+
cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
10+
'native-country', 'target']
11+
12+
# Defining the training parameters
13+
batch_size = 500
14+
epochs = 500+1
15+
learning_rate = 2e-4
16+
beta_1 = 0.5
17+
beta_2 = 0.9
18+
19+
ctgan_args = ModelParameters(batch_size=batch_size,
20+
lr=learning_rate,
21+
betas=(beta_1, beta_2))
22+
23+
train_args = TrainParameters(epochs=epochs)
24+
synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)
25+
synth.fit(data=data, train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)
26+
27+
synth.save('adult_ctgan_model.pkl')
28+
29+
#########################################################
30+
# Loading and sampling from a trained synthesizer #
31+
#########################################################
32+
synth = RegularSynthesizer.load('adult_ctgan_model.pkl')
33+
synth_data = synth.sample(1000)
34+
print(synth_data)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
CTGAN architecture example file
3+
"""
4+
import pandas as pd
5+
from sklearn import cluster
6+
7+
from ydata_synthetic.utils.cache import cache_file
8+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
9+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
10+
11+
# Read the original data and have it preprocessed
12+
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
13+
data = pd.read_csv(data_path, index_col=[0])
14+
15+
# Data processing and analysis
16+
num_cols = list(data.columns[ data.columns != 'Class' ])
17+
cat_cols = []
18+
19+
print('Dataset columns: {}'.format(num_cols))
20+
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19',
21+
'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15',
22+
'V9', 'V23', 'Class']
23+
processed_data = data[ sorted_cols ].copy()
24+
processed_data['Class'] = processed_data['Class'].apply(lambda x: 1 if x == "'1'" else 0)
25+
26+
# For the purpose of this example we will only synthesize the minority class
27+
train_data = processed_data.loc[processed_data['Class'] == 1].copy()
28+
29+
# Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional GAN
30+
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
31+
algorithm = cluster.KMeans
32+
args, kwds = (), {'n_clusters':2, 'random_state':0}
33+
labels = algorithm(*args, **kwds).fit_predict(train_data[num_cols])
34+
35+
fraud_w_classes = train_data.copy()
36+
fraud_w_classes['Class'] = labels
37+
38+
#----------------------------
39+
# CTGAN Training
40+
#----------------------------
41+
42+
batch_size = 500
43+
epochs = 500+1
44+
learning_rate = 2e-4
45+
beta_1 = 0.5
46+
beta_2 = 0.9
47+
48+
ctgan_args = ModelParameters(batch_size=batch_size,
49+
lr=learning_rate,
50+
betas=(beta_1, beta_2))
51+
52+
train_args = TrainParameters(epochs=epochs)
53+
54+
# Create a bining
55+
fraud_w_classes['Amount'] = pd.cut(fraud_w_classes['Amount'], 5).cat.codes
56+
57+
# Init the CTGAN
58+
synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)
59+
60+
#Training the CTGAN
61+
synth.fit(data=fraud_w_classes, train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)
62+
63+
# Saving the synthesizer
64+
synth.save('creditcard_ctgan_model.pkl')
65+
66+
# Loading the synthesizer
67+
synthesizer = RegularSynthesizer.load('creditcard_ctgan_model.pkl')
68+
69+
# Sampling from the synthesizer
70+
sample = synthesizer.sample(1000)
71+
print(sample)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ numpy==1.23.*
44
scikit-learn==1.2.*
55
matplotlib==3.6.*
66
tensorflow==2.11.0
7+
tensorflow-probability==0.19.0
78
easydict==1.10
89
pmlb==1.0.*
910
tqdm<5.0
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional
4+
from typeguard import typechecked
5+
from dataclasses import dataclass
6+
import pandas as pd
7+
import numpy as np
8+
from sklearn.exceptions import NotFittedError, ConvergenceWarning
9+
from sklearn.utils._testing import ignore_warnings
10+
from sklearn.mixture import BayesianGaussianMixture
11+
from sklearn.preprocessing import OneHotEncoder
12+
13+
from ydata_synthetic.preprocessing.base_processor import BaseProcessor
14+
15+
@dataclass
16+
class ColumnMetadata:
17+
"""
18+
Dataclass that stores the metadata of each column.
19+
"""
20+
start_idx: int
21+
end_idx: int
22+
discrete: bool
23+
output_dim: int
24+
model: any
25+
components: list
26+
name: str
27+
28+
29+
@typechecked
30+
class CTGANDataProcessor(BaseProcessor):
31+
"""
32+
CTGAN data preprocessing class.
33+
It works like any other transformer in scikit-learn with the methods fit, transform and inverse_transform.
34+
Args:
35+
n_clusters (int), default=10:
36+
Number of clusters.
37+
epsilon (float), default=0.005:
38+
Epsilon value.
39+
num_cols (list of strings):
40+
List of names of numerical columns.
41+
cat_cols (list of strings):
42+
List of names of categorical columns.
43+
"""
44+
SUPPORTED_MODEL = 'CTGAN'
45+
46+
def __init__(self, n_clusters=10, epsilon=0.005,
47+
num_cols: Optional[List[str]] = None,
48+
cat_cols: Optional[List[str]] = None):
49+
super().__init__(num_cols, cat_cols)
50+
51+
self._n_clusters = n_clusters
52+
self._epsilon = epsilon
53+
self._metadata = None
54+
self._dtypes = None
55+
self._output_dimensions = None
56+
57+
@property
58+
def metadata(self) -> list[ColumnMetadata]:
59+
"""
60+
Returns the metadata for each column.
61+
"""
62+
return self._metadata
63+
64+
@property
65+
def output_dimensions(self) -> int:
66+
"""
67+
Returns the dataset dimensionality after the preprocessing.
68+
"""
69+
return int(self._output_dimensions)
70+
71+
@ignore_warnings(category=ConvergenceWarning)
72+
def fit(self, X: pd.DataFrame) -> CTGANDataProcessor:
73+
"""
74+
Fits the data processor to a passed DataFrame.
75+
76+
Args:
77+
X (DataFrame):
78+
DataFrame used to fit the processor parameters.
79+
Should be aligned with the num/cat columns defined in initialization.
80+
Returns:
81+
self (CTGANDataProcessor): The fitted data processor.
82+
"""
83+
self._dtypes = X.infer_objects().dtypes
84+
self._metadata = []
85+
cur_idx = 0
86+
for column in X.columns:
87+
column_data = X[[column]].values
88+
if column in self.cat_cols:
89+
ohe = OneHotEncoder(sparse_output=False)
90+
ohe.fit(column_data)
91+
n_categories = len(ohe.categories_[0])
92+
self._metadata.append(
93+
ColumnMetadata(
94+
start_idx=cur_idx,
95+
end_idx=cur_idx + n_categories,
96+
discrete=True,
97+
output_dim=n_categories,
98+
model=ohe,
99+
components=None,
100+
name=column
101+
)
102+
)
103+
cur_idx += n_categories
104+
else:
105+
bgm = BayesianGaussianMixture(
106+
n_components=self._n_clusters,
107+
weight_concentration_prior_type='dirichlet_process',
108+
weight_concentration_prior=0.001,
109+
n_init=1
110+
)
111+
bgm.fit(column_data)
112+
components = bgm.weights_ > self._epsilon
113+
output_dim = components.sum() + 1
114+
self._metadata.append(
115+
ColumnMetadata(
116+
start_idx=cur_idx,
117+
end_idx=cur_idx + output_dim,
118+
discrete=False,
119+
output_dim=output_dim,
120+
model=bgm,
121+
components=components,
122+
name=column
123+
)
124+
)
125+
cur_idx += output_dim
126+
self._output_dimensions = cur_idx
127+
return self
128+
129+
def transform(self, X: pd.DataFrame) -> np.ndarray:
130+
"""
131+
Transforms the passed DataFrame with the fitted data processor.
132+
133+
Args:
134+
X (DataFrame):
135+
DataFrame used to fit the processor parameters.
136+
Should be aligned with the columns types defined in initialization.
137+
Returns:
138+
Processed version of the passed DataFrame.
139+
"""
140+
if self._metadata is None:
141+
raise NotFittedError("This data processor has not yet been fitted.")
142+
143+
transformed_data = []
144+
for col_md in self._metadata:
145+
column_data = X[[col_md.name]].values
146+
if col_md.discrete:
147+
ohe = col_md.model
148+
transformed_data.append(ohe.transform(column_data))
149+
else:
150+
bgm = col_md.model
151+
components = col_md.components
152+
153+
means = bgm.means_.reshape((1, self._n_clusters))
154+
stds = np.sqrt(bgm.covariances_).reshape((1, self._n_clusters))
155+
features = (column_data - means) / (4 * stds)
156+
157+
probabilities = bgm.predict_proba(column_data)
158+
n_opts = components.sum()
159+
features = features[:, components]
160+
probabilities = probabilities[:, components]
161+
162+
opt_sel = np.zeros(len(column_data), dtype='int')
163+
for i in range(len(column_data)):
164+
norm_probs = probabilities[i] + 1e-6
165+
norm_probs = norm_probs / norm_probs.sum()
166+
opt_sel[i] = np.random.choice(np.arange(n_opts), p=norm_probs)
167+
168+
idx = np.arange((len(features)))
169+
features = features[idx, opt_sel].reshape([-1, 1])
170+
features = np.clip(features, -.99, .99)
171+
172+
probs_onehot = np.zeros_like(probabilities)
173+
probs_onehot[np.arange(len(probabilities)), opt_sel] = 1
174+
transformed_data.append(
175+
np.concatenate([features, probs_onehot], axis=1).astype(float))
176+
177+
return np.concatenate(transformed_data, axis=1).astype(float)
178+
179+
def inverse_transform(self, X: np.ndarray) -> pd.DataFrame:
180+
"""
181+
Reverts the data transformations on a passed DataFrame.
182+
183+
Args:
184+
X (ndarray):
185+
Numpy array to be brought back to the original data format.
186+
Should share the schema of data transformed by this data processor.
187+
Can be used to revert transformations of training data or for synthetic samples.
188+
Returns:
189+
DataFrame with all performed transformations reverted.
190+
"""
191+
if self._metadata is None:
192+
raise NotFittedError("This data processor has not yet been fitted.")
193+
194+
transformed_data = []
195+
col_names = []
196+
for col_md in self._metadata:
197+
col_data = X[:, col_md.start_idx:col_md.end_idx]
198+
if col_md.discrete:
199+
inv_data = col_md.model.inverse_transform(col_data)
200+
else:
201+
mean = col_data[:, 0]
202+
variance = col_data[:, 1:]
203+
mean = np.clip(mean, -1, 1)
204+
205+
v_t = np.ones((len(col_data), self._n_clusters)) * -100
206+
v_t[:, col_md.components] = variance
207+
variance = v_t
208+
means = col_md.model.means_.reshape([-1])
209+
stds = np.sqrt(col_md.model.covariances_).reshape([-1])
210+
211+
p_argmax = np.argmax(variance, axis=1)
212+
std_t = stds[p_argmax]
213+
mean_t = means[p_argmax]
214+
inv_data = mean * 4 * std_t + mean_t
215+
216+
transformed_data.append(inv_data)
217+
col_names.append(col_md.name)
218+
219+
transformed_data = np.column_stack(transformed_data)
220+
transformed_data = pd.DataFrame(transformed_data, columns=col_names).astype(self._dtypes)
221+
return transformed_data

0 commit comments

Comments
 (0)