77from tensorflow import (Tensor , TensorShape , concat , one_hot , split , squeeze ,
88 stop_gradient )
99from tensorflow .keras .layers import Activation , Layer
10+ from tensorflow .keras .utils import register_keras_serializable
1011from tensorflow .math import log
1112from tensorflow .nn import softmax
1213from tensorflow .random import categorical , uniform
@@ -19,12 +20,12 @@ def gumbel_noise(shape: TensorShape) -> Tensor:
1920 uniform_sample = uniform (shape , seed = 0 )
2021 return - log (- log (uniform_sample + TOL ) + TOL )
2122
22-
23+ @ register_keras_serializable ( package = 'Synthetic Data' , name = 'GumbelSoftmaxLayer' )
2324class GumbelSoftmaxLayer (Layer ):
2425 "A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
2526
26- def __init__ (self , tau : float = 0.2 , name : Optional [str ] = None ):
27- super ().__init__ (name = name )
27+ def __init__ (self , tau : float = 0.2 , name : Optional [str ] = None , ** kwargs ):
28+ super ().__init__ (name = name , ** kwargs )
2829 self .tau = tau
2930
3031 # pylint: disable=W0221, E1120
@@ -35,6 +36,11 @@ def call(self, _input):
3536 hard_sample = stop_gradient (squeeze (one_hot (categorical (log (soft_sample ), 1 ), _input .shape [- 1 ]), 1 ))
3637 return hard_sample , soft_sample
3738
39+ def get_config (self ):
40+ config = super ().get_config ().copy ()
41+ config .update ({'tau' : self .tau })
42+ return config
43+
3844
3945class ActivationInterface (Layer ):
4046 """An interface layer connecting different parts of an incoming tensor to adequate activation functions.
@@ -53,20 +59,25 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
5359 name (Optional[str]): Name of the layer"""
5460 super ().__init__ (name )
5561
62+ self ._processor_info = processor_info
63+
5664 self .cat_feats = processor_info .categorical
5765 self .num_feats = processor_info .numerical
5866
5967 self ._cat_lens = [len ([col for col in self .cat_feats .feat_names_out if search (f'^{ cat_feat } _.*$' , col )]) \
6068 for cat_feat in self .cat_feats .feat_names_in ]
6169 self ._num_lens = len (self .num_feats .feat_names_out )
6270
63- self ._num_activ = Activation ('tanh' , name = 'num_cols_activation' )
64- self ._cat_activ = [GumbelSoftmaxLayer (name = name ) for name in self .cat_feats .feat_names_in ]
65-
6671 def call (self , _input ): # pylint: disable=W0221
6772 num_cols , cat_cols = split (_input , [self ._num_lens , - 1 ], 1 , name = 'split_num_cats' )
6873 cat_cols = split (cat_cols , self ._cat_lens , 1 , name = 'split_cats' )
6974
70- num_cols = [self ._num_activ (num_cols )]
71- cat_cols = [activ (col )[0 ] for (activ , col ) in zip (self ._cat_activ , cat_cols )]
75+ num_cols = [Activation ('tanh' , name = 'num_cols_activation' )(num_cols )]
76+ cat_cols = [GumbelSoftmaxLayer (name = name )(col )[0 ] for name , col in \
77+ zip (self .cat_feats .feat_names_in , cat_cols )]
7278 return concat (num_cols + cat_cols , 1 )
79+
80+ def get_config (self ):
81+ config = super ().get_config ().copy ()
82+ config .update ({'processor_info' : self ._processor_info })
83+ return config
0 commit comments