Skip to content

Commit 49f6221

Browse files
authored
feat: Support save model in CGAN (#64)
1 parent 6e4368b commit 49f6221

2 files changed

Lines changed: 29 additions & 86 deletions

File tree

examples/regular/cgan_example.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
noise_dim = 32
4343
dim = 128
4444
batch_size = 128
45+
beta_1 = 0.5
46+
beta_2 = 0.9
4547

4648
log_step = 100
47-
epochs = 200+1
49+
epochs = 500 + 1
4850
learning_rate = 5e-4
49-
beta_1 = 0.5
50-
beta_2 = 0.9
5151
models_dir = './cache'
5252

5353
train_sample = fraud_w_classes.copy().reset_index(drop=True)
@@ -57,11 +57,14 @@
5757
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
5858
train_no_label = train_sample[ data_cols ]
5959

60-
gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, train_sample.shape[1], 2, (0, 1), dim]
61-
train_args = ['', label_cols[0], epochs, log_step, '']
60+
gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, train_sample.shape[1], dim]
61+
train_args = ['', -1, epochs, log_step, (0, 1)]
6262

6363
#Init the Conditional GAN providing the index of the label column as one of the arguments
64-
synthesizer = CGAN(gan_args)
64+
synthesizer = CGAN(gan_args, num_classes=2)
6565

6666
#Training the Conditional GAN
67-
synthesizer.train(train_sample, train_args)
67+
synthesizer.train(train_sample, train_args)
68+
69+
#Saving the synthesizer
70+
synthesizer.save('cgan_synthtrained.pkl')

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 19 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,29 @@
1010

1111
from tensorflow.keras.optimizers import Adam
1212

13-
class CGAN():
13+
class CGAN(gan.Model):
1414

15-
def __init__(self, model_parameters):
16-
[self.batch_size, lr,self.beta_1, self.beta_2, self.noise_dim,
17-
self.data_dim, num_classes, self.classes, layers_dim] = model_parameters
15+
def __init__(self, model_parameters, num_classes):
16+
self.num_classes = num_classes
17+
super().__init__(model_parameters)
1818

19-
self.generator = Generator(self.batch_size, num_classes). \
20-
build_model(input_shape=(self.noise_dim,), dim=layers_dim, data_dim=self.data_dim)
19+
def define_gan(self):
20+
self.generator = Generator(self.batch_size, self.num_classes). \
21+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
2122

22-
self.discriminator = Discriminator(self.batch_size, num_classes). \
23-
build_model(input_shape=(self.data_dim,), dim=layers_dim)
23+
self.discriminator = Discriminator(self.batch_size, self.num_classes). \
24+
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
2425

25-
optimizer = Adam(lr, beta_1=self.beta_1, beta_2=self.beta_2)
26+
optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)
2627

2728
# Build and compile the discriminator
2829
self.discriminator.compile(loss='binary_crossentropy',
2930
optimizer=optimizer,
3031
metrics=['accuracy'])
3132

3233
# The generator takes noise as input and generates imgs
33-
z = Input(shape=(self.noise_dim,), batch_size=self.batch_size)
34-
label = Input(shape=(1,), batch_size=self.batch_size)
34+
z = Input(shape=(self.noise_dim,))
35+
label = Input(shape=(1,))
3536
record = self.generator([z, label])
3637

3738
# For the combined model we will only train the generator
@@ -42,8 +43,8 @@ def __init__(self, model_parameters):
4243

4344
# The combined model (stacked generator and discriminator)
4445
# Trains the generator to fool the discriminator
45-
self.combined = Model([z, label], validity)
46-
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
46+
self._model = Model([z, label], validity)
47+
self._model.compile(loss='binary_crossentropy', optimizer=optimizer)
4748

4849
def get_data_batch(self, train, batch_size, seed=0):
4950
# # random sampling - some samples will have excessively low or high sampling, but easy to implement
@@ -61,12 +62,14 @@ def get_data_batch(self, train, batch_size, seed=0):
6162
return np.reshape(x, (batch_size, -1))
6263

