Skip to content

Commit 65ffba4

Browse files
author
Francisco Santos
committed
Interface serializable + fix no cat feats error
1 parent 972fb99 commit 65ffba4

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/ydata_synthetic/utils/gumbel_softmax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_config(self):
4242
return config
4343

4444

45+
@register_keras_serializable(package='Synthetic Data', name='ActivationInterface')
4546
class ActivationInterface(Layer):
4647
"""An interface layer connecting different parts of an incoming tensor to adequate activation functions.
4748
The tensor parts are qualified according to the passed processor object.
@@ -53,11 +54,11 @@ class ActivationInterface(Layer):
5354
processor's pipelines in/out feature maps. For simplicity this object can be taken directly from the data \
5455
processor col_transform_info."""
5556

56-
def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
57+
def __init__(self, processor_info: NamedTuple, name: Optional[str] = None, **kwargs):
5758
"""Arguments:
5859
col_map (NamedTuple): Defines each of the processor pipelines input/output features.
5960
name (Optional[str]): Name of the layer"""
60-
super().__init__(name)
61+
super().__init__(name=name, **kwargs)
6162

6263
self._processor_info = processor_info
6364

@@ -70,7 +71,7 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
7071

7172
def call(self, _input): # pylint: disable=W0221
7273
num_cols, cat_cols = split(_input, [self._num_lens, -1], 1, name='split_num_cats')
73-
cat_cols = split(cat_cols, self._cat_lens, 1, name='split_cats')
74+
cat_cols = split(cat_cols, self._cat_lens if self._cat_lens else [0], 1, name='split_cats')
7475

7576
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)]
7677
cat_cols = [GumbelSoftmaxLayer(name=name)(col)[0] for name, col in \

0 commit comments

Comments
 (0)