Skip to content

Commit 3f6cbe5

Browse files
author
Francisco Santos
committed
integrate TSDataProcessor, revise sample method
1 parent d842e0c commit 3f6cbe5

1 file changed

Lines changed: 33 additions & 17 deletions

File tree

  • src/ydata_synthetic/synthesizers/timeseries/tscwgan

src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from tqdm import trange
77
from numpy import array, vstack
88
from numpy.random import normal
9+
from typing import List
910

10-
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims
11+
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, constant
1112
from tensorflow import data as tfdata
1213
from tensorflow.keras import Model, Sequential
1314
from tensorflow.keras.optimizers import Adam
@@ -16,16 +17,17 @@
1617
from ydata_synthetic.synthesizers.gan import BaseModel
1718
from ydata_synthetic.synthesizers import TrainParameters
1819
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
20+
from ydata_synthetic.synthesizers.timeseries import TimeSeriesDataProcessor
1921

2022
class TSCWGAN(BaseModel):
2123

2224
__MODEL__='TSCWGAN'
2325

2426
def __init__(self, model_parameters, gradient_penalty_weight=10):
2527
"""Create a base TSCWGAN."""
28+
super().__init__(model_parameters)
2629
self.gradient_penalty_weight = gradient_penalty_weight
2730
self.cond_dim = model_parameters.condition
28-
super().__init__(model_parameters)
2931

3032
def define_gan(self):
3133
self.generator = Generator(self.batch_size). \
@@ -44,14 +46,18 @@ def define_gan(self):
4446
score = concat([cond, gen], axis=1)
4547
score = self.critic(score)
4648

47-
def train(self, data, train_arguments: TrainParameters):
48-
real_batches = self.get_batch_data(data)
49+
def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str],
50+
preprocess: bool = True):
51+
super().train(data, num_cols, cat_cols, preprocess)
52+
53+
processed_data = self.processor.transform(data)
54+
real_batches = self.get_batch_data(processed_data)
4955
noise_batches = self.get_batch_noise()
5056

5157
for epoch in trange(train_arguments.epochs):
5258
for i in range(train_arguments.critic_iter):
5359
real_batch = next(real_batches)
54-
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor
60+
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate noise tensor to real data shape
5561

5662
c_loss = self.update_critic(real_batch, noise_batch)
5763

@@ -142,21 +148,31 @@ def get_batch_data(self, data, n_windows= None):
142148
.shuffle(buffer_size=n_windows)
143149
.batch(self.batch_size).repeat())
144150

145-
def sample(self, cond_array, n_samples):
146-
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array."""
147-
assert len(cond_array.shape) == 1, "Condition array should be one-dimensional."
148-
assert cond_array.shape[0] == self.cond_dim, \
149-
f"The condition sequence should have a {self.cond_dim} length."
151+
def sample(self, cond_array: array, n_samples: int, inverse_transform: bool = True):
152+
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.
153+
The returned samples per condition will always be a multiple of batch_size and equal or bigger than n_samples.
154+
155+
Arguments:
156+
cond_array (numpy array): Array with the set of conditions for the sampling process.
157+
n_samples (int): Number of samples to be taken for each condition in cond_array.
158+
inverse_transform (bool): """
159+
assert len(cond_array.shape) == 2, "Condition array should be two-dimensional. N_conditions x cond_dim"
160+
assert cond_array.shape[1] == self.cond_dim, \
161+
f"The condition sequences should have a {self.cond_dim} length."
150162
steps = n_samples // self.batch_size + 1
151163
data = []
152164
z_dist = self.get_batch_noise()
153-
cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0)
154-
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
155-
for step in trange(steps, desc=f'Synthetic data generation'):
156-
gen_input = concat([cond_seq, next(z_dist)], axis=1)
157-
records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False)))
158-
data.append(records)
159-
return array(vstack(data))
165+
for condition in cond_array:
166+
cond_seq = convert_to_tensor(condition, float32)
167+
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
168+
for step in trange(steps, desc=f'Synthetic data generation'):
169+
gen_input = concat([cond_seq, next(z_dist)], axis=1)
170+
records = self.generator(gen_input, training=False)
171+
data.append(records)
172+
data = array(vstack(data))
173+
if inverse_transform:
174+
return self.processor.inverse_transform(data)
175+
return data
160176

161177

162178
class Generator(Model):

0 commit comments

Comments
 (0)