Skip to content

Commit a9de1ab

Browse files
authored
feat: Add new Cramer Loss and Cramer GAN (#102)
1 parent 1c84754 commit a9de1ab

6 files changed

Lines changed: 323 additions & 6 deletions

File tree

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#Install ydata-synthetic lib
2+
# pip install ydata-synthetic
3+
import sklearn.cluster as cluster
4+
import numpy as np
5+
import pandas as pd
6+
7+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
8+
from ydata_synthetic.synthesizers.regular import CRAMERGAN
9+
from ydata_synthetic.preprocessing.regular.credit_fraud import transformations
10+
11+
model = CRAMERGAN
12+
13+
#Read the original data and have it preprocessed
14+
data = pd.read_csv('data/creditcard.csv', index_col=[0])
15+
16+
#Data processing and analysis
17+
data_cols = list(data.columns[ data.columns != 'Class' ])
18+
label_cols = ['Class']
19+
20+
print('Dataset columns: {}'.format(data_cols))
21+
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
22+
processed_data = data[ sorted_cols ].copy()
23+
24+
#Before training the GAN do not forget to apply the required data transformations
25+
#To ease here we've applied a PowerTransformation
26+
data = transformations(data)
27+
28+
#For the purpose of this example we will only synthesize the minority class
29+
train_data = data.loc[ data['Class']==1 ].copy()
30+
31+
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
32+
33+
algorithm = cluster.KMeans
34+
args, kwds = (), {'n_clusters':2, 'random_state':0}
35+
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])
36+
37+
print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )
38+
39+
fraud_w_classes = train_data.copy()
40+
fraud_w_classes['Class'] = labels
41+
42+
# GAN training
43+
#Define the GAN and training parameters
44+
noise_dim = 32
45+
dim = 128
46+
batch_size = 128
47+
48+
log_step = 100
49+
epochs = 300+1
50+
learning_rate = 5e-4
51+
beta_1 = 0.5
52+
beta_2 = 0.9
53+
models_dir = './cache'
54+
55+
train_sample = fraud_w_classes.copy().reset_index(drop=True)
56+
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True)
57+
label_cols = [ i for i in train_sample.columns if 'Class' in i ]
58+
data_cols = [ i for i in train_sample.columns if i not in label_cols ]
59+
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
60+
train_no_label = train_sample[ data_cols ]
61+
62+
model_parameters = ModelParameters(batch_size=batch_size,
63+
lr=learning_rate,
64+
betas=(beta_1, beta_2),
65+
noise_dim=noise_dim,
66+
n_cols=train_sample.shape[1],
67+
layers_dim=dim)
68+
69+
train_args = TrainParameters(epochs=epochs,
70+
sample_interval=log_step)
71+
72+
test_size = 492 # number of fraud cases
73+
noise_dim = 32
74+
75+
#Training the CRAMERGAN model
76+
synthesizer = model(model_parameters, gradient_penalty_weight=10)
77+
synthesizer.train(train_sample, train_args)
78+
79+
#Saving the synthesizer to later generate new events
80+
synthesizer.save(path='models/cramergan_creditcard.pkl')
81+
82+
#Loading the synthesizer
83+
synth = model.load(path='models/cramergan_creditcard.pkl')
84+
85+
#Sampling the data
86+
#Note that the data returned it is not inverse processed.
87+
data_sample = synth.sample(100000)

src/ydata_synthetic/synthesizers/loss.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from tensorflow import reshape, shape, math, GradientTape, reduce_mean
33
from tensorflow import norm as tfnorm
44

5+
from enum import Enum
6+
7+
class Mode(Enum):
8+
WGANGP = 'wgangp'
9+
DRAGAN = 'dragan'
10+
CRAMER = 'cramer'
11+
512
## Original code loss from
613
## https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/blob/master/tf2gan/loss.py
714
def gradient_penalty(f, real, fake, mode):
@@ -23,12 +30,24 @@ def _interpolate(a, b=None):
2330
grad = t.gradient(pred, x)
2431
norm = tfnorm(reshape(grad, [shape(grad)[0], -1]), axis=1)
2532
gp = reduce_mean((norm - 1.)**2)
26-
2733
return gp
2834

29-
if mode == 'dragan':
35+
def _gradient_penalty_cramer(f_crit, real, fake):
36+
epsilon = random.uniform([real.shape[0], 1], 0.0, 1.0)
37+
x_hat = epsilon * real + (1 - epsilon) * fake[0]
38+
with GradientTape() as t:
39+
t.watch(x_hat)
40+
f_x_hat = f_crit(x_hat, fake[1])
41+
gradients = t.gradient(f_x_hat, x_hat)
42+
c_dx = tfnorm(reshape(gradients, [shape(gradients)[0], -1]), axis=1)
43+
c_regularizer = (c_dx - 1.0) ** 2
44+
return c_regularizer
45+
46+
if mode == Mode.DRAGAN:
3047
gp = _gradient_penalty(f, real)
31-
elif mode == 'wgangp':
48+
elif mode == Mode.CRAMER:
49+
gp = _gradient_penalty_cramer(f, real, fake)
50+
elif mode == Mode.WGANGP:
3251
gp = _gradient_penalty(f, real, fake)
3352

