Skip to content

Commit f631807

Browse files
authored
fix: save load TimeGAN (#58)
* fix: remove optimizers from the class to enable save and load. * fix: Remove not used grads from the supervised training.
1 parent 1ad9c7e commit f631807

1 file changed

Lines changed: 22 additions & 26 deletions

File tree

  • src/ydata_synthetic/synthesizers/timeseries/timegan

src/ydata_synthetic/synthesizers/timeseries/timegan/model.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,6 @@ def define_gan(self):
9595
outputs=Y_real,
9696
name="RealDiscriminator")
9797

98-
# ----------------------------
99-
# Init the optimizers
100-
# ----------------------------
101-
self.autoencoder_opt = Adam(learning_rate=self.lr)
102-
self.supervisor_opt = Adam(learning_rate=self.lr)
103-
self.generator_opt = Adam(learning_rate=self.lr)
104-
self.discriminator_opt = Adam(learning_rate=self.lr)
105-
self.embedding_opt = Adam(learning_rate=self.lr)
106-
10798
# ----------------------------
10899
# Define the loss functions
109100
# ----------------------------
@@ -112,31 +103,32 @@ def define_gan(self):
112103

113104

114105
@function
115-
def train_autoencoder(self, x):
106+
def train_autoencoder(self, x, opt):
116107
with GradientTape() as tape:
117108
x_tilde = self.autoencoder(x)
118109
embedding_loss_t0 = self._mse(x, x_tilde)
119110
e_loss_0 = 10 * sqrt(embedding_loss_t0)
120111

121112
var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
122113
gradients = tape.gradient(e_loss_0, var_list)
123-
self.autoencoder_opt.apply_gradients(zip(gradients, var_list))
114+
opt.apply_gradients(zip(gradients, var_list))
124115
return sqrt(embedding_loss_t0)
125116

126117
@function
127-
def train_supervisor(self, x):
118+
def train_supervisor(self, x, opt):
128119
with GradientTape() as tape:
129120
h = self.embedder(x)
130121
h_hat_supervised = self.supervisor(h)
131122
g_loss_s = self._mse(h[:, 1:, :], h_hat_supervised[:, 1:, :])
132123

133124
var_list = self.supervisor.trainable_variables + self.generator.trainable_variables
134125
gradients = tape.gradient(g_loss_s, var_list)
135-
self.supervisor_opt.apply_gradients(zip(gradients, var_list))
126+
apply_grads = [(grad, var) for (grad, var) in zip(gradients, var_list) if grad is not None]
127+
opt.apply_gradients(apply_grads)
136128
return g_loss_s
137129

138130
@function
139-
def train_embedder(self,x):
131+
def train_embedder(self,x, opt):
140132
with GradientTape() as tape:
141133
h = self.embedder(x)
142134
h_hat_supervised = self.supervisor(h)
@@ -148,7 +140,7 @@ def train_embedder(self,x):
148140

149141
var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
150142
gradients = tape.gradient(e_loss, var_list)
151-
self.embedding_opt.apply_gradients(zip(gradients, var_list))
143+
opt.apply_gradients(zip(gradients, var_list))
152144
return sqrt(embedding_loss_t0)
153145

154146
def discriminator_loss(self, x, z):
@@ -176,7 +168,7 @@ def calc_generator_moments_loss(y_true, y_pred):
176168
return g_loss_mean + g_loss_var
177169

178170
@function
179-
def train_generator(self, x, z):
171+
def train_generator(self, x, z, opt):
180172
with GradientTape() as tape:
181173
y_fake = self.adversarial_supervised(z)
182174
generator_loss_unsupervised = self._bce(y_true=ones_like(y_fake),
@@ -199,17 +191,17 @@ def train_generator(self, x, z):
199191

200192
var_list = self.generator_aux.trainable_variables + self.supervisor.trainable_variables
201193
gradients = tape.gradient(generator_loss, var_list)
202-
self.generator_opt.apply_gradients(zip(gradients, var_list))
194+
opt.apply_gradients(zip(gradients, var_list))
203195
return generator_loss_unsupervised, generator_loss_supervised, generator_moment_loss
204196

205197
@function
206-
def train_discriminator(self, x, z):
198+
def train_discriminator(self, x, z, opt):
207199
with GradientTape() as tape:
208200
discriminator_loss = self.discriminator_loss(x, z)
209201

210202
var_list = self.discriminator.trainable_variables
211203
gradients = tape.gradient(discriminator_loss, var_list)
212-
self.discriminator_opt.apply_gradients(zip(gradients, var_list))
204+
opt.apply_gradients(zip(gradients, var_list))
213205
return discriminator_loss
214206

215207
def get_batch_data(self, data, n_windows):
@@ -229,16 +221,22 @@ def get_batch_noise(self):
229221

230222
def train(self, data, train_steps):
231223
## Embedding network training
224+
autoencoder_opt = Adam(learning_rate=self.lr)
232225
for _ in tqdm(range(train_steps), desc='Emddeding network training'):
233226
X_ = next(self.get_batch_data(data, n_windows=len(data)))
234-
step_e_loss_t0 = self.train_autoencoder(X_)
227+
step_e_loss_t0 = self.train_autoencoder(X_, autoencoder_opt)
235228

236229
## Supervised Network training
230+
supervisor_opt = Adam(learning_rate=self.lr)
237231
for _ in tqdm(range(train_steps), desc='Supervised network training'):
238232
X_ = next(self.get_batch_data(data, n_windows=len(data)))
239-
step_g_loss_s = self.train_supervisor(X_)
233+
step_g_loss_s = self.train_supervisor(X_, supervisor_opt)
240234

241235
## Joint training
236+
generator_opt = Adam(learning_rate=self.lr)
237+
embedder_opt = Adam(learning_rate=self.lr)
238+
discriminator_opt = Adam(learning_rate=self.lr)
239+
242240
step_g_loss_u = step_g_loss_s = step_g_loss_v = step_e_loss_t0 = step_d_loss = 0
243241
for _ in tqdm(range(train_steps), desc='Joint networks training'):
244242

@@ -250,18 +248,18 @@ def train(self, data, train_steps):
250248
# --------------------------
251249
# Train the generator
252250
# --------------------------
253-
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_)
251+
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_, generator_opt)
254252

255253
# --------------------------
256254
# Train the embedder
257255
# --------------------------
258-
step_e_loss_t0 = self.train_embedder(X_)
256+
step_e_loss_t0 = self.train_embedder(X_, embedder_opt)
259257

260258
X_ = next(self.get_batch_data(data, n_windows=len(data)))
261259
Z_ = next(self.get_batch_noise())
262260
step_d_loss = self.discriminator_loss(X_, Z_)
263261
if step_d_loss > 0.15:
264-
step_d_loss = self.train_discriminator(X_, Z_)
262+
step_d_loss = self.train_discriminator(X_, Z_, discriminator_opt)
265263

266264
def sample(self, n_samples):
267265
steps = n_samples // self.batch_size + 1
@@ -273,8 +271,6 @@ def sample(self, n_samples):
273271
return np.array(np.vstack(data))
274272

275273

276-
277-
278274
class Generator(Model):
279275
def __init__(self, hidden_dim, net_type='GRU'):
280276
self.hidden_dim = hidden_dim

0 commit comments

Comments
 (0)