|
| 1 | +""" |
| 2 | +Conditional time-series Wasserstein GAN. |
| 3 | +Based on: https://www.naun.org/main/NAUN/neural/2020/a082016-004(2020).pdf |
| 4 | +And on: https://github.com/CasperHogenboom/WGAN_financial_time-series |
| 5 | +""" |
| 6 | +from tqdm import trange |
| 7 | +from numpy.random import normal |
| 8 | +from pandas import DataFrame |
| 9 | + |
| 10 | +from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims |
| 11 | +from tensorflow import data as tfdata |
| 12 | +from tensorflow.keras import Model, Sequential |
| 13 | +from tensorflow.keras.optimizers import Adam |
| 14 | +from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add |
| 15 | + |
| 16 | + |
| 17 | +from ydata_synthetic.synthesizers.gan import BaseModel |
| 18 | +from ydata_synthetic.synthesizers import TrainParameters |
| 19 | +from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty |
| 20 | + |
| 21 | +class TSCWGAN(BaseModel): |
| 22 | + |
| 23 | + __MODEL__='TSCWGAN' |
| 24 | + |
| 25 | + def __init__(self, model_parameters, gradient_penalty_weight=10): |
| 26 | + """Create a base TSCWGAN.""" |
| 27 | + self.gradient_penalty_weight = gradient_penalty_weight |
| 28 | + super().__init__(model_parameters) |
| 29 | + |
| 30 | + def define_gan(self): |
| 31 | + self.generator = Generator(self.batch_size). \ |
| 32 | + build_model(input_shape=(self.noise_dim + self.cond_dim, 1), dim=self.layers_dim, data_dim=self.data_dim) |
| 33 | + self.critic = Critic(self.batch_size). \ |
| 34 | + build_model(input_shape=(self.data_dim + self.cond_dim, 1), dim=self.layers_dim) |
| 35 | + |
| 36 | + self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2) |
| 37 | + self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2) |
| 38 | + |
| 39 | + # The generator takes noise as input and generates records |
| 40 | + noise = Input(shape=self.noise_dim, batch_size=self.batch_size) |
| 41 | + cond = Input(shape=self.cond_dim, batch_size=self.batch_size) |
| 42 | + gen = concat([cond, noise], axis=1) |
| 43 | + gen = self.generator(gen) |
| 44 | + score = concat([cond, gen], axis=1) |
| 45 | + score = self.critic(score) |
| 46 | + |
| 47 | + def train(self, data, train_arguments: TrainParameters): |
| 48 | + real_batches = self.get_batch_data(data) |
| 49 | + noise_batches = self.get_batch_noise() |
| 50 | + |
| 51 | + for epoch in trange(train_arguments.epochs): |
| 52 | + for i in range(train_arguments.critic_iter): |
| 53 | + real_batch = next(real_batches) |
| 54 | + noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor |
| 55 | + |
| 56 | + c_loss = self.update_critic(real_batch, noise_batch) |
| 57 | + |
| 58 | + real_batch = next(real_batches) |
| 59 | + noise_batch = next(noise_batches)[:len(real_batch)] |
| 60 | + |
| 61 | + g_loss = self.update_generator(real_batch, noise_batch) |
| 62 | + |
| 63 | + print( |
| 64 | + "Epoch: {} | critic_loss: {} | gen_loss: {}".format( |
| 65 | + epoch, c_loss, g_loss |
| 66 | + )) |
| 67 | + |
| 68 | + self.g_optimizer = self.g_optimizer.get_config() |
| 69 | + self.c_optimizer = self.c_optimizer.get_config() |
| 70 | + |
| 71 | + def update_critic(self, real_batch, noise_batch): |
| 72 | + with GradientTape() as c_tape: |
| 73 | + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) |
| 74 | + |
| 75 | + # Real and fake records with conditions |
| 76 | + real_batch_ = concat([cond_batch, real_batch], axis=1) |
| 77 | + fake_batch_ = concat([cond_batch, fake_batch], axis=1) |
| 78 | + |
| 79 | + c_loss = self.c_lossfn(real_batch_, fake_batch_) |
| 80 | + |
| 81 | + c_gradient = c_tape.gradient(c_loss, self.critic.trainable_variables) |
| 82 | + |
| 83 | + # Update the weights of the critic using the optimizer |
| 84 | + self.c_optimizer.apply_gradients( |
| 85 | + zip(c_gradient, self.critic.trainable_variables) |
| 86 | + ) |
| 87 | + return c_loss |
| 88 | + |
| 89 | + def update_generator(self, real_batch, noise_batch): |
| 90 | + with GradientTape() as g_tape: |
| 91 | + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) |
| 92 | + |
| 93 | + # Fake records with conditions |
| 94 | + fake_batch_ = concat([cond_batch, fake_batch], axis=1) |
| 95 | + |
| 96 | + g_loss = self.g_lossfn(fake_batch_) |
| 97 | + |
| 98 | + g_gradient = g_tape.gradient(g_loss, self.generator.trainable_variables) |
| 99 | + |
| 100 | + # Update the weights of the generator using the optimizer |
| 101 | + self.g_optimizer.apply_gradients( |
| 102 | + zip(g_gradient, self.generator.trainable_variables) |
| 103 | + ) |
| 104 | + return g_loss |
| 105 | + |
| 106 | + def c_lossfn(self, real_batch_, fake_batch_): |
| 107 | + score_fake = self.critic(fake_batch_) |
| 108 | + score_real = self.critic(real_batch_) |
| 109 | + grad_penalty = self.gradient_penalty(real_batch_, fake_batch_) |
| 110 | + c_loss = reduce_mean(score_fake) - reduce_mean(score_real) + grad_penalty |
| 111 | + return c_loss |
| 112 | + |
| 113 | + def g_lossfn(self, fake_batch_): |
| 114 | + score_fake = self.critic(fake_batch_) |
| 115 | + g_loss = - reduce_mean(score_fake) |
| 116 | + return g_loss |
| 117 | + |
| 118 | + def _make_fake_batch(self, real_batch, noise_batch): |
| 119 | + """Generate a batch of fake records and return it with the batch of used conditions. |
| 120 | + Conditions are the first elements of records in the real batch.""" |
| 121 | + cond_batch = real_batch[:, :self.cond_dim] |
| 122 | + gen_input = concat([cond_batch, noise_batch], axis=1) |
| 123 | + return self.generator(gen_input, training=True), cond_batch |
| 124 | + |
| 125 | + def gradient_penalty(self, real, fake): |
| 126 | + gp = gradient_penalty(self.critic, real, fake, mode=Mode.DRAGAN) |
| 127 | + return gp |
| 128 | + |
| 129 | + def _generate_noise(self): |
| 130 | + "Gaussian noise for the generator input." |
| 131 | + while True: |
| 132 | + yield normal(size=self.noise_dim) |
| 133 | + |
| 134 | + def get_batch_noise(self): |
| 135 | + "Create a batch iterator for the generator gaussian noise input." |
| 136 | + return iter(tfdata.Dataset.from_generator(self._generate_noise, output_types=float32) |
| 137 | + .batch(self.batch_size) |
| 138 | + .repeat()) |
| 139 | + |
| 140 | + def get_batch_data(self, data, n_windows= None): |
| 141 | + if not n_windows: |
| 142 | + n_windows = len(data) |
| 143 | + data = reshape(convert_to_tensor(data, dtype=float32), shape=(-1, self.data_dim)) |
| 144 | + return iter(tfdata.Dataset.from_tensor_slices(data) |
| 145 | + .shuffle(buffer_size=n_windows) |
| 146 | + .batch(self.batch_size).repeat()) |
| 147 | + |
| 148 | + def sample(self, cond_array, n_samples): |
| 149 | + """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" |
| 150 | + assert len(cond_array.shape) == 2, "Condition array should have 2 dimensions." |
| 151 | + assert cond_array.shape[1] == self.cond_dim, \ |
| 152 | + f"Each sequence in the condition array should have a {self.cond_dim} length." |
| 153 | + n_conds = cond_array.shape[0] |
| 154 | + steps = n_samples // self.batch_size + 1 |
| 155 | + data = [] |
| 156 | + z_dist = self.get_batch_noise() |
| 157 | + for seq in range(n_conds): |
| 158 | + cond_seq = expand_dims(convert_to_tensor(cond_array.iloc[seq], float32), axis=0) |
| 159 | + cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) |
| 160 | + for step in trange(steps, desc=f'Synthetic data generation - Condition {seq+1}/{n_conds}'): |
| 161 | + gen_input = concat([cond_seq, next(z_dist)], axis=1) |
| 162 | + records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) |
| 163 | + data.append(records) |
| 164 | + return DataFrame(concat(data, axis=0)) |
| 165 | + |
| 166 | + |
| 167 | +class Generator(Model): |
| 168 | + """Conditional generator with skip connections.""" |
| 169 | + def __init__(self, batch_size): |
| 170 | + self.batch_size = batch_size |
| 171 | + |
| 172 | + def build_model(self, input_shape, dim, data_dim): |
| 173 | + # Define blocks |
| 174 | + input_to_latent = Sequential(layers=[ |
| 175 | + Conv1D(filters=dim, kernel_size=1, input_shape = input_shape), |
| 176 | + LeakyReLU(), |
| 177 | + Conv1D(dim, kernel_size=5, dilation_rate=2, padding="same"), |
| 178 | + LeakyReLU() |
| 179 | + ], name='input_to_latent') |
| 180 | + block_cnn = Sequential(layers=[ |
| 181 | + Conv1D(filters=dim, kernel_size=3, dilation_rate=2, padding="same"), |
| 182 | + LeakyReLU() |
| 183 | + ], name='block_cnn') |
| 184 | + block_shift = Sequential(layers=[ |
| 185 | + Conv1D(filters=10, kernel_size=3, dilation_rate=2, padding="same"), |
| 186 | + LeakyReLU(), |
| 187 | + Flatten(), |
| 188 | + Dense(dim*2), |
| 189 | + LeakyReLU() |
| 190 | + ], name='block_shift') |
| 191 | + block = Sequential(layers=[ |
| 192 | + Dense(dim*2), |
| 193 | + LeakyReLU() |
| 194 | + ], name='block') |
| 195 | + latent_to_output = Sequential([ |
| 196 | + Dense(data_dim) |
| 197 | + ], name='latent_to_ouput') |
| 198 | + |
| 199 | + # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond |
| 200 | + noise_input = Input(shape = input_shape, batch_size = self.batch_size) |
| 201 | + |
| 202 | + # Compose model |
| 203 | + x = input_to_latent(noise_input) |
| 204 | + x_block = block_cnn(x) |
| 205 | + x = Add()([x_block, x]) |
| 206 | + x_block = block_cnn(x) |
| 207 | + x = Add()([x_block, x]) |
| 208 | + x_block = block_cnn(x) |
| 209 | + x = Add()([x_block, x]) |
| 210 | + x = block_shift(x) |
| 211 | + x_block = block(x) |
| 212 | + x = Add()([x_block, x]) |
| 213 | + x_block = block(x) |
| 214 | + x = Add()([x_block, x]) |
| 215 | + x_block = block(x) |
| 216 | + x = Add()([x_block, x]) |
| 217 | + x = latent_to_output(x) |
| 218 | + # Output - Expected shape is (batch_size, seq_len, data_dim). data_dim does not include conditions |
| 219 | + return Model(inputs=noise_input, outputs=x, name='SkipConnectionGenerator') |
| 220 | + |
| 221 | +class Critic(Model): |
| 222 | + """Conditional Wasserstein Critic with skip connections.""" |
| 223 | + def __init__(self, batch_size): |
| 224 | + self.batch_size = batch_size |
| 225 | + |
| 226 | + def build_model(self, input_shape, dim): |
| 227 | + # Define blocks |
| 228 | + ts_to_latent = Sequential(layers=[ |
| 229 | + Dense(dim*2,), |
| 230 | + LeakyReLU() |
| 231 | + ], name='ts_to_latent') |
| 232 | + block = Sequential(layers=[ |
| 233 | + Dense(dim*2), |
| 234 | + LeakyReLU() |
| 235 | + ], name='block') |
| 236 | + latent_to_score = Sequential(layers=[ |
| 237 | + Dense(1) |
| 238 | + ], name='latent_to_score') |
| 239 | + |
| 240 | + # Define input - Expected input shape is X + condition |
| 241 | + record_input = Input(shape = input_shape, batch_size = self.batch_size) |
| 242 | + |
| 243 | + # Compose model |
| 244 | + x = ts_to_latent(record_input) |
| 245 | + x_block = block(x) |
| 246 | + x = Add()([x_block, x]) |
| 247 | + x_block = block(x) |
| 248 | + x = Add()([x_block, x]) |
| 249 | + x_block = block(x) |
| 250 | + x = Add()([x_block, x]) |
| 251 | + x_block = block(x) |
| 252 | + x = Add()([x_block, x]) |
| 253 | + x_block = block(x) |
| 254 | + x = Add()([x_block, x]) |
| 255 | + x_block = block(x) |
| 256 | + x = Add()([x_block, x]) |
| 257 | + x_block = block(x) |
| 258 | + x = Add()([x_block, x]) |
| 259 | + x = latent_to_score(x) |
| 260 | + return Model(inputs=record_input, outputs=x, name='SkipConnectionCritic') |
0 commit comments