Skip to content

Commit 3c5f1ee

Browse files
authored
fix: Fix data preprocessing and dragan performance prints (#105)
1 parent c618d90 commit 3c5f1ee

3 files changed

Lines changed: 7 additions & 6 deletions

File tree

examples/regular/cramergan_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
#Before training the GAN do not forget to apply the required data transformations
2525
#To ease here we've applied a PowerTransformation
26-
data = transformations(data)
26+
_, data, _ = transformations(data)
27+
2728

2829
#For the purpose of this example we will only synthesize the minority class
2930
train_data = data.loc[ data['Class']==1 ].copy()

examples/regular/wgan_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
#Before training the GAN do not forget to apply the required data transformations
2323
#To ease here we've applied a PowerTransformation
24-
data = transformations(data)
24+
_, data, _ = transformations(data)
2525

2626
#For the purpose of this example we will only synthesize the minority class
2727
train_data = data.loc[ data['Class']==1 ].copy()

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ def train(self, data, train_arguments):
125125
batch_data = tf.cast(batch_data, dtype=tf.float32)
126126
d_loss, g_loss = self.train_step(batch_data)
127127

128-
print(
129-
"Epoch: {} | disc_loss: {} | gen_loss: {}".format(
130-
epoch, d_loss, g_loss
131-
))
128+
print(
129+
"Epoch: {} | disc_loss: {} | gen_loss: {}".format(
130+
epoch, d_loss, g_loss
131+
))
132132

133133
if epoch % train_arguments.sample_interval == 0:
134134
# Test here data generation step

0 commit comments

Comments
 (0)