Skip to content

Commit 809008a

Browse files
ricardodcpereiraricardodcpereira
authored andcommitted
feat: add doppelganger model
1 parent f53afd3 commit 809008a

12 files changed

Lines changed: 1493 additions & 61 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
# Importing necessary libraries
3+
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
4+
from ydata_synthetic.preprocessing.timeseries import processed_stock
5+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
6+
import pandas as pd
7+
from os import path
8+
9+
# Read the data
10+
stock_data = processed_stock(path='../../data/stock_data.csv', seq_len=24)
11+
stock_data = [pd.DataFrame(sd, columns = ["Open", "High", "Low", "Close", "Adj_Close", "Volume"]) for sd in stock_data]
12+
stock_data = pd.concat(stock_data).reset_index(drop=True)
13+
14+
# Define model parameters
15+
model_args = ModelParameters(batch_size=100,
16+
lr=0.001,
17+
betas=(0.5, 0.9),
18+
latent_dim=3,
19+
gp_lambda=10,
20+
pac=10)
21+
22+
train_args = TrainParameters(epochs=500, sequence_length=24,
23+
measurement_cols=["Open", "High", "Low", "Close", "Adj_Close", "Volume"])
24+
25+
# Training the DoppelGANger synthesizer
26+
if path.exists('doppelganger_stock'):
27+
model_dop_gan = TimeSeriesSynthesizer.load('doppelganger_stock')
28+
else:
29+
model_dop_gan = TimeSeriesSynthesizer(modelname='doppelganger', model_parameters=model_args)
30+
model_dop_gan.fit(stock_data, train_args, num_cols=["Open", "High", "Low", "Close", "Adj_Close", "Volume"])
31+
32+
# Generating new synthetic samples
33+
synth_data = model_dop_gan.sample(n_samples=500)
34+
print(synth_data[0])

examples/timeseries/stock_timegan.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,62 +4,53 @@
44

55
# Importing necessary libraries
66
from os import path
7-
import pandas as pd
7+
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
8+
from ydata_synthetic.preprocessing.timeseries import processed_stock
9+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
810
import numpy as np
11+
import pandas as pd
912
import matplotlib.pyplot as plt
1013

11-
from ydata_synthetic.synthesizers import ModelParameters
12-
from ydata_synthetic.preprocessing.timeseries import processed_stock
13-
from ydata_synthetic.synthesizers.timeseries import TimeGAN
14-
1514
# Define model parameters
16-
seq_len=24
17-
n_seq = 6
18-
hidden_dim=24
19-
gamma=1
20-
21-
noise_dim = 32
22-
dim = 128
23-
batch_size = 128
15+
gan_args = ModelParameters(batch_size=128,
16+
lr=5e-4,
17+
noise_dim=32,
18+
layers_dim=128,
19+
latent_dim=24,
20+
gamma=1)
2421

25-
log_step = 100
26-
learning_rate = 5e-4
27-
28-
gan_args = ModelParameters(batch_size=batch_size,
29-
lr=learning_rate,
30-
noise_dim=noise_dim,
31-
layers_dim=dim)
22+
train_args = TrainParameters(epochs=50000,
23+
sequence_length=24,
24+
number_sequences=6)
3225

3326
# Read the data
34-
stock_data = processed_stock(path='../../data/stock_data.csv', seq_len=seq_len)
35-
print(len(stock_data),stock_data[0].shape)
27+
stock_data = pd.read_csv("../../data/stock_data.csv")
28+
cols = list(stock_data.columns)
3629

3730
# Training the TimeGAN synthesizer
3831
if path.exists('synthesizer_stock.pkl'):
39-
synth = TimeGAN.load('synthesizer_stock.pkl')
32+
synth = TimeSeriesSynthesizer.load('synthesizer_stock.pkl')
4033
else:
41-
synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1)
42-
synth.train(stock_data, train_steps=50000)
34+
synth = TimeSeriesSynthesizer(modelname='timegan', model_parameters=gan_args)
35+
synth.fit(stock_data, train_args, num_cols=cols)
4336
synth.save('synthesizer_stock.pkl')
4437

4538
# Generating new synthetic samples
46-
synth_data = synth.sample(len(stock_data))
47-
print(synth_data.shape)
48-
49-
# Reshaping the data
50-
cols = ['Open','High','Low','Close','Adj Close','Volume']
39+
stock_data_blocks = processed_stock(path='../../data/stock_data.csv', seq_len=24)
40+
synth_data = synth.sample(n_samples=len(stock_data_blocks))
41+
print(synth_data[0].shape)
5142