3453
return gp

src/ydata_synthetic/synthesizers/regular/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN
44
from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP
55
from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN
6+
from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN
67

78
__all__ = [
89
"VanilllaGAN",
910
"CGAN",
1011
"WGAN",
1112
"WGAN_GP",
12-
"DRAGAN"
13+
"DRAGAN",
14+
"CRAMERGAN"
1315
]

src/ydata_synthetic/synthesizers/regular/cramergan/__init__.py

Whitespace-only changes.
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import os
2+
from os import path
3+
import numpy as np
4+
from tqdm import trange
5+
6+
from ydata_synthetic.synthesizers.gan import BaseModel
7+
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
8+
from ydata_synthetic.synthesizers import TrainParameters
9+
10+
import tensorflow as tf
11+
from tensorflow.keras.layers import Input, Dense, Dropout
12+
from tensorflow.keras import Model
13+
from tensorflow.keras.optimizers import Adam
14+
15+
class CRAMERGAN(BaseModel):
16+
17+
__MODEL__='CRAMERGAN'
18+
19+
def __init__(self, model_parameters, gradient_penalty_weight=10):
20+
"""Create a base CramerGAN.
21+
22+
Based according to the WGAN paper - https://arxiv.org/pdf/1705.10743.pdf
23+
CramerGAN, a solution to biased Wassertein Gradients https://arxiv.org/abs/1705.10743"""
24+
self.gradient_penalty_weight = gradient_penalty_weight
25+
super().__init__(model_parameters)
26+
27+
def define_gan(self):
28+
self.generator = Generator(self.batch_size). \
29+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
30+
31+
self.critic = Critic(self.batch_size). \
32+
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
33+
34+
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
35+
self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
36+
37+
# The generator takes noise as input and generates records
38+
z = Input(shape=(self.noise_dim,), batch_size=self.batch_size)
39+
fake = self.generator(z, training=True)
40+
logits = self.critic(fake, training=True)
41+
42+
# Compile the critic
43+
self.critic.compile(loss=self.c_lossfn,
44+
optimizer=self.c_optimizer,
45+
metrics=['accuracy'])
46+
47+
# Generator and critic model
48+
_model = Model(z, logits)
49+
_model.compile(loss=self.g_lossfn, optimizer=self.g_optimizer)
50+
51+
def gradient_penalty(self, real, fake):
52+
gp = gradient_penalty(self.f_crit, real, fake, mode=Mode.CRAMER)
53+
return gp
54+
55+
def update_gradients(self, x):
56+
"""Compute and apply the gradients for both the Generator and the Critic.
57+
58+
:param x: real data event
59+
:return: generator gradients, critic gradients
60+
"""
61+
# Update the gradients of critic for n_critic times (Training the critic)
62+
63+
##New generator gradient_tape
64+
noise= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32)
65+
noise2= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32)
66+
67+
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
68+
fake=self.generator(noise, training=True)
69+
fake2=self.generator(noise2, training=True)
70+
71+
g_loss = self.g_lossfn(x, fake, fake2)
72+
73+
c_loss = self.c_lossfn(x, fake, fake2)
74+
75+
# Get the gradients of the generator
76+
g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
77+
78+
# Update the weights of the generator
79+
self.g_optimizer.apply_gradients(
80+
zip(g_gradients, self.generator.trainable_variables)
81+
)
82+
83+
c_gradient = d_tape.gradient(c_loss, self.critic.trainable_variables)
84+
# Update the weights of the critic using the optimizer
85+
self.c_optimizer.apply_gradients(
86+
zip(c_gradient, self.critic.trainable_variables)
87+
)
88+
89+
return c_loss, g_loss
90+
91+
def g_lossfn(self, real, fake, fake2):
92+
"""Compute generator loss function according to the CramerGAN paper.
93+
94+
:param real: A real sample
95+
:param fake: A fake sample
96+
:param fak2: A second fake sample
97+
:return: Loss of the generator
98+
"""
99+
g_loss = tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) + \
100+
tf.norm(self.critic(real, training=True) - self.critic(fake2, training=True), axis=1) - \
101+
tf.norm(self.critic(fake, training=True) - self.critic(fake2, training=True), axis=1)
102+
return tf.reduce_mean(g_loss)
103+
104+
def f_crit(self, real, fake):
105+
"""
106+
Computes the critic distance function f between two samples
107+
:param real: A real sample
108+
:param fake: A fake sample
109+
:return: Loss of the critic
110+
"""
111+
return tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) - tf.norm(self.critic(real, training=True), axis=1)
112+
113+
def c_lossfn(self, real, fake, fake2):
114+
"""
115+
:param real: A real sample
116+
:param fake: A fake sample
117+
:param fak2: A second fake sample
118+
:return: Loss of the critic
119+
"""
120+
f_real = self.f_crit(real, fake2)
121+
f_fake = self.f_crit(fake, fake2)
122+
loss_surrogate = f_real - f_fake
123+
gp = self.gradient_penalty(real, [fake, fake2])
124+
return tf.reduce_mean(- loss_surrogate + self.gradient_penalty_weight*gp)
125+
126+
@staticmethod
127+
def get_data_batch(train, batch_size, seed=0):
128+
# np.random.seed(seed)
129+
# x = train.loc[ np.random.choice(train.index, batch_size) ].values
130+
# iterate through shuffled indices, so every sample gets covered evenly
131+
start_i = (batch_size * seed) % len(train)
132+
stop_i = start_i + batch_size
133+
shuffle_seed = (batch_size * seed) // len(train)
134+
np.random.seed(shuffle_seed)
135+
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time
136+
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
137+
x = train.loc[train_ix[start_i: stop_i]].values
138+
return np.reshape(x, (batch_size, -1))
139+
140+
def train_step(self, train_data):
141+
critic_loss, g_loss = self.update_gradients(train_data)
142+
return critic_loss, g_loss
143+
144+
def train(self, data, train_arguments: TrainParameters):
145+
iterations = int(abs(data.shape[0] / self.batch_size) + 1)
146+
147+
# Create a summary file
148+
train_summary_writer = tf.summary.create_file_writer(path.join('..\cramergan_test', 'summaries', 'train'))
149+
150+
with train_summary_writer.as_default():
151+
for epoch in trange(train_arguments.epochs):
152+
for iteration in range(iterations):
153+
batch_data = self.get_data_batch(data, self.batch_size)
154+
c_loss, g_loss = self.train_step(batch_data)
155+
156+
if iteration % train_arguments.sample_interval == 0:
157+
# Test here data generation step
158+
# save model checkpoints
159+
if path.exists('./cache') is False:
160+
os.mkdir('./cache')
161+
model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5'
162+
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration))
163+
self.critic.save_weights(model_checkpoint_base_name.format('critic', iteration))
164+
165+
print(
166+
"Epoch: {} | critic_loss: {} | gen_loss: {}".format(
167+
epoch, c_loss, g_loss
168+
))
169+
170+
self.g_optimizer=self.g_optimizer.get_config()
171+
self.critic_optimizer=self.c_optimizer.get_config()
172+
173+
def save(self, path):
174+
"""Strip down the optimizers from the model then save."""
175+
for attr in ['g_optimizer', 'c_optimizer']:
176+
try:
177+
delattr(self, attr)
178+
except AttributeError:
179+
continue
180+
super().save(path)
181+
182+
183+
class Generator(tf.keras.Model):
184+
def __init__(self, batch_size):
185+
"""Simple generator with dense feedforward layers."""
186+
self.batch_size = batch_size
187+
188+
def build_model(self, input_shape, dim, data_dim):
189+
input_ = Input(shape=input_shape, batch_size=self.batch_size)
190+
x = Dense(dim, activation='relu')(input_)
191+
x = Dense(dim * 2, activation='relu')(x)
192+
x = Dense(dim * 4, activation='relu')(x)
193+
x = Dense(data_dim)(x)
194+
return Model(inputs=input_, outputs=x)
195+
196+
class Critic(tf.keras.Model):
197+
def __init__(self, batch_size):
198+
"""Simple critic with dense feedforward and dropout layers."""
199+
self.batch_size = batch_size
200+
201+
def build_model(self, input_shape, dim):
202+
input_ = Input(shape=input_shape, batch_size=self.batch_size)
203+
x = Dense(dim * 4, activation='relu')(input_)
204+
x = Dropout(0.1)(x)
205+
x = Dense(dim * 2, activation='relu')(x)
206+
x = Dropout(0.1)(x)
207+
x = Dense(dim, activation='relu')(x)
208+
x = Dense(1)(x)
209+
return Model(inputs=input_, outputs=x)

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorflow.keras import Model, initializers
1010

1111
from ydata_synthetic.synthesizers.gan import BaseModel
12-
from ydata_synthetic.synthesizers.loss import gradient_penalty
12+
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
1313

1414
class DRAGAN(BaseModel):
1515

@@ -33,7 +33,7 @@ def define_gan(self):
3333
self.d_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)
3434

3535
def gradient_penalty(self, real, fake):
36-
gp = gradient_penalty(self.discriminator, real, fake, mode='dragan')
36+
gp = gradient_penalty(self.discriminator, real, fake, mode= Mode.DRAGAN)
3737
return gp
3838

3939
def update_gradients(self, x):

0 commit comments

Comments
 (0)