Skip to content

Commit 972fb99

Browse files
author
Francisco Santos
committed
GS serializable + Remove optimization (no improv)
1 parent d29d3a0 commit 972fb99

1 file changed

Lines changed: 19 additions & 8 deletions

File tree

src/ydata_synthetic/utils/gumbel_softmax.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tensorflow import (Tensor, TensorShape, concat, one_hot, split, squeeze,
88
stop_gradient)
99
from tensorflow.keras.layers import Activation, Layer
10+
from tensorflow.keras.utils import register_keras_serializable
1011
from tensorflow.math import log
1112
from tensorflow.nn import softmax
1213
from tensorflow.random import categorical, uniform
@@ -19,12 +20,12 @@ def gumbel_noise(shape: TensorShape) -> Tensor:
1920
uniform_sample = uniform(shape, seed=0)
2021
return -log(-log(uniform_sample + TOL) + TOL)
2122

22-
23+
@register_keras_serializable(package='Synthetic Data', name='GumbelSoftmaxLayer')
2324
class GumbelSoftmaxLayer(Layer):
2425
"A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
2526

26-
def __init__(self, tau: float = 0.2, name: Optional[str] = None):
27-
super().__init__(name = name)
27+
def __init__(self, tau: float = 0.2, name: Optional[str] = None, **kwargs):
28+
super().__init__(name=name, **kwargs)
2829
self.tau = tau
2930

3031
# pylint: disable=W0221, E1120
@@ -35,6 +36,11 @@ def call(self, _input):
3536
hard_sample = stop_gradient(squeeze(one_hot(categorical(log(soft_sample), 1), _input.shape[-1]), 1))
3637
return hard_sample, soft_sample
3738

39+
def get_config(self):
40+
config = super().get_config().copy()
41+
config.update({'tau': self.tau})
42+
return config
43+
3844

3945
class ActivationInterface(Layer):
4046
"""An interface layer connecting different parts of an incoming tensor to adequate activation functions.
@@ -53,20 +59,25 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
5359
name (Optional[str]): Name of the layer"""
5460
super().__init__(name)
5561

62+
self._processor_info = processor_info
63+
5664
self.cat_feats = processor_info.categorical
5765
self.num_feats = processor_info.numerical
5866

5967
self._cat_lens = [len([col for col in self.cat_feats.feat_names_out if search(f'^{cat_feat}_.*$', col)]) \
6068
for cat_feat in self.cat_feats.feat_names_in]
6169
self._num_lens = len(self.num_feats.feat_names_out)
6270

63-
self._num_activ = Activation('tanh', name='num_cols_activation')
64-
self._cat_activ = [GumbelSoftmaxLayer(name=name) for name in self.cat_feats.feat_names_in]
65-
6671
def call(self, _input): # pylint: disable=W0221
6772
num_cols, cat_cols = split(_input, [self._num_lens, -1], 1, name='split_num_cats')
6873
cat_cols = split(cat_cols, self._cat_lens, 1, name='split_cats')
6974

70-
num_cols = [self._num_activ(num_cols)]
71-
cat_cols = [activ(col)[0] for (activ, col) in zip(self._cat_activ, cat_cols)]
75+
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)]
76+
cat_cols = [GumbelSoftmaxLayer(name=name)(col)[0] for name, col in \
77+
zip(self.cat_feats.feat_names_in, cat_cols)]
7278
return concat(num_cols+cat_cols, 1)
79+
80+
def get_config(self):
81+
config = super().get_config().copy()
82+
config.update({'processor_info': self._processor_info})
83+
return config

0 commit comments

Comments
 (0)