5243
# Plotting some generated samples. Both Synthetic and Original data are still standartized with values between [0,1]
5344
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 10))
5445
axes=axes.flatten()
5546

5647
time = list(range(1,25))
57-
obs = np.random.randint(len(stock_data))
48+
obs = np.random.randint(len(stock_data_blocks))
5849

5950
for j, col in enumerate(cols):
60-
df = pd.DataFrame({'Real': stock_data[obs][:, j],
61-
'Synthetic': synth_data[obs][:, j]})
51+
df = pd.DataFrame({'Real': stock_data_blocks[obs][:, j],
52+
'Synthetic': synth_data[obs].iloc[:, j]})
6253
df.plot(ax=axes[j],
6354
title = col,
6455
secondary_y='Synthetic data', style=['-', '--'])
65-
fig.tight_layout()
56+
fig.tight_layout()

src/ydata_synthetic/preprocessing/regular/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def fit(self, X: DataFrame) -> RegularDataProcessor:
6262
("scaler", MinMaxScaler()),
6363
])
6464
self._cat_pipeline = Pipeline([
65-
("encoder", OneHotEncoder(sparse=False, handle_unknown='ignore')),
65+
("encoder", OneHotEncoder(sparse_output=False, handle_unknown='ignore')),
6666
])
6767

