@@ -95,15 +95,6 @@ def define_gan(self):
9595 outputs = Y_real ,
9696 name = "RealDiscriminator" )
9797
98- # ----------------------------
99- # Init the optimizers
100- # ----------------------------
101- self .autoencoder_opt = Adam (learning_rate = self .lr )
102- self .supervisor_opt = Adam (learning_rate = self .lr )
103- self .generator_opt = Adam (learning_rate = self .lr )
104- self .discriminator_opt = Adam (learning_rate = self .lr )
105- self .embedding_opt = Adam (learning_rate = self .lr )
106-
10798 # ----------------------------
10899 # Define the loss functions
109100 # ----------------------------
@@ -112,31 +103,32 @@ def define_gan(self):
112103
113104
114105 @function
115- def train_autoencoder (self , x ):
106+ def train_autoencoder (self , x , opt ):
116107 with GradientTape () as tape :
117108 x_tilde = self .autoencoder (x )
118109 embedding_loss_t0 = self ._mse (x , x_tilde )
119110 e_loss_0 = 10 * sqrt (embedding_loss_t0 )
120111
121112 var_list = self .embedder .trainable_variables + self .recovery .trainable_variables
122113 gradients = tape .gradient (e_loss_0 , var_list )
123- self . autoencoder_opt .apply_gradients (zip (gradients , var_list ))
114+ opt .apply_gradients (zip (gradients , var_list ))
124115 return sqrt (embedding_loss_t0 )
125116
126117 @function
127- def train_supervisor (self , x ):
118+ def train_supervisor (self , x , opt ):
128119 with GradientTape () as tape :
129120 h = self .embedder (x )
130121 h_hat_supervised = self .supervisor (h )
131122 g_loss_s = self ._mse (h [:, 1 :, :], h_hat_supervised [:, 1 :, :])
132123
133124 var_list = self .supervisor .trainable_variables + self .generator .trainable_variables
134125 gradients = tape .gradient (g_loss_s , var_list )
135- self .supervisor_opt .apply_gradients (zip (gradients , var_list ))
126+ apply_grads = [(grad , var ) for (grad , var ) in zip (gradients , var_list ) if grad is not None ]
127+ opt .apply_gradients (apply_grads )
136128 return g_loss_s
137129
138130 @function
139- def train_embedder (self ,x ):
131+ def train_embedder (self ,x , opt ):
140132 with GradientTape () as tape :
141133 h = self .embedder (x )
142134 h_hat_supervised = self .supervisor (h )
@@ -148,7 +140,7 @@ def train_embedder(self,x):
148140
149141 var_list = self .embedder .trainable_variables + self .recovery .trainable_variables
150142 gradients = tape .gradient (e_loss , var_list )
151- self . embedding_opt .apply_gradients (zip (gradients , var_list ))
143+ opt .apply_gradients (zip (gradients , var_list ))
152144 return sqrt (embedding_loss_t0 )
153145
154146 def discriminator_loss (self , x , z ):
@@ -176,7 +168,7 @@ def calc_generator_moments_loss(y_true, y_pred):
176168 return g_loss_mean + g_loss_var
177169
178170 @function
179- def train_generator (self , x , z ):
171+ def train_generator (self , x , z , opt ):
180172 with GradientTape () as tape :
181173 y_fake = self .adversarial_supervised (z )
182174 generator_loss_unsupervised = self ._bce (y_true = ones_like (y_fake ),
@@ -199,17 +191,17 @@ def train_generator(self, x, z):
199191
200192 var_list = self .generator_aux .trainable_variables + self .supervisor .trainable_variables
201193 gradients = tape .gradient (generator_loss , var_list )
202- self . generator_opt .apply_gradients (zip (gradients , var_list ))
194+ opt .apply_gradients (zip (gradients , var_list ))
203195 return generator_loss_unsupervised , generator_loss_supervised , generator_moment_loss
204196
205197 @function
206- def train_discriminator (self , x , z ):
198+ def train_discriminator (self , x , z , opt ):
207199 with GradientTape () as tape :
208200 discriminator_loss = self .discriminator_loss (x , z )
209201
210202 var_list = self .discriminator .trainable_variables
211203 gradients = tape .gradient (discriminator_loss , var_list )
212- self . discriminator_opt .apply_gradients (zip (gradients , var_list ))
204+ opt .apply_gradients (zip (gradients , var_list ))
213205 return discriminator_loss
214206
215207 def get_batch_data (self , data , n_windows ):
@@ -229,16 +221,22 @@ def get_batch_noise(self):
229221
230222 def train (self , data , train_steps ):
231223 ## Embedding network training
224+ autoencoder_opt = Adam (learning_rate = self .lr )
232225 for _ in tqdm (range (train_steps ), desc = 'Emddeding network training' ):
233226 X_ = next (self .get_batch_data (data , n_windows = len (data )))
234- step_e_loss_t0 = self .train_autoencoder (X_ )
227+ step_e_loss_t0 = self .train_autoencoder (X_ , autoencoder_opt )
235228
236229 ## Supervised Network training
230+ supervisor_opt = Adam (learning_rate = self .lr )
237231 for _ in tqdm (range (train_steps ), desc = 'Supervised network training' ):
238232 X_ = next (self .get_batch_data (data , n_windows = len (data )))
239- step_g_loss_s = self .train_supervisor (X_ )
233+ step_g_loss_s = self .train_supervisor (X_ , supervisor_opt )
240234
241235 ## Joint training
236+ generator_opt = Adam (learning_rate = self .lr )
237+ embedder_opt = Adam (learning_rate = self .lr )
238+ discriminator_opt = Adam (learning_rate = self .lr )
239+
242240 step_g_loss_u = step_g_loss_s = step_g_loss_v = step_e_loss_t0 = step_d_loss = 0
243241 for _ in tqdm (range (train_steps ), desc = 'Joint networks training' ):
244242
@@ -250,18 +248,18 @@ def train(self, data, train_steps):
250248 # --------------------------
251249 # Train the generator
252250 # --------------------------
253- step_g_loss_u , step_g_loss_s , step_g_loss_v = self .train_generator (X_ , Z_ )
251+ step_g_loss_u , step_g_loss_s , step_g_loss_v = self .train_generator (X_ , Z_ , generator_opt )
254252
255253 # --------------------------
256254 # Train the embedder
257255 # --------------------------
258- step_e_loss_t0 = self .train_embedder (X_ )
256+ step_e_loss_t0 = self .train_embedder (X_ , embedder_opt )
259257
260258 X_ = next (self .get_batch_data (data , n_windows = len (data )))
261259 Z_ = next (self .get_batch_noise ())
262260 step_d_loss = self .discriminator_loss (X_ , Z_ )
263261 if step_d_loss > 0.15 :
264- step_d_loss = self .train_discriminator (X_ , Z_ )
262+ step_d_loss = self .train_discriminator (X_ , Z_ , discriminator_opt )
265263
266264 def sample (self , n_samples ):
267265 steps = n_samples // self .batch_size + 1
@@ -273,8 +271,6 @@ def sample(self, n_samples):
273271 return np .array (np .vstack (data ))
274272
275273
276-
277-
278274class Generator (Model ):
279275 def __init__ (self , hidden_dim , net_type = 'GRU' ):
280276 self .hidden_dim = hidden_dim
0 commit comments