6364
def train(self, data, train_arguments):
64-
[cache_prefix, label_dim, epochs, sample_interval, data_dir] = train_arguments
65+
[cache_prefix, label_dim, epochs, sample_interval, classes] = train_arguments
6566

6667
# Adversarial ground truths
6768
valid = np.ones((self.batch_size, 1))
6869
fake = np.zeros((self.batch_size, 1))
6970

71+
#define here the classes?
72+
7073
for epoch in range(epochs):
7174
# ---------------------
7275
# Train Discriminator
@@ -88,7 +91,7 @@ def train(self, data, train_arguments):
8891
# ---------------------
8992
noise = tf.random.normal((self.batch_size, self.noise_dim))
9093
# Train the generator (to have the discriminator label samples as valid)
91-
g_loss = self.combined.train_on_batch([noise, label], valid)
94+
g_loss = self._model.train_on_batch([noise, label], valid)
9295

9396
# Plot the progress
9497
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
@@ -105,7 +108,7 @@ def train(self, data, train_arguments):
105108

106109
#Here is generating synthetic data
107110
z = tf.random.normal((432, self.noise_dim))
108-
label_z = tf.random.uniform((432,), minval=min(self.classes), maxval=max(self.classes)+1, dtype=tf.dtypes.int32)
111+
label_z = tf.random.uniform((432,), minval=min(classes), maxval=max(classes)+1, dtype=tf.dtypes.int32)
109112
gen_data = self.generator([z, label_z])
110113

111114
class Generator():
@@ -144,66 +147,3 @@ def build_model(self, input_shape, dim):
144147
x = Dense(dim, activation='relu')(x)
145148
x = Dense(1, activation='sigmoid')(x)
146149
return Model(inputs=[events, label], outputs=x)
147-
148-
149-
if __name__ == '__main__':
150-
import pandas as pd
151-
from src.ydata_synthetic.preprocessing import transformations
152-
import sklearn.cluster as cluster
153-
154-
data = pd.read_csv('/home/fabiana/PycharmProjects/YData/gan-playground/examples/data/creditcard.csv')
155-
156-
data_cols = list(data.columns[data.columns != 'Class'])
157-
label_cols = ['Class']
158-
159-
print('Dataset columns: {}'.format(data_cols))
160-
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3',
161-
'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9',
162-
'V23',
163-
'Class']
164-
processed_data = data[sorted_cols].copy()
165-
166-
data = transformations(data)
167-
168-
# For the purpose of this example we will only synthesize the minority class
169-
train_data = data.loc[data['Class'] == 1].copy()
170-
171-
print(
172-
"Dataset info: Number of records - {} Number of varibles - {}".format(train_data.shape[0],
173-
train_data.shape[1]))
174-
algorithm = cluster.KMeans
175-
args, kwds = (), {'n_clusters': 2, 'random_state': 0}
176-
labels = algorithm(*args, **kwds).fit_predict(train_data[data_cols])
177-
178-
print(pd.DataFrame([[np.sum(labels == i)] for i in np.unique(labels)], columns=['count'],
179-
index=np.unique(labels)))
180-
181-
fraud_w_classes = train_data.copy()
182-
fraud_w_classes['Class'] = labels
183-
184-
noise_dim = 32
185-
dim = 128
186-
batch_size = 128
187-
188-
log_step = 100
189-
epochs = 500 + 1
190-
learning_rate = 5e-4
191-
models_dir = './cache'
192-
193-
train_sample = data.copy().reset_index(drop=True)
194-
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True)
195-
label_cols = [i for i in train_sample.columns if 'Class' in i]
196-
data_cols = [i for i in train_sample.columns if i not in label_cols]
197-
train_sample[data_cols] = train_sample[data_cols] / 10 # scale to random noise size, one less thing to learn
198-
train_no_label = train_sample[data_cols]
199-
200-
gan_args = [batch_size, learning_rate, noise_dim, train_sample.shape[1], 2, (0, 1), dim]
201-
train_args = ['', -1, epochs, log_step, '']
202-
203-
synthesizer = CGAN(gan_args)
204-
synthesizer.train(train_sample, train_args)
205-
206-
207-
208-
209-

0 commit comments

Comments
 (0)