|
4 | 4 | And on: https://github.com/CasperHogenboom/WGAN_financial_time-series |
5 | 5 | """ |
6 | 6 | from tqdm import trange |
| 7 | +from numpy import array, vstack |
7 | 8 | from numpy.random import normal |
8 | | -from pandas import DataFrame |
9 | 9 |
|
10 | 10 | from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims |
11 | 11 | from tensorflow import data as tfdata |
12 | 12 | from tensorflow.keras import Model, Sequential |
13 | 13 | from tensorflow.keras.optimizers import Adam |
14 | 14 | from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add |
15 | 15 |
|
16 | | - |
17 | 16 | from ydata_synthetic.synthesizers.gan import BaseModel |
18 | 17 | from ydata_synthetic.synthesizers import TrainParameters |
19 | 18 | from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty |
@@ -61,10 +60,7 @@ def train(self, data, train_arguments: TrainParameters): |
61 | 60 |
|
62 | 61 | g_loss = self.update_generator(real_batch, noise_batch) |
63 | 62 |
|
64 | | - print( |
65 | | - "Epoch: {} | critic_loss: {} | gen_loss: {}".format( |
66 | | - epoch, c_loss, g_loss |
67 | | - )) |
| 63 | + print(f"Epoch: {epoch} | critic_loss: {c_loss} | gen_loss: {g_loss}") |
68 | 64 |
|
69 | 65 | self.g_optimizer = self.g_optimizer.get_config() |
70 | 66 | self.c_optimizer = self.c_optimizer.get_config() |
@@ -148,21 +144,19 @@ def get_batch_data(self, data, n_windows= None): |
148 | 144 |
|
149 | 145 | def sample(self, cond_array, n_samples): |
150 | 146 | """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" |
151 | | - assert len(cond_array.shape) == 2, "Condition array should have 2 dimensions." |
152 | | - assert cond_array.shape[1] == self.cond_dim, \ |
153 | | - f"Each sequence in the condition array should have a {self.cond_dim} length." |
154 | | - n_conds = cond_array.shape[0] |
| 147 | + assert len(cond_array.shape) == 1, "Condition array should be one-dimensional." |
| 148 | + assert cond_array.shape[0] == self.cond_dim, \ |
| 149 | + f"The condition sequence should have a {self.cond_dim} length." |
155 | 150 | steps = n_samples // self.batch_size + 1 |
156 | 151 | data = [] |
157 | 152 | z_dist = self.get_batch_noise() |
158 | | - for seq in range(n_conds): |
159 | | - cond_seq = expand_dims(convert_to_tensor(cond_array.iloc[seq], float32), axis=0) |
160 | | - cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) |
161 | | - for step in trange(steps, desc=f'Synthetic data generation - Condition {seq+1}/{n_conds}'): |
162 | | - gen_input = concat([cond_seq, next(z_dist)], axis=1) |
163 | | - records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) |
164 | | - data.append(records) |
165 | | - return DataFrame(concat(data, axis=0)) |
| 153 | + cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0) |
| 154 | + cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) |
| 155 | + for step in trange(steps, desc=f'Synthetic data generation'): |
| 156 | + gen_input = concat([cond_seq, next(z_dist)], axis=1) |
| 157 | + records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) |
| 158 | + data.append(records) |
| 159 | + return array(vstack(data)) |
166 | 160 |
|
167 | 161 |
|
168 | 162 | class Generator(Model): |
|
0 commit comments