Skip to content

Commit bf8656f

Browse files
author
Francisco Santos
committed
add example, remove added attribute of basemodel
remove changes on gitignore removed unused n_feats argument
1 parent 8454559 commit bf8656f

2 files changed

Lines changed: 60 additions & 2 deletions

File tree

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,5 +374,3 @@ DerivedData/
374374
# User created
375375
VERSION
376376
version.py
377-
local_test_*.py
378-
local_test_*.ipynb
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from pandas import DataFrame
2+
from numpy import squeeze
3+
4+
from ydata_synthetic.postprocessing.timeseries.inverse_preprocesser import inverse_transform
5+
from ydata_synthetic.preprocessing.timeseries import processed_stock
6+
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
7+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
8+
9+
model = TSCWGAN
10+
11+
#Define the GAN and training parameters
12+
noise_dim = 32
13+
dim = 128
14+
seq_len = 48
15+
cond_dim = 24
16+
batch_size = 128
17+
18+
log_step = 100
19+
epochs = 300+1
20+
learning_rate = 5e-4
21+
beta_1 = 0.5
22+
beta_2 = 0.9
23+
models_dir = './cache'
24+
critic_iter = 5
25+
26+
# Get transformed data stock - Univariate
27+
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open')
28+
data_sample = processed_data[0]
29+
30+
model_parameters = ModelParameters(batch_size=batch_size,
31+
lr=learning_rate,
32+
betas=(beta_1, beta_2),
33+
noise_dim=noise_dim,
34+
n_cols=seq_len,
35+
layers_dim=dim,
36+
condition = cond_dim)
37+
38+
train_args = TrainParameters(epochs=epochs,
39+
sample_interval=log_step,
40+
critic_iter=critic_iter)
41+
42+
#Training the TSCWGAN model
43+
synthesizer = model(model_parameters, gradient_penalty_weight=10)
44+
synthesizer.train(processed_data, train_args)
45+
46+
#Saving the synthesizer to later generate new events
47+
synthesizer.save(path='./tscwgan_stock.pkl')
48+
49+
#Loading the synthesizer
50+
synth = model.load(path='./tscwgan_stock.pkl')
51+
52+
#Sampling the data
53+
#Note that the data returned is not inverse processed.
54+
step = int(len(processed_data)/(5-1))
55+
cond_array = DataFrame(data=[squeeze(processed_data[i][:cond_dim], axis=1) for i in range(0, len(processed_data), step)])
56+
57+
data_sample = synth.sample(cond_array, 200)
58+
59+
# Inverting the scaling of the synthetic samples
60+
data_sample = inverse_transform(data_sample, scaler)

0 commit comments

Comments
 (0)