@@ -56,15 +56,17 @@ def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
5656 self .cat_feats = processor_info .categorical
5757 self .num_feats = processor_info .numerical
5858
59- self ._cat_lens = [len ([col for col in self .cat_feats .feat_names_out if \
60- col . startswith ( cat_feat ) and search ( '_[0-9]*$' , col )]) for cat_feat in self .cat_feats .feat_names_in ]
59+ self ._cat_lens = [len ([col for col in self .cat_feats .feat_names_out if search ( f'^ { cat_feat } _.*$' , col )]) \
60+ for cat_feat in self .cat_feats .feat_names_in ]
6161 self ._num_lens = len (self .num_feats .feat_names_out )
6262
63+ self ._num_activ = Activation ('tanh' , name = 'num_cols_activation' )
64+ self ._cat_activ = [GumbelSoftmaxLayer (name = name ) for name in self .cat_feats .feat_names_in ]
65+
6366 def call (self , _input ): # pylint: disable=W0221
6467 num_cols , cat_cols = split (_input , [self ._num_lens , - 1 ], 1 , name = 'split_num_cats' )
6568 cat_cols = split (cat_cols , self ._cat_lens , 1 , name = 'split_cats' )
6669
67- num_cols = [Activation ('tanh' , name = 'num_cols_activation' )(num_cols )]
68- cat_cols = [GumbelSoftmaxLayer (name = name )(col )[0 ] for name , col in \
69- zip (self .cat_feats .feat_names_in , cat_cols )]
70+ num_cols = [self ._num_activ (num_cols )]
71+ cat_cols = [activ (col )[0 ] for (activ , col ) in zip (self ._cat_activ , cat_cols )]
7072 return concat (num_cols + cat_cols , 1 )
0 commit comments