Skip to content

Commit e2f4159

Browse files
authored
fix: remove optimizers as model attributess (#160)
* remove optimizers from dragan * remove optimizers from cramergan * remove optimizers from wgangp * remove optimizers from cwgangp
1 parent cb8e3d2 commit e2f4159

4 files changed

Lines changed: 39 additions & 55 deletions

File tree

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,21 @@ def define_gan(self, activation_info: Optional[NamedTuple] = None):
3535
self.critic = Critic(self.batch_size). \
3636
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
3737

38-
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
39-
self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
38+
g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
39+
c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
4040

4141
# The generator takes noise as input and generates records
4242
z = Input(shape=(self.noise_dim,), batch_size=self.batch_size)
4343
fake = self.generator(z)
4444
logits = self.critic(fake)
4545

46+
return g_optimizer, c_optimizer
47+
4648
def gradient_penalty(self, real, fake):
4749
gp = gradient_penalty(self.f_crit, real, fake, mode=Mode.CRAMER)
4850
return gp
4951

50-
def update_gradients(self, x):
52+
def update_gradients(self, x, g_optimizer, c_optimizer):
5153
"""Compute and apply the gradients for both the Generator and the Critic.
5254
5355
:param x: real data event
@@ -71,13 +73,13 @@ def update_gradients(self, x):
7173
g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
7274

7375
# Update the weights of the generator
74-
self.g_optimizer.apply_gradients(
76+
g_optimizer.apply_gradients(
7577
zip(g_gradients, self.generator.trainable_variables)
7678
)
7779

7880
c_gradient = d_tape.gradient(c_loss, self.critic.trainable_variables)
7981
# Update the weights of the critic using the optimizer
80-
self.c_optimizer.apply_gradients(
82+
c_optimizer.apply_gradients(
8183
zip(c_gradient, self.critic.trainable_variables)
8284
)
8385

@@ -131,8 +133,8 @@ def get_data_batch(train, batch_size, seed=0):
131133
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
132134
return train[train_ix[start_i: stop_i]]
133135

134-
def train_step(self, train_data):
135-
critic_loss, g_loss = self.update_gradients(train_data)
136+
def train_step(self, train_data, optimizers):
137+
critic_loss, g_loss = self.update_gradients(train_data, *optimizers)
136138
return critic_loss, g_loss
137139

138140
def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
@@ -147,7 +149,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
147149

148150
data = self.processor.transform(data)
149151
self.data_dim = data.shape[1]
150-
self.define_gan(self.processor.col_transform_info)
152+
optimizers = self.define_gan(self.processor.col_transform_info)
151153

152154
iterations = int(abs(data.shape[0] / self.batch_size) + 1)
153155

@@ -158,7 +160,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
158160
for epoch in trange(train_arguments.epochs):
159161
for iteration in range(iterations):
160162
batch_data = self.get_data_batch(data, self.batch_size)
161-
c_loss, g_loss = self.train_step(batch_data)
163+
c_loss, g_loss = self.train_step(batch_data, optimizers)
162164

163165
if iteration % train_arguments.sample_interval == 0:
164166
# Test here data generation step
@@ -168,23 +170,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
168170
model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5'
169171
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration))
170172
self.critic.save_weights(model_checkpoint_base_name.format('critic', iteration))
171-
172-
print(
173-
"Epoch: {} | critic_loss: {} | gen_loss: {}".format(
174-
epoch, c_loss, g_loss
175-
))
176-
177-
self.g_optimizer=self.g_optimizer.get_config()
178-
self.critic_optimizer=self.c_optimizer.get_config()
179-
180-
def save(self, path):
181-
"""Strip down the optimizers from the model then save."""
182-
for attr in ['g_optimizer', 'c_optimizer']:
183-
try:
184-
delattr(self, attr)
185-
except AttributeError:
186-
continue
187-
super().save(path)
173+
print(f"Epoch: {epoch} | critic_loss: {c_loss} | gen_loss: {g_loss}")
188174

189175

190176
class Generator(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/cwgangp/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def define_gan(self, activation_info: Optional[NamedTuple] = None):
4545
self.critic = Critic(self.batch_size, self.num_classes). \
4646
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
4747

48-
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
49-
self.critic_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
48+
g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
49+
c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
50+
return g_optimizer, c_optimizer
5051

5152
def gradient_penalty(self, real, fake, label):
5253
epsilon = random.uniform([real.shape[0], 1], 0.0, 1.0, dtype=dtypes.float32)
@@ -130,7 +131,7 @@ def train(self, data: DataFrame, label_col: str, train_arguments: TrainParameter
130131

131132
processed_data = self.processor.transform(data)
132133
self.data_dim = processed_data.shape[1]
133-
self.define_gan(self.processor.col_transform_info)
134+
optimizers = self.define_gan(self.processor.col_transform_info)
134135

135136
# Merging labels with processed data
136137
processed_data = hstack([processed_data, label])
@@ -145,7 +146,7 @@ def train(self, data: DataFrame, label_col: str, train_arguments: TrainParameter
145146
batch_x = self.get_data_batch(processed_data, self.batch_size) # Batches are retrieved with labels
146147
batch_x, label = batch_x[:, :-1], batch_x[:, -1] # Separate labels from batch
147148

148-
cri_loss, ge_loss = self.train_step((batch_x, label))
149+
cri_loss, ge_loss = self.train_step((batch_x, label), optimizers)
149150

150151
print(
151152
"Epoch: {} | critic_loss: {} | gen_loss: {}".format(

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ def define_gan(self, col_transform_info: Optional[NamedTuple] = None):
3232
self.discriminator = Discriminator(self.batch_size). \
3333
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
3434

35-
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)
36-
self.d_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)
35+
g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)
36+
d_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)
37+
return g_optimizer, d_optimizer
3738

3839
def gradient_penalty(self, real, fake):
3940
gp = gradient_penalty(self.discriminator, real, fake, mode= Mode.DRAGAN)
4041
return gp
4142

42-
def update_gradients(self, x):
43+
def update_gradients(self, x, g_optimizer, d_optimizer):
4344
"""
4445
Compute the gradients for both the Generator and the Discriminator
45-
:param x: real data event
46+
x (tf.tensor): real data event
47+
*_optimizer (tf.OptimizerV2): Optimizer for the * model
4648
:return: generator gradients, discriminator gradients
4749
"""
4850
# Update the gradients of critic for n_critic times (Training the critic)
@@ -52,7 +54,7 @@ def update_gradients(self, x):
5254
# Get the gradients of the critic
5355
d_gradient = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
5456
# Update the weights of the critic using the optimizer
55-
self.d_optimizer.apply_gradients(
57+
d_optimizer.apply_gradients(
5658
zip(d_gradient, self.discriminator.trainable_variables)
5759
)
5860

@@ -64,7 +66,7 @@ def update_gradients(self, x):
6466
gen_gradients = g_tape.gradient(gen_loss, self.generator.trainable_variables)
6567

6668
# Update the weights of the generator
67-
self.g_optimizer.apply_gradients(
69+
g_optimizer.apply_gradients(
6870
zip(gen_gradients, self.generator.trainable_variables)
6971
)
7072

@@ -112,8 +114,8 @@ def get_data_batch(self, train, batch_size):
112114
.batch(batch_size).shuffle(buffer_size)
113115
return train_loader
114116

115-
def train_step(self, train_data):
116-
d_loss, g_loss = self.update_gradients(train_data)
117+
def train_step(self, train_data, optimizers):
118+
d_loss, g_loss = self.update_gradients(train_data, *optimizers)
117119
return d_loss, g_loss
118120

119121
def train(self, data, train_arguments, num_cols, cat_cols):
@@ -128,7 +130,7 @@ def train(self, data, train_arguments, num_cols, cat_cols):
128130

129131
processed_data = self.processor.transform(data)
130132
self.data_dim = processed_data.shape[1]
131-
self.define_gan(self.processor.col_transform_info)
133+
optimizers = self.define_gan(self.processor.col_transform_info)
132134

133135
train_loader = self.get_data_batch(processed_data, self.batch_size)
134136

@@ -139,7 +141,7 @@ def train(self, data, train_arguments, num_cols, cat_cols):
139141
for epoch in tqdm.trange(train_arguments.epochs):
140142
for batch_data in train_loader:
141143
batch_data = tf.cast(batch_data, dtype=tf.float32)
142-
d_loss, g_loss = self.train_step(batch_data)
144+
d_loss, g_loss = self.train_step(batch_data, optimizers)
143145

144146
print(
145147
"Epoch: {} | disc_loss: {} | gen_loss: {}".format(
@@ -155,9 +157,6 @@ def train(self, data, train_arguments, num_cols, cat_cols):
155157
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
156158
self.discriminator.save_weights(model_checkpoint_base_name.format('discriminator', epoch))
157159

158-
self.g_optimizer=self.g_optimizer.get_config()
159-
self.d_optimizer=self.d_optimizer.get_config()
160-
161160

162161
class Discriminator(Model):
163162
def __init__(self, batch_size):

src/ydata_synthetic/synthesizers/regular/wgangp/model.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def define_gan(self, activation_info: Optional[NamedTuple] = None):
3333
self.critic = Critic(self.batch_size). \
3434
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
3535

36-
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
37-
self.critic_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
36+
g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
37+
c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
38+
return g_optimizer, c_optimizer
3839

3940
def gradient_penalty(self, real, fake):
4041
epsilon = tf.random.uniform([real.shape[0], 1], 0.0, 1.0, dtype=tf.dtypes.float32)
@@ -47,7 +48,7 @@ def gradient_penalty(self, real, fake):
4748
d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
4849
return d_regularizer
4950

50-
def update_gradients(self, x):
51+
def update_gradients(self, x, g_optimizer, c_optimizer):
5152
"""
5253
Compute the gradients for both the Generator and the Critic
5354
:param x: real data event
@@ -60,7 +61,7 @@ def update_gradients(self, x):
6061
# Get the gradients of the critic
6162
d_gradient = d_tape.gradient(critic_loss, self.critic.trainable_variables)
6263
# Update the weights of the critic using the optimizer
63-
self.critic_optimizer.apply_gradients(
64+
c_optimizer.apply_gradients(
6465
zip(d_gradient, self.critic.trainable_variables)
6566
)
6667

@@ -72,7 +73,7 @@ def update_gradients(self, x):
7273
gen_gradients = g_tape.gradient(gen_loss, self.generator.trainable_variables)
7374

7475
# Update the weights of the generator
75-
self.g_optimizer.apply_gradients(
76+
g_optimizer.apply_gradients(
7677
zip(gen_gradients, self.generator.trainable_variables)
7778
)
7879

@@ -124,8 +125,8 @@ def get_data_batch(self, train, batch_size, seed=0):
124125
return train[train_ix[start_i: stop_i]]
125126

126127
@tf.function
127-
def train_step(self, train_data):
128-
cri_loss, ge_loss = self.update_gradients(train_data)
128+
def train_step(self, train_data, optimizers):
129+
cri_loss, ge_loss = self.update_gradients(train_data, *optimizers)
129130
return cri_loss, ge_loss
130131

131132
def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
@@ -140,7 +141,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
140141

141142
processed_data = self.processor.transform(data)
142143
self.data_dim = processed_data.shape[1]
143-
self.define_gan(self.processor.col_transform_info)
144+
optimizers = self.define_gan(self.processor.col_transform_info)
144145

145146
iterations = int(abs(data.shape[0]/self.batch_size)+1)
146147

@@ -151,7 +152,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
151152
for epoch in trange(train_arguments.epochs):
152153
for _ in range(iterations):
153154
batch_data = self.get_data_batch(processed_data, self.batch_size).astype(np.float32)
154-
cri_loss, ge_loss = self.train_step(batch_data)
155+
cri_loss, ge_loss = self.train_step(batch_data, optimizers)
155156

156157
print(
157158
"Epoch: {} | disc_loss: {} | gen_loss: {}".format(
@@ -167,9 +168,6 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
167168
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
168169
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))
169170

170-
self.g_optimizer=self.g_optimizer.get_config()
171-
self.critic_optimizer=self.critic_optimizer.get_config()
172-
173171

174172
class Generator(tf.keras.Model):
175173
def __init__(self, batch_size):

0 commit comments

Comments
 (0)