6868
self.num_pipeline.fit(X[self.num_cols]) if self.num_cols else zeros([len(X), 0])
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional
4+
from dataclasses import dataclass
5+
6+
from numpy import concatenate, ndarray, zeros, ones, expand_dims, reshape, sum as npsum, repeat, array_split, asarray
7+
from pandas import DataFrame
8+
from typeguard import typechecked
9+
10+
from ydata_synthetic.preprocessing.regular.processor import RegularDataProcessor
11+
12+
13+
@dataclass
14+
class ColumnMetadata:
15+
"""
16+
Dataclass that stores the metadata of each column.
17+
"""
18+
discrete: bool
19+
output_dim: int
20+
name: str
21+
22+
23+
@typechecked
24+
class DoppelGANgerProcessor(RegularDataProcessor):
25+
"""
26+
Main class for class the DoppelGANger preprocessing.
27+
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
28+
Args:
29+
num_cols (list of strings):
30+
List of names of numerical columns.
31+
measurement_cols (list of strings):
32+
List of measurement columns.
33+
sequence_length (int):
34+
Sequence length.
35+
"""
36+
SUPPORTED_MODEL = 'DoppelGANger'
37+
38+
def __init__(self, num_cols: Optional[List[str]] = None,
39+
cat_cols: Optional[List[str]] = None,
40+
measurement_cols: Optional[List[str]] = None,
41+
sequence_length: Optional[int] = None):
42+
super().__init__(num_cols, cat_cols)
43+
44+
if num_cols is None:
45+
num_cols = []
46+
if cat_cols is None:
47+
cat_cols = []
48+
if measurement_cols is None:
49+
measurement_cols = []
50+
self.sequence_length = sequence_length
51+
self._measurement_num_cols = [c for c in self.num_cols if c in measurement_cols]
52+
self._measurement_cat_cols = [c for c in self.cat_cols if c in measurement_cols]
53+
self._attribute_num_cols = [c for c in self.num_cols if c not in measurement_cols]
54+
self._attribute_cat_cols = [c for c in self.cat_cols if c not in measurement_cols]
55+
self._measurement_cols_metadata = None
56+
self._attribute_cols_metadata = None
57+
self._measurement_one_hot_cat_cols = None
58+
self._attribute_one_hot_cat_cols = None
59+
self._has_attributes = self._attribute_num_cols or self._attribute_cat_cols
60+
61+
@property
62+
def measurement_cols_metadata(self):
63+
return self._measurement_cols_metadata
64+
65+
@property
66+
def attribute_cols_metadata(self):
67+
return self._attribute_cols_metadata
68+
69+
def add_gen_flag(self, data_features: ndarray, sample_len: int):
70+
num_sample = data_features.shape[0]
71+
length = data_features.shape[1]
72+
if length % sample_len != 0:
73+
raise Exception("length must be a multiple of sample_len")
74+
data_gen_flag = ones((num_sample, length))
75+
data_gen_flag = expand_dims(data_gen_flag, 2)
76+
shift_gen_flag = concatenate(
77+
[data_gen_flag[:, 1:, :],
78+
zeros((data_gen_flag.shape[0], 1, 1))],
79+
axis=1)
80+
data_gen_flag_t = reshape(
81+
data_gen_flag,
82+
[num_sample, int(length / sample_len), sample_len])
83+
data_gen_flag_t = npsum(data_gen_flag_t, 2)
84+
data_gen_flag_t = data_gen_flag_t > 0.5
85+
data_gen_flag_t = repeat(data_gen_flag_t, sample_len, axis=1)
86+
data_gen_flag_t = expand_dims(data_gen_flag_t, 2)
87+
data_features = concatenate(
88+
[data_features,
89+
shift_gen_flag,
90+
(1 - shift_gen_flag) * data_gen_flag_t],
91+
axis=2)
92+
93+
return data_features
94+
95+
def transform(self, X: DataFrame) -> tuple[ndarray, ndarray]:
96+
"""Transforms the passed DataFrame with the fit DataProcessor.
97+
Args:
98+
X (DataFrame):
99+
DataFrame used to fit the processor parameters.
100+
Should be aligned with the columns types defined in initialization.
101+
Returns:
102+
transformed (ndarray, ndarray):
103+
Processed version of the passed DataFrame.
104+
"""
105+
self._check_is_fitted()
106+
107+
measurement_cols = self._measurement_num_cols + self._measurement_cat_cols
108+
if not measurement_cols:
109+
raise ValueError("At least one measurement column must be supplied.")
110+
if not all(c in self.num_cols + self.cat_cols for c in measurement_cols):
111+
raise ValueError("At least one of the supplied measurement columns does not exist in the dataset.")
112+
if self.sequence_length is None:
113+
raise ValueError("The sequence length is mandatory.")
114+
115+
num_data = DataFrame(self.num_pipeline.transform(X[self.num_cols]) if self.num_cols else zeros([len(X), 0]), columns=self.num_cols)
116+
one_hot_cat_cols = self.cat_pipeline.get_feature_names_out()
117+
cat_data = DataFrame(self.cat_pipeline.transform(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0]), columns=one_hot_cat_cols)
118+
119+
self._measurement_one_hot_cat_cols = [c for c in one_hot_cat_cols if c.split("_")[0] in self._measurement_cat_cols]
120+
measurement_num_data = num_data[self._measurement_num_cols].to_numpy() if self._measurement_num_cols else zeros([len(X), 0])
121+
self._measurement_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name=c) for c in self._measurement_num_cols]
122+
measurement_cat_data = cat_data[self._measurement_one_hot_cat_cols].to_numpy() if self._measurement_one_hot_cat_cols else zeros([len(X), 0])
123+
self._measurement_cols_metadata += [ColumnMetadata(discrete=True, output_dim=X[c].nunique(), name=c) for c in self._measurement_cat_cols]
124+
data_features = concatenate([measurement_num_data, measurement_cat_data], axis=1)
125+
126+
if self._has_attributes:
127+
self._attribute_one_hot_cat_cols = [c for c in one_hot_cat_cols if c.split("_")[0] in self._attribute_cat_cols]
128+
attribute_num_data = num_data[self._attribute_num_cols].to_numpy() if self._attribute_num_cols else zeros([len(X), 0])
129+
self._attribute_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name=c) for c in self._attribute_num_cols]
130+
attribute_cat_data = cat_data[self._attribute_one_hot_cat_cols].to_numpy() if self._attribute_one_hot_cat_cols else zeros([len(X), 0])
131+
self._attribute_cols_metadata += [ColumnMetadata(discrete=True, output_dim=X[c].nunique(), name=c) for c in self._attribute_cat_cols]
132+
data_attributes = concatenate([attribute_num_data, attribute_cat_data], axis=1)
133+
else:
134+
self._attribute_one_hot_cat_cols = []
135+
data_attributes = zeros((data_features.shape[0], 1))
136+
self._attribute_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name="zeros_attribute")]
137+
138+
num_samples = int(X.shape[0] / self.sequence_length)
139+
data_features = asarray(array_split(data_features, num_samples))
140+
data_attributes = asarray(array_split(data_attributes, num_samples))
141+
142+
data_features = self.add_gen_flag(data_features, sample_len=self.sequence_length)
143+
self._measurement_cols_metadata += [ColumnMetadata(discrete=True, output_dim=2, name="gen_flags")]
144+
return data_features, data_attributes.mean(axis=1)
145+
146+
def inverse_transform(self, X_features: ndarray, X_attributes: ndarray) -> list[DataFrame]:
147+
"""Inverts the data transformation pipelines on a passed DataFrame.
148+
Args:
149+
X_features (ndarray):
150+
Numpy array with the measurement data to be brought back to the original format.
151+
X_attributes (ndarray):
152+
Numpy array with the attribute data to be brought back to the original format.
153+
Returns:
154+
result (DataFrame):
155+
DataFrame with all performed transformations inverted.
156+
"""
157+
self._check_is_fitted()
158+
159+
num_samples = X_attributes.shape[0]
160+
if self._has_attributes:
161+
X_attributes = repeat(X_attributes.reshape((num_samples, 1, X_attributes.shape[1])), repeats=X_features.shape[1], axis=1)
162+
generated_data = concatenate((X_features, X_attributes), axis=2)
163+
else:
164+
generated_data = X_features
165+
output_cols = self._measurement_num_cols + self._measurement_one_hot_cat_cols + self._attribute_num_cols + self._attribute_one_hot_cat_cols
166+
one_hot_cat_cols = self._measurement_one_hot_cat_cols + self._attribute_one_hot_cat_cols
167+
168+
samples = []
169+
for i in range(num_samples):
170+
df = DataFrame(generated_data[i], columns=output_cols)
171+
df_num = self.num_pipeline.inverse_transform(df[self.num_cols]) if self.num_cols else zeros([len(df), 0])
172+
df_cat = self.cat_pipeline.inverse_transform(df[one_hot_cat_cols].round(0)) if self.cat_cols else zeros([len(df), 0])
173+
df = DataFrame(concatenate((df_num, df_cat), axis=1), columns=self.num_cols+self.cat_cols)
174+
df = df.loc[:, self._col_order_]
175+
for col in df.columns:
176+
df[col] = df[col].astype(self._types[col])
177+
samples.append(df)
178+
179+
return samples

