66from tqdm import trange
77from numpy import array , vstack
88from numpy .random import normal
9+ from typing import List
910
10- from tensorflow import concat , float32 , convert_to_tensor , reshape , GradientTape , reduce_mean , make_ndarray , make_tensor_proto , tile , expand_dims
11+ from tensorflow import concat , float32 , convert_to_tensor , reshape , GradientTape , reduce_mean , make_ndarray , make_tensor_proto , tile , constant
1112from tensorflow import data as tfdata
1213from tensorflow .keras import Model , Sequential
1314from tensorflow .keras .optimizers import Adam
1617from ydata_synthetic .synthesizers .gan import BaseModel
1718from ydata_synthetic .synthesizers import TrainParameters
1819from ydata_synthetic .synthesizers .loss import Mode , gradient_penalty
20+ from ydata_synthetic .synthesizers .timeseries import TimeSeriesDataProcessor
1921
2022class TSCWGAN (BaseModel ):
2123
2224 __MODEL__ = 'TSCWGAN'
2325
2426 def __init__ (self , model_parameters , gradient_penalty_weight = 10 ):
2527 """Create a base TSCWGAN."""
28+ super ().__init__ (model_parameters )
2629 self .gradient_penalty_weight = gradient_penalty_weight
2730 self .cond_dim = model_parameters .condition
28- super ().__init__ (model_parameters )
2931
3032 def define_gan (self ):
3133 self .generator = Generator (self .batch_size ). \
@@ -44,14 +46,18 @@ def define_gan(self):
4446 score = concat ([cond , gen ], axis = 1 )
4547 score = self .critic (score )
4648
47- def train (self , data , train_arguments : TrainParameters ):
48- real_batches = self .get_batch_data (data )
49+ def train (self , data , train_arguments : TrainParameters , num_cols : List [str ], cat_cols : List [str ],
50+ preprocess : bool = True ):
51+ super ().train (data , num_cols , cat_cols , preprocess )
52+
53+ processed_data = self .processor .transform (data )
54+ real_batches = self .get_batch_data (processed_data )
4955 noise_batches = self .get_batch_noise ()
5056
5157 for epoch in trange (train_arguments .epochs ):
5258 for i in range (train_arguments .critic_iter ):
5359 real_batch = next (real_batches )
54- noise_batch = next (noise_batches )[:len (real_batch )] # Truncate the noise tensor in the shape of the real data tensor
60+ noise_batch = next (noise_batches )[:len (real_batch )] # Truncate noise tensor to real data shape
5561
5662 c_loss = self .update_critic (real_batch , noise_batch )
5763
@@ -142,21 +148,31 @@ def get_batch_data(self, data, n_windows= None):
142148 .shuffle (buffer_size = n_windows )
143149 .batch (self .batch_size ).repeat ())
144150
145- def sample (self , cond_array , n_samples ):
146- """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array."""
147- assert len (cond_array .shape ) == 1 , "Condition array should be one-dimensional."
148- assert cond_array .shape [0 ] == self .cond_dim , \
149- f"The condition sequence should have a { self .cond_dim } length."
151+ def sample (self , cond_array : array , n_samples : int , inverse_transform : bool = True ):
152+ """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.
153+ The returned samples per condition will always be a multiple of batch_size and equal or bigger than n_samples.
154+
155+ Arguments:
156+ cond_array (numpy array): Array with the set of conditions for the sampling process.
157+ n_samples (int): Number of samples to be taken for each condition in cond_array.
158+ inverse_transform (bool): """
159+ assert len (cond_array .shape ) == 2 , "Condition array should be two-dimensional. N_conditions x cond_dim"
160+ assert cond_array .shape [1 ] == self .cond_dim , \
161+ f"The condition sequences should have a { self .cond_dim } length."
150162 steps = n_samples // self .batch_size + 1
151163 data = []
152164 z_dist = self .get_batch_noise ()
153- cond_seq = expand_dims (convert_to_tensor (cond_array , float32 ), axis = 0 )
154- cond_seq = tile (cond_seq , multiples = [self .batch_size , 1 ])
155- for step in trange (steps , desc = f'Synthetic data generation' ):
156- gen_input = concat ([cond_seq , next (z_dist )], axis = 1 )
157- records = make_ndarray (make_tensor_proto (self .generator (gen_input , training = False )))
158- data .append (records )
159- return array (vstack (data ))
165+ for condition in cond_array :
166+ cond_seq = convert_to_tensor (condition , float32 )
167+ cond_seq = tile (cond_seq , multiples = [self .batch_size , 1 ])
168+ for step in trange (steps , desc = f'Synthetic data generation' ):
169+ gen_input = concat ([cond_seq , next (z_dist )], axis = 1 )
170+ records = self .generator (gen_input , training = False )
171+ data .append (records )
172+ data = array (vstack (data ))
173+ if inverse_transform :
174+ return self .processor .inverse_transform (data )
175+ return data
160176
161177
162178class Generator (Model ):
0 commit comments