@@ -42,6 +42,7 @@ def get_config(self):
4242 return config
4343
4444
45+ @register_keras_serializable (package = 'Synthetic Data' , name = 'ActivationInterface' )
4546class ActivationInterface (Layer ):
4647 """An interface layer connecting different parts of an incoming tensor to adequate activation functions.
4748 The tensor parts are qualified according to the passed processor object.
@@ -53,11 +54,11 @@ class ActivationInterface(Layer):
5354 processor's pipelines in/out feature maps. For simplicity this object can be taken directly from the data \
5455 processor col_transform_info."""
5556
56- def __init__ (self , processor_info : NamedTuple , name : Optional [str ] = None ):
57+ def __init__ (self , processor_info : NamedTuple , name : Optional [str ] = None , ** kwargs ):
5758 """Arguments:
5859 col_map (NamedTuple): Defines each of the processor pipelines input/output features.
5960 name (Optional[str]): Name of the layer"""
60- super ().__init__ (name )
61+ super ().__init__ (name = name , ** kwargs )
6162
6263 self ._processor_info = processor_info
6364
@@ -70,7 +71,7 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
7071
7172 def call (self , _input ): # pylint: disable=W0221
7273 num_cols , cat_cols = split (_input , [self ._num_lens , - 1 ], 1 , name = 'split_num_cats' )
73- cat_cols = split (cat_cols , self ._cat_lens , 1 , name = 'split_cats' )
74+ cat_cols = split (cat_cols , self ._cat_lens if self . _cat_lens else [ 0 ] , 1 , name = 'split_cats' )
7475
7576 num_cols = [Activation ('tanh' , name = 'num_cols_activation' )(num_cols )]
7677 cat_cols = [GumbelSoftmaxLayer (name = name )(col )[0 ] for name , col in \
0 commit comments