Skip to content

Commit 4fc1f4f

Browse files
authored
chore: expose gs tau argument [SD-128] (#145)
* Expose tau argument of the Gumbel Softmax * Integrate exposed tau arg in regular synths * expose tau only on RegularModels
1 parent 465e502 commit 4fc1f4f

8 files changed

Lines changed: 36 additions & 25 deletions

File tree

src/ydata_synthetic/synthesizers/gan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable
1818

1919
_model_parameters = ['batch_size', 'lr', 'betas', 'layers_dim', 'noise_dim',
20-
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features']
20+
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features', 'tau_gs']
2121
_model_parameters_df = [128, 1e-4, (None, None), 128, 264,
22-
None, None, None, 1, None]
22+
None, None, None, 1, None, 0.2]
2323

2424
_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval', 'labels']
2525

@@ -62,6 +62,8 @@ def __init__(
6262
self.data_dim = None
6363
self.layers_dim = model_parameters.layers_dim
6464
self.processor = None
65+
if self.__MODEL__ in RegularModels.__members__:
66+
self.tau = model_parameters.tau_gs
6567

6668
# pylint: disable=E1101
6769
def __call__(self, inputs, **kwargs):

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def label_col(self, data_label: Tuple[Union[DataFrame, array], str]):
5454
def define_gan(self, activation_info: Optional[NamedTuple] = None):
5555
self.generator = Generator(self.batch_size, self.num_classes). \
5656
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
57-
activation_info = activation_info)
57+
activation_info = activation_info, tau = self.tau)
5858

5959
self.discriminator = Discriminator(self.batch_size, self.num_classes). \
6060
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -200,7 +200,7 @@ def __init__(self, batch_size, num_classes):
200200
self.batch_size = batch_size
201201
self.num_classes = num_classes
202202

203-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
203+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
204204
noise = Input(shape=input_shape, batch_size=self.batch_size)
205205
label = Input(shape=(1,), batch_size=self.batch_size, dtype='int32')
206206
label_embedding = Flatten()(Embedding(self.num_classes, 1)(label))
@@ -211,7 +211,7 @@ def build_model(self, input_shape, dim, data_dim, activation_info: Optional[Name
211211
x = Dense(dim * 4, activation='relu')(x)
212212
x = Dense(data_dim)(x)
213213
if activation_info:
214-
x = GumbelSoftmaxActivation(activation_info).call(x)
214+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
215215
return Model(inputs=[noise, label], outputs=x)
216216

217217

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, model_parameters, gradient_penalty_weight=10):
3030
def define_gan(self, activation_info: Optional[NamedTuple] = None):
3131
self.generator = Generator(self.batch_size). \
3232
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
33-
activation_info=activation_info)
33+
activation_info=activation_info, tau = self.tau)
3434

3535
self.critic = Critic(self.batch_size). \
3636
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -192,14 +192,14 @@ def __init__(self, batch_size):
192192
"""Simple generator with dense feedforward layers."""
193193
self.batch_size = batch_size
194194

195-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
195+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
196196
input_ = Input(shape=input_shape, batch_size=self.batch_size)
197197
x = Dense(dim, activation='relu')(input_)
198198
x = Dense(dim * 2, activation='relu')(x)
199199
x = Dense(dim * 4, activation='relu')(x)
200200
x = Dense(data_dim)(x)
201201
if activation_info:
202-
x = GumbelSoftmaxActivation(activation_info)(x)
202+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
203203
return Model(inputs=input_, outputs=x)
204204

205205
class Critic(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def define_gan(self, col_transform_info: Optional[NamedTuple] = None):
2727
# define generator/discriminator
2828
self.generator = Generator(self.batch_size). \
2929
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
30-
activation_info=col_transform_info)
30+
activation_info=col_transform_info, tau = self.tau)
3131

3232
self.discriminator = Discriminator(self.batch_size). \
3333
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -177,12 +177,12 @@ class Generator(Model):
177177
def __init__(self, batch_size):
178178
self.batch_size = batch_size
179179

180-
def build_model(self, input_shape, dim, data_dim, activation_info: NamedTuple = None):
180+
def build_model(self, input_shape, dim, data_dim, activation_info: NamedTuple = None, tau: Optional[float] = None):
181181
input = Input(shape=input_shape, batch_size = self.batch_size)
182182
x = Dense(dim, kernel_initializer=initializers.TruncatedNormal(mean=0., stddev=0.5), activation='relu')(input)
183183
x = Dense(dim * 2, activation='relu')(x)
184184
x = Dense(dim * 4, activation='relu')(x)
185185
x = Dense(data_dim)(x)
186186
if activation_info:
187-
x = GumbelSoftmaxActivation(activation_info)(x)
187+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
188188
return Model(inputs=input, outputs=x)

src/ydata_synthetic/synthesizers/regular/vanillagan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, model_parameters):
2323
def define_gan(self, activation_info: Optional[NamedTuple]):
2424
self.generator = Generator(self.batch_size).\
2525
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
26-
activation_info = activation_info)
26+
activation_info = activation_info, tau = self.tau)
2727

2828
self.discriminator = Discriminator(self.batch_size).\
2929
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -131,14 +131,14 @@ class Generator(tf.keras.Model):
131131
def __init__(self, batch_size):
132132
self.batch_size=batch_size
133133

