Skip to content

Commit 93aac14

Browse files
author
Francisco Santos
committed
apply revisions + add typeguard
1 parent 894d988 commit 93aac14

4 files changed

Lines changed: 8 additions & 7 deletions

File tree

examples/timeseries/tscwgan_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
critic_iter = 5
2525

2626
# Get transformed data stock - Univariate
27-
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open')
27+
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = ['Open'])
2828
data_sample = processed_data[0]
2929

3030
model_parameters = ModelParameters(batch_size=batch_size,

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tensorflow==2.4.*
77
easydict==1.9
88
pmlb==1.0.*
99
tqdm<5.0
10+
typeguard==2.13.*

src/ydata_synthetic/preprocessing/timeseries/stock.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
Get the stock data from Yahoo finance data
33
Data from the period 01 January 2017 - 24 January 2021
44
"""
5-
from typing import Union, List
5+
from typing import Optional, List
66

77
import pandas as pd
8+
from typeguard import typechecked
89

910
from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading
1011

11-
def transformations(path, seq_len: int, cols: Union[str, List] = None):
12+
@typechecked
13+
def transformations(path, seq_len: int, cols: Optional[List] = None):
1214
"""Apply min max scaling and roll windows of a temporal dataset.
1315
1416
Args:
1517
path(str): path to a csv temporal dataframe
1618
seq_len(int): length of the rolled sequences
1719
cols (Union[str, List]): Column or list of columns to be used"""
18-
if isinstance(cols, str):
19-
cols = [cols]
2020
if isinstance(cols, list):
2121
stock_df = pd.read_csv(path)[cols]
2222
else:

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def save(self, path):
171171
super().save(path)
172172

173173

174-
class Generator(Model):
174+
class Generator(tf.keras.Model):
175175
def __init__(self, batch_size):
176176
"""Simple generator with dense feedforward layers."""
177177
self.batch_size = batch_size
@@ -184,7 +184,7 @@ def build_model(self, input_shape, dim, data_dim):
184184
x = Dense(data_dim)(x)
185185
return Model(inputs=input_, outputs=x)
186186

187-
class Critic(Model):
187+
class Critic(tf.keras.Model):
188188
def __init__(self, batch_size):
189189
"""Simple critic with dense feedforward and dropout layers."""
190190
self.batch_size = batch_size

0 commit comments

Comments
 (0)