@@ -25,6 +25,7 @@ class TSCWGAN(BaseModel):
2525 def __init__ (self , model_parameters , gradient_penalty_weight = 10 ):
2626 """Create a base TSCWGAN."""
2727 self .gradient_penalty_weight = gradient_penalty_weight
28+ self .cond_dim = model_parameters .condition
2829 super ().__init__ (model_parameters )
2930
3031 def define_gan (self ):
@@ -170,91 +171,78 @@ def __init__(self, batch_size):
170171 self .batch_size = batch_size
171172
172173 def build_model (self , input_shape , dim , data_dim ):
173- # Define blocks
174- input_to_latent = Sequential (layers = [
174+ # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond
175+ noise_input = Input (shape = input_shape , batch_size = self .batch_size )
176+
177+ # Compose model
178+ proc_input = Sequential (layers = [
175179 Conv1D (filters = dim , kernel_size = 1 , input_shape = input_shape ),
176180 LeakyReLU (),
177181 Conv1D (dim , kernel_size = 5 , dilation_rate = 2 , padding = "same" ),
178182 LeakyReLU ()
179- ], name = 'input_to_latent' )
183+ ], name = 'input_to_latent' )(noise_input )
184+
180185 block_cnn = Sequential (layers = [
181186 Conv1D (filters = dim , kernel_size = 3 , dilation_rate = 2 , padding = "same" ),
182187 LeakyReLU ()
183188 ], name = 'block_cnn' )
184- block_shift = Sequential (layers = [
189+ for i in range (3 ):
190+ if i == 0 :
191+ cnn_block_i = proc_input
192+ cnn_block_o = block_cnn (proc_input )
193+ else :
194+ cnn_block_o = block_cnn (cnn_block_i )
195+ cnn_block_i = Add ()([cnn_block_i , cnn_block_o ])
196+
197+ shift = Sequential (layers = [
185198 Conv1D (filters = 10 , kernel_size = 3 , dilation_rate = 2 , padding = "same" ),
186199 LeakyReLU (),
187200 Flatten (),
188201 Dense (dim * 2 ),
189202 LeakyReLU ()
190- ], name = 'block_shift' )
203+ ], name = 'block_shift' )(cnn_block_i )
204+
191205 block = Sequential (layers = [
192206 Dense (dim * 2 ),
193207 LeakyReLU ()
194208 ], name = 'block' )
195- latent_to_output = Sequential ([
196- Dense (data_dim )
197- ], name = 'latent_to_ouput' )
209+ for i in range (3 ):
210+ if i == 0 :
211+ block_i = shift
212+ block_o = block (shift )
213+ else :
214+ block_o = block (block_i )
215+ block_i = Add ()([block_i , block_o ])
198216
199- # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond
200- noise_input = Input (shape = input_shape , batch_size = self .batch_size )
201-
202- # Compose model
203- x = input_to_latent (noise_input )
204- x_block = block_cnn (x )
205- x = Add ()([x_block , x ])
206- x_block = block_cnn (x )
207- x = Add ()([x_block , x ])
208- x_block = block_cnn (x )
209- x = Add ()([x_block , x ])
210- x = block_shift (x )
211- x_block = block (x )
212- x = Add ()([x_block , x ])
213- x_block = block (x )
214- x = Add ()([x_block , x ])
215- x_block = block (x )
216- x = Add ()([x_block , x ])
217- x = latent_to_output (x )
218- # Output - Expected shape is (batch_size, seq_len, data_dim). data_dim does not include conditions
219- return Model (inputs = noise_input , outputs = x , name = 'SkipConnectionGenerator' )
217+ output = Dense (data_dim , name = 'latent_to_ouput' )(block_i )
218+ return Model (inputs = noise_input , outputs = output , name = 'SkipConnectionGenerator' )
220219
221220class Critic (Model ):
222221 """Conditional Wasserstein Critic with skip connections."""
223222 def __init__ (self , batch_size ):
224223 self .batch_size = batch_size
225224
226225 def build_model (self , input_shape , dim ):
227- # Define blocks
228- ts_to_latent = Sequential (layers = [
226+ # Define input - Expected input shape is X + condition
227+ record_input = Input (shape = input_shape , batch_size = self .batch_size )
228+
229+ # Compose model
230+ proc_record = Sequential (layers = [
229231 Dense (dim * 2 ,),
230232 LeakyReLU ()
231- ], name = 'ts_to_latent' )
233+ ], name = 'ts_to_latent' )(record_input )
234+
232235 block = Sequential (layers = [
233236 Dense (dim * 2 ),
234237 LeakyReLU ()
235238 ], name = 'block' )
236- latent_to_score = Sequential (layers = [
237- Dense (1 )
238- ], name = 'latent_to_score' )
239-
240- # Define input - Expected input shape is X + condition
241- record_input = Input (shape = input_shape , batch_size = self .batch_size )
242-
243- # Compose model
244- x = ts_to_latent (record_input )
245- x_block = block (x )
246- x = Add ()([x_block , x ])
247- x_block = block (x )
248- x = Add ()([x_block , x ])
249- x_block = block (x )
250- x = Add ()([x_block , x ])
251- x_block = block (x )
252- x = Add ()([x_block , x ])
253- x_block = block (x )
254- x = Add ()([x_block , x ])
255- x_block = block (x )
256- x = Add ()([x_block , x ])
257- x_block = block (x )
258- x = Add ()([x_block , x ])
259- x = latent_to_score (x )
260- return Model (inputs = record_input , outputs = x , name = 'SkipConnectionCritic' )
239+ for i in range (7 ):
240+ if i == 0 :
241+ block_i = proc_record
242+ block_o = block (proc_record )
243+ else :
244+ block_o = block (block_i )
245+ block_i = Add ()([block_i , block_o ])
246+
247+ output = Dense (1 , name = 'latent_to_score' )(block_i )
248+ return Model (inputs = record_input , outputs = output , name = 'SkipConnectionCritic' )
0 commit comments