134-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
134+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
135135
input= Input(shape=input_shape, batch_size=self.batch_size)
136136
x = Dense(dim, activation='relu')(input)
137137
x = Dense(dim * 2, activation='relu')(x)
138138
x = Dense(dim * 4, activation='relu')(x)
139139
x = Dense(data_dim)(x)
140140
if activation_info:
141-
x = GumbelSoftmaxActivation(activation_info)(x)
141+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
142142
return Model(inputs=input, outputs=x)
143143

144144
class Discriminator(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/wgan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def wasserstein_loss(self, y_true, y_pred):
4545
def define_gan(self, activation_info: Optional[NamedTuple] = None):
4646
self.generator = Generator(self.batch_size). \
4747
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
48-
activation_info=activation_info)
48+
activation_info=activation_info, tau = self.tau)
4949

5050
self.critic = Critic(self.batch_size). \
5151
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -155,14 +155,14 @@ class Generator(tf.keras.Model):
155155
def __init__(self, batch_size):
156156
self.batch_size = batch_size
157157

158-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
158+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
159159
input = Input(shape=input_shape, batch_size=self.batch_size)
160160
x = Dense(dim, activation='relu')(input)
161161
x = Dense(dim * 2, activation='relu')(x)
162162
x = Dense(dim * 4, activation='relu')(x)
163163
x = Dense(data_dim)(x)
164164
if activation_info:
165-
x = GumbelSoftmaxActivation(activation_info)(x)
165+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
166166
return Model(inputs=input, outputs=x)
167167

168168
class Critic(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/wgangp/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, model_parameters, n_critic, gradient_penalty_weight=10):
2828
def define_gan(self, activation_info: Optional[NamedTuple] = None):
2929
self.generator = Generator(self.batch_size). \
3030
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
31-
activation_info=activation_info)
31+
activation_info=activation_info, tau = self.tau)
3232

3333
self.critic = Critic(self.batch_size). \
3434
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -176,14 +176,14 @@ class Generator(tf.keras.Model):
176176
def __init__(self, batch_size):
177177
self.batch_size = batch_size
178178

179-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
179+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
180180
input = Input(shape=input_shape, batch_size=self.batch_size)
181181
x = Dense(dim, activation='relu')(input)
182182
x = Dense(dim * 2, activation='relu')(x)
183183
x = Dense(dim * 4, activation='relu')(x)
184184
x = Dense(data_dim)(x)
185185
if activation_info:
186-
x = GumbelSoftmaxActivation(activation_info)(x)
186+
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
187187
return Model(inputs=input, outputs=x)
188188

189189
class Critic(tf.keras.Model):

src/ydata_synthetic/utils/gumbel_softmax.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@ def gumbel_noise(shape: TensorShape) -> Tensor:
2222

2323
@register_keras_serializable(package='Synthetic Data', name='GumbelSoftmaxLayer')
2424
class GumbelSoftmaxLayer(Layer):
25-
"A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
25+
"""A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits.
2626
27-
def __init__(self, tau: float = 0.2, name: Optional[str] = None, **kwargs):
27+
Arguments:
28+
tau (float): Temperature parameter of the GS layer
29+
name (Optional[str]): Name for a single categorical block
30+
"""
31+
32+
def __init__(self, tau: float, name: Optional[str] = None, **kwargs):
2833
super().__init__(name=name, **kwargs)
2934
self.tau = tau
3035

@@ -54,11 +59,15 @@ class GumbelSoftmaxActivation(Layer):
5459
processor's pipelines in/out feature maps. For simplicity this object can be taken directly from the data \
5560
processor col_transform_info."""
5661

57-
def __init__(self, activation_info: NamedTuple, name: Optional[str] = None, **kwargs):
62+
def __init__(self, activation_info: NamedTuple, name: Optional[str] = None, tau: Optional[float] = None, **kwargs):
5863
"""Arguments:
5964
col_map (NamedTuple): Defines each of the processor pipelines input/output features.
60-
name (Optional[str]): Name of the layer"""
65+
name (Optional[str]): Name of the GumbelSoftmaxActivation layer
66+
tau (Optional[float]): Temperature parameter of the GS layer, must be a float bigger than 0"""
6167
super().__init__(name=name, **kwargs)
68+
self.tau = 0.2 if not tau else tau # Defaults to the default value proposed in the original article
69+
assert isinstance(self.tau, (int, float)) and self.tau > 0, "Optional argument tau must be numerical and \
70+
bigger than 0."
6271

6372
self._activation_info = activation_info
6473

@@ -74,7 +83,7 @@ def call(self, _input): # pylint: disable=W0221
7483
cat_cols = split(cat_cols, self._cat_lens if self._cat_lens else [0], 1, name='split_cats')
7584

7685
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)]
77-
cat_cols = [GumbelSoftmaxLayer(name=name)(col)[0] for name, col in \
86+
cat_cols = [GumbelSoftmaxLayer(tau=self.tau, name=name)(col)[0] for name, col in \
7887
zip(self.cat_feats.feat_names_in, cat_cols)]
7988
return concat(num_cols+cat_cols, 1)
8089

0 commit comments

Comments
 (0)