Skip to content

Commit d29d3a0

Browse files
author
Francisco Santos
committed
Generalize cat_lens property + optimize runtime
1 parent e3ac8e8 commit d29d3a0

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/ydata_synthetic/utils/gumbel_softmax.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,17 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
5656
self.cat_feats = processor_info.categorical
5757
self.num_feats = processor_info.numerical
5858

59-
self._cat_lens = [len([col for col in self.cat_feats.feat_names_out if \
60-
col.startswith(cat_feat) and search('_[0-9]*$', col)]) for cat_feat in self.cat_feats.feat_names_in]
59+
self._cat_lens = [len([col for col in self.cat_feats.feat_names_out if search(f'^{cat_feat}_.*$', col)]) \
60+
for cat_feat in self.cat_feats.feat_names_in]
6161
self._num_lens = len(self.num_feats.feat_names_out)
6262

63+
self._num_activ = Activation('tanh', name='num_cols_activation')
64+
self._cat_activ = [GumbelSoftmaxLayer(name=name) for name in self.cat_feats.feat_names_in]
65+
6366
def call(self, _input): # pylint: disable=W0221
6467
num_cols, cat_cols = split(_input, [self._num_lens, -1], 1, name='split_num_cats')
6568
cat_cols = split(cat_cols, self._cat_lens, 1, name='split_cats')
6669

67-
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)]
68-
cat_cols = [GumbelSoftmaxLayer(name=name)(col)[0] for name, col in \
69-
zip(self.cat_feats.feat_names_in, cat_cols)]
70+
num_cols = [self._num_activ(num_cols)]
71+
cat_cols = [activ(col)[0] for (activ, col) in zip(self._cat_activ, cat_cols)]
7072
return concat(num_cols+cat_cols, 1)

0 commit comments

Comments
 (0)