Skip to content

Commit 1c84754

Browse files
authored
fix: cgan discriminator repeated label (#101)
1 parent ba4f86d commit 1c84754

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

examples/regular/cgan_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
train_data = data.loc[ data['Class']==1 ].copy()
2626

2727
#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional GAN
28-
print("Dataset info: Number of records - {} Number of varibles - {}".format(train_data.shape[0], train_data.shape[1]))
28+
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
2929
algorithm = cluster.KMeans
3030
args, kwds = (), {'n_clusters':2, 'random_state':0}
3131
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])
@@ -63,7 +63,7 @@
6363
lr=learning_rate,
6464
betas=(beta_1, beta_2),
6565
noise_dim=noise_dim,
66-
n_cols=train_sample.shape[1],
66+
n_cols=train_sample.shape[1] - len(label_cols), # Don't count the label columns here
6767
layers_dim=dim)
6868

6969
train_args = TrainParameters(epochs=epochs,

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,15 @@ def train(self, data: Union[DataFrame, array],
9393
# ---------------------
9494
batch_x = self.get_data_batch(data, self.batch_size)
9595
label = batch_x[:, train_arguments.label_dim]
96+
data_cols = [i for i in range(batch_x.shape[1] - 1)] # All data without the label columns
9697
noise = tf.random.normal((self.batch_size, self.noise_dim))
9798

9899
# Generate a batch of new records
99100
gen_records = self.generator([noise, label], training=True)
100101

101102
# Train the discriminator
102-
d_loss_real = self.discriminator.train_on_batch([batch_x, label], valid)
103-
d_loss_fake = self.discriminator.train_on_batch([gen_records, label], fake)
103+
d_loss_real = self.discriminator.train_on_batch([batch_x[:, data_cols], label], valid) # Separate labels
104+
d_loss_fake = self.discriminator.train_on_batch([gen_records, label], fake) # Separate labels
104105
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
105106

106107
# ---------------------

0 commit comments

Comments
 (0)