Skip to content

Commit 2c2b720

Browse files
author
Francisco Santos
committed
apply revisions + add typeguard
apply revisions
1 parent bf8656f commit 2c2b720

5 files changed

Lines changed: 23 additions & 47 deletions

File tree

examples/timeseries/tscwgan_example.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from pandas import DataFrame
21
from numpy import squeeze
32

4-
from ydata_synthetic.postprocessing.timeseries.inverse_preprocesser import inverse_transform
53
from ydata_synthetic.preprocessing.timeseries import processed_stock
64
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
75
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
6+
from ydata_synthetic.postprocessing.regular.inverse_preprocesser import inverse_transform
87

98
model = TSCWGAN
109

@@ -24,7 +23,7 @@
2423
critic_iter = 5
2524

2625
# Get transformed data stock - Univariate
27-
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open')
26+
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = ['Open'])
2827
data_sample = processed_data[0]
2928

3029
model_parameters = ModelParameters(batch_size=batch_size,
@@ -51,10 +50,10 @@
5150

5251
#Sampling the data
5352
#Note that the data returned is not inverse processed.
54-
step = int(len(processed_data)/(5-1))
55-
cond_array = DataFrame(data=[squeeze(processed_data[i][:cond_dim], axis=1) for i in range(0, len(processed_data), step)])
53+
cond_index = 100 # Arbitrary sequence for conditioning
54+
cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1)
5655

57-
data_sample = synth.sample(cond_array, 200)
56+
data_sample = synth.sample(cond_array, 1000)
5857

5958
# Inverting the scaling of the synthetic samples
6059
data_sample = inverse_transform(data_sample, scaler)

src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/ydata_synthetic/preprocessing/timeseries/stock.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
Get the stock data from Yahoo finance data
33
Data from the period 01 January 2017 - 24 January 2021
44
"""
5-
from typing import Union, List
5+
from typing import Optional, List
66

77
import pandas as pd
8+
from typeguard import typechecked
89

910
from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading
1011

11-
def transformations(path, seq_len: int, cols: Union[str, List] = None):
12+
@typechecked
13+
def transformations(path, seq_len: int, cols: Optional[List] = None):
1214
"""Apply min max scaling and roll windows of a temporal dataset.
1315
1416
Args:
1517
path(str): path to a csv temporal dataframe
1618
seq_len(int): length of the rolled sequences
1719
cols (Union[str, List]): Column or list of columns to be used"""
18-
if isinstance(cols, str):
19-
cols = [cols]
2020
if isinstance(cols, list):
2121
stock_df = pd.read_csv(path)[cols]
2222
else:

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def save(self, path):
187187
super().save(path)
188188

189189

190-
class Generator(Model):
190+
class Generator(tf.keras.Model):
191191
def __init__(self, batch_size):
192192
"""Simple generator with dense feedforward layers."""
193193
self.batch_size = batch_size
@@ -202,7 +202,7 @@ def build_model(self, input_shape, dim, data_dim, activation_info: Optional[Name
202202
x = GumbelSoftmaxActivation(activation_info)(x)
203203
return Model(inputs=input_, outputs=x)
204204

205-
class Critic(Model):
205+
class Critic(tf.keras.Model):
206206
def __init__(self, batch_size):
207207
"""Simple critic with dense feedforward and dropout layers."""
208208
self.batch_size = batch_size

src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
And on: https://github.com/CasperHogenboom/WGAN_financial_time-series
55
"""
66
from tqdm import trange
7+
from numpy import array, vstack
78
from numpy.random import normal
8-
from pandas import DataFrame
99

1010
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims
1111
from tensorflow import data as tfdata
1212
from tensorflow.keras import Model, Sequential
1313
from tensorflow.keras.optimizers import Adam
1414
from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add
1515

16-
1716
from ydata_synthetic.synthesizers.gan import BaseModel
1817
from ydata_synthetic.synthesizers import TrainParameters
1918
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
@@ -61,10 +60,7 @@ def train(self, data, train_arguments: TrainParameters):
6160

6261
g_loss = self.update_generator(real_batch, noise_batch)
6362

64-
print(
65-
"Epoch: {} | critic_loss: {} | gen_loss: {}".format(
66-
epoch, c_loss, g_loss
67-
))
63+
print(f"Epoch: {epoch} | critic_loss: {c_loss} | gen_loss: {g_loss}")
6864

6965
self.g_optimizer = self.g_optimizer.get_config()
7066
self.c_optimizer = self.c_optimizer.get_config()
@@ -148,21 +144,19 @@ def get_batch_data(self, data, n_windows= None):
148144

149145
def sample(self, cond_array, n_samples):
150146
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array."""
151-
assert len(cond_array.shape) == 2, "Condition array should have 2 dimensions."
152-
assert cond_array.shape[1] == self.cond_dim, \
153-
f"Each sequence in the condition array should have a {self.cond_dim} length."
154-
n_conds = cond_array.shape[0]
147+
assert len(cond_array.shape) == 1, "Condition array should be one-dimensional."
148+
assert cond_array.shape[0] == self.cond_dim, \
149+
f"The condition sequence should have a {self.cond_dim} length."
155150
steps = n_samples // self.batch_size + 1
156151
data = []
157152
z_dist = self.get_batch_noise()
158-
for seq in range(n_conds):
159-
cond_seq = expand_dims(convert_to_tensor(cond_array.iloc[seq], float32), axis=0)
160-
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
161-
for step in trange(steps, desc=f'Synthetic data generation - Condition {seq+1}/{n_conds}'):
162-
gen_input = concat([cond_seq, next(z_dist)], axis=1)
163-
records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False)))
164-
data.append(records)
165-
return DataFrame(concat(data, axis=0))
153+
cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0)
154+
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
155+
for step in trange(steps, desc=f'Synthetic data generation'):
156+
gen_input = concat([cond_seq, next(z_dist)], axis=1)
157+
records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False)))
158+
data.append(records)
159+
return array(vstack(data))
166160

167161

168162
class Generator(Model):

0 commit comments

Comments
 (0)