@@ -22,9 +22,14 @@ def gumbel_noise(shape: TensorShape) -> Tensor:
2222
2323@register_keras_serializable (package = 'Synthetic Data' , name = 'GumbelSoftmaxLayer' )
2424class GumbelSoftmaxLayer (Layer ):
25- "A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
25+ """ A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits.
2626
27- def __init__ (self , tau : float = 0.2 , name : Optional [str ] = None , ** kwargs ):
27+ Arguments:
28+ tau (float): Temperature parameter of the GS layer
29+ name (Optional[str]): Name for a single categorical block
30+ """
31+
32+ def __init__ (self , tau : float , name : Optional [str ] = None , ** kwargs ):
2833 super ().__init__ (name = name , ** kwargs )
2934 self .tau = tau
3035
@@ -54,11 +59,15 @@ class GumbelSoftmaxActivation(Layer):
5459 processor's pipelines in/out feature maps. For simplicity this object can be taken directly from the data \
5560 processor col_transform_info."""
5661
57- def __init__ (self , activation_info : NamedTuple , name : Optional [str ] = None , ** kwargs ):
62+ def __init__ (self , activation_info : NamedTuple , name : Optional [str ] = None , tau : Optional [ float ] = None , ** kwargs ):
5863 """Arguments:
5964 col_map (NamedTuple): Defines each of the processor pipelines input/output features.
60- name (Optional[str]): Name of the layer"""
65+ name (Optional[str]): Name of the GumbelSoftmaxActivation layer
66+ tau (Optional[float]): Temperature parameter of the GS layer, must be a float bigger than 0"""
6167 super ().__init__ (name = name , ** kwargs )
68+ self .tau = 0.2 if not tau else tau # Defaults to the default value proposed in the original article
69+ assert isinstance (self .tau , (int , float )) and self .tau > 0 , "Optional argument tau must be numerical and \
70+ bigger than 0."
6271
6372 self ._activation_info = activation_info
6473
@@ -74,7 +83,7 @@ def call(self, _input): # pylint: disable=W0221
7483 cat_cols = split (cat_cols , self ._cat_lens if self ._cat_lens else [0 ], 1 , name = 'split_cats' )
7584
7685 num_cols = [Activation ('tanh' , name = 'num_cols_activation' )(num_cols )]
77- cat_cols = [GumbelSoftmaxLayer (name = name )(col )[0 ] for name , col in \
86+ cat_cols = [GumbelSoftmaxLayer (tau = self . tau , name = name )(col )[0 ] for name , col in \
7887 zip (self .cat_feats .feat_names_in , cat_cols )]
7988 return concat (num_cols + cat_cols , 1 )
8089
0 commit comments