src/ydata_synthetic/synthesizers/base.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,23 @@
2626
from ydata_synthetic.preprocessing.timeseries.timeseries_processor import (
2727
TimeSeriesDataProcessor, TimeSeriesModels)
2828
from ydata_synthetic.preprocessing.regular.ctgan_processor import CTGANDataProcessor
29+
from ydata_synthetic.preprocessing.timeseries.doppelganger_processor import DoppelGANgerProcessor
2930
from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable
3031

3132
_model_parameters = ['batch_size', 'lr', 'betas', 'layers_dim', 'noise_dim',
3233
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features',
3334
'tau_gs', 'generator_dims', 'critic_dims', 'l2_scale',
34-
'latent_dim', 'gp_lambda', 'pac']
35+
'latent_dim', 'gp_lambda', 'pac', 'gamma']
3536
_model_parameters_df = [128, 1e-4, (None, None), 128, 264,
3637
None, None, None, 1, None, 0.2, [256, 256],
37-
[256, 256], 1e-6, 128, 10.0, 10]
38+
[256, 256], 1e-6, 128, 10.0, 10, 1]
3839

3940
_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval',
40-
'labels', 'n_clusters', 'epsilon', 'log_frequency']
41+
'labels', 'n_clusters', 'epsilon', 'log_frequency',
42+
'measurement_cols', 'sequence_length', 'number_sequences']
4143

4244
ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df)
43-
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True))
45+
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True, None, 1, 1))
4446

4547
@typechecked
4648
class BaseModel(ABC):
@@ -185,6 +187,12 @@ def fit(self,
185187
epsilon = train_arguments.epsilon
186188
self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon,
187189
num_cols=num_cols, cat_cols=cat_cols).fit(data)
190+
elif self.__MODEL__ == DoppelGANgerProcessor.SUPPORTED_MODEL:
191+
measurement_cols = train_arguments.measurement_cols
192+
sequence_length = train_arguments.sequence_length
193+
self.processor = DoppelGANgerProcessor(num_cols=num_cols, cat_cols=cat_cols,
194+
measurement_cols=measurement_cols,
195+
sequence_length=sequence_length).fit(data)
188196
else:
189197
print(f'A DataProcessor is not available for the {self.__MODEL__}.')
190198

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ydata_synthetic.synthesizers.timeseries.timegan.model import TimeGAN
1+
from ydata_synthetic.synthesizers.timeseries.model import TimeSeriesSynthesizer
22

33
__all__ = [
4-
'TimeGAN',
4+
'TimeSeriesSynthesizer'
55
]

src/ydata_synthetic/synthesizers/timeseries/doppelganger/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)