Skip to content

Commit 2326986

Browse files
author
Francisco Santos
committed
add example, remove added attribute of basemodel
1 parent 0cc0772 commit 2326986

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from pandas import DataFrame
2+
from numpy import squeeze
3+
4+
from ydata_synthetic.postprocessing.timeseries.inverse_preprocesser import inverse_transform
5+
from ydata_synthetic.preprocessing.timeseries import processed_stock
6+
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
7+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
8+
9+
model = TSCWGAN
10+
11+
#Define the GAN and training parameters
12+
noise_dim = 32
13+
dim = 128
14+
seq_len = 48
15+
cond_dim = 24
16+
batch_size = 128
17+
18+
log_step = 100
19+
epochs = 300+1
20+
learning_rate = 5e-4
21+
beta_1 = 0.5
22+
beta_2 = 0.9
23+
models_dir = './cache'
24+
critic_iter = 5
25+
26+
# Get transformed data stock - Univariate
27+
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open')
28+
data_sample = processed_data[0]
29+
n_features = data_sample.shape[1]
30+
31+
model_parameters = ModelParameters(batch_size=batch_size,
32+
lr=learning_rate,
33+
betas=(beta_1, beta_2),
34+
noise_dim=noise_dim,
35+
n_cols=seq_len,
36+
layers_dim=dim,
37+
condition = cond_dim,
38+
n_features = n_features)
39+
40+
train_args = TrainParameters(epochs=epochs,
41+
sample_interval=log_step,
42+
critic_iter=critic_iter)
43+
44+
#Training the TSCWGAN model
45+
synthesizer = model(model_parameters, gradient_penalty_weight=10)
46+
synthesizer.train(processed_data, train_args)
47+
48+
#Saving the synthesizer to later generate new events
49+
synthesizer.save(path='./tscwgan_stock.pkl')
50+
51+
#Loading the synthesizer
52+
synth = model.load(path='./tscwgan_stock.pkl')
53+
54+
#Sampling the data
55+
#Note that the data returned is not inverse processed.
56+
step = int(len(processed_data)/(5-1))
57+
cond_array = DataFrame(data=[squeeze(processed_data[i][:cond_dim], axis=1) for i in range(0, len(processed_data), step)])
58+
59+
data_sample = synth.sample(cond_array, 200)
60+
61+
# Inverting the scaling of the synthetic samples
62+
data_sample = inverse_transform(data_sample, scaler)

src/ydata_synthetic/synthesizers/gan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(
4444
self.noise_dim = model_parameters.noise_dim
4545
self.data_dim = model_parameters.n_cols
4646
self.layers_dim = model_parameters.layers_dim
47-
self.n_features = model_parameters.n_features
4847
self.define_gan()
4948

5049
def __call__(self, inputs, **kwargs):

0 commit comments

Comments
 (0)