Skip to content

Commit e3ac8e8

Browse files
author
Francisco Santos
committed
Simplify data preprocessor schema arg
1 parent f4477b8 commit e3ac8e8

2 files changed

Lines changed: 52 additions & 44 deletions

File tree

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,53 @@
11
"Activation Interface layer test suite."
2-
from numpy import cumsum, isin, split
2+
from itertools import cycle, islice
3+
from re import search
4+
5+
from numpy import array, cumsum, isin, split
36
from numpy import sum as npsum
47
from numpy.random import normal
8+
from pandas import DataFrame, concat
59
from pytest import fixture
610
from tensorflow.keras import Model
711
from tensorflow.keras.layers import Dense, Input
812

13+
from ydata_synthetic.preprocessing.regular.processor import \
14+
RegularDataProcessor
915
from ydata_synthetic.utils.gumbel_softmax import ActivationInterface
1016

17+
BATCH_SIZE = 10
1118

1219
@fixture(name='noise_batch')
1320
def fixture_noise_batch():
1421
"Sample noise for mock output generation."
15-
return normal(size=(10, 16))
16-
17-
@fixture(name='mock_col_map')
18-
def fixture_mock_col_map():
19-
"Mock data processing column map (var blocks i/o names)."
20-
return {'numerical': [
21-
[f'nfeat{n}' for n in range(6)],
22-
[f'nfeat{n}' for n in range(6)]],
23-
'categorical': [
24-
[f'cfeat{n}' for n in range(2)],
25-
sum([[f'cfeat0_{i}' for i in range(4)], [f'cfeat1_{i}' for i in range(2)]],[])]}
22+
return normal(size=(BATCH_SIZE, 16))
23+
24+
@fixture(name='mock_data')
25+
def fixture_mock_data():
26+
"Creates mock data for the tests."
27+
num_block = DataFrame(normal(size=(BATCH_SIZE, 6)), columns = [f'num_{i}' for i in range(6)])
28+
cat_block_1 = DataFrame(array(list(islice(cycle(range(2)), BATCH_SIZE))), columns = ['cat_0'])
29+
cat_block_2 = DataFrame(array(list(islice(cycle(range(4)), BATCH_SIZE))), columns = ['cat_1'])
30+
return concat([num_block, cat_block_1, cat_block_2], axis = 1)
31+
32+
@fixture(name='mock_processor')
33+
def fixture_mock_processor(mock_data):
34+
"Creates a mock data processor for the mock data."
35+
num_cols = [col for col in mock_data.columns if col.startswith('num')]
36+
cat_cols = [col for col in mock_data.columns if col.startswith('cat')]
37+
return RegularDataProcessor(num_cols, cat_cols).fit(mock_data)
2638

2739
# pylint: disable=C0103
2840
@fixture(name='mock_generator')
29-
def fixture_mock_generator(noise_batch, mock_col_map):
41+
def fixture_mock_generator(noise_batch, mock_processor):
3042
"A mock generator with the Activation Interface as final layer."
31-
input_ = Input(shape=noise_batch.shape[1], batch_size = noise_batch.shape[0])
43+
input_ = Input(shape=noise_batch.shape[1], batch_size = BATCH_SIZE)
3244
dim = 15
3345
data_dim = 12
3446
x = Dense(dim, activation='relu')(input_)
3547
x = Dense(dim * 2, activation='relu')(x)
3648
x = Dense(dim * 4, activation='relu')(x)
3749
x = Dense(data_dim)(x)
38-
x = ActivationInterface(mock_col_map, name='act_itf')(x)
50+
x = ActivationInterface(processor_info=mock_processor.col_transform_info, name='act_itf')(x)
3951
return Model(inputs=input_, outputs=x)
4052

4153
@fixture(name='mock_output')
@@ -44,16 +56,17 @@ def fixture_mock_output(noise_batch, mock_generator):
4456
return mock_generator(noise_batch).numpy()
4557

4658
# pylint: disable=W0632
47-
def test_io(noise_batch, mock_col_map, mock_output):
59+
def test_io(mock_processor, mock_output):
4860
"Tests the output format of the activation interface for a known input."
49-
num_lens = len(mock_col_map.get('numerical')[1])
50-
cat_lens = len(mock_col_map.get('categorical')[1])
51-
assert mock_output.shape == (len(noise_batch), num_lens + cat_lens), "The output has wrong shape."
61+
num_lens = len(mock_processor.col_transform_info.numerical.feat_names_out)
62+
cat_lens = len(mock_processor.col_transform_info.categorical.feat_names_out)
63+
assert mock_output.shape == (BATCH_SIZE, num_lens + cat_lens), "The output has wrong shape."
5264
num_part, cat_part = split(mock_output, [num_lens], 1)
5365
assert not isin(num_part, [0, 1]).all(), "The numerical block is not expected to contain 0 or 1."
5466
assert isin(cat_part, [0, 1]).all(), "The categorical block is expected to contain only 0 or 1."
55-
cat_i, cat_o = mock_col_map.get('categorical')
56-
cat_blocks = cumsum([len([col for col in cat_o if ''.join(col.split('_')[:-1]) == feat]) for feat in cat_i])
67+
cat_i, cat_o = mock_processor.col_transform_info.categorical
68+
cat_blocks = cumsum([len([col for col in cat_o if col.startswith(feat) and search('_[0-9]*$', col)]) \
69+
for feat in cat_i])
5770
cat_blocks = split(cat_part, cat_blocks[:-1], 1)
58-
assert all(npsum(abs(block)) == noise_batch.shape[0] for block in cat_blocks), "There are non one-hot encoded \
71+
assert all(npsum(abs(block)) == BATCH_SIZE for block in cat_blocks), "There are non one-hot encoded \
5972
categorical blocks."

src/ydata_synthetic/utils/gumbel_softmax.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Gumbel-Softmax layer implementation.
22
Reference: https://arxiv.org/pdf/1611.04051.pdf"""
3-
from typing import Dict, List, Optional
3+
from re import search
4+
from typing import NamedTuple, Optional
45

56
# pylint: disable=E0401
67
from tensorflow import (Tensor, TensorShape, concat, one_hot, split, squeeze,
@@ -42,34 +43,28 @@ class ActivationInterface(Layer):
4243
Processed features of different kind are sent to a TanH activation.
4344
Finally all output parts are concatenated and returned in the same order.
4445
45-
The parts of an incoming tensor are qualified by leveraging a data processor's in/out feature map.
46+
The parts of an incoming tensor are qualified by leveraging a namedtuple pointing to each of the used data \
47+
processor's pipelines in/out feature maps. For simplicity this object can be taken directly from the data \
48+
processor col_transform_info."""
4649

47-
Example of how to get a col_map from a Data Processor ProcessorInfo attribute:
48-
>>> col_map = {k: [v.feat_names_in, v.feat_names_out] for k, v in ProcessorInfo._asdict().items() if v}"""
49-
50-
def __init__(self, col_map: Dict[str, List[List[str]]], name: Optional[str] = None):
50+
def __init__(self, processor_info: NamedTuple, name: Optional[str] = None):
5151
"""Arguments:
52-
col_map (Dict[str, List[List[str]]]): A map defining the processor pipelines input/output features.
52+
col_map (NamedTuple): Defines each of the processor pipelines input/output features.
5353
name (Optional[str]): Name of the layer"""
5454
super().__init__(name)
5555

56-
self.cat_names_i, cat_names_o = col_map.get("categorical", [[],[]])
57-
num_names_i, num_names_o = col_map.get("numerical", [[],[]])
58-
59-
self._cat_lens = None
60-
self._num_lens = None
56+
self.cat_feats = processor_info.categorical
57+
self.num_feats = processor_info.numerical
6158

62-
if self.cat_names_i: # Get the length of each processed categorical feature's output block
63-
self._cat_lens = [len([col for col in cat_names_o \
64-
if ''.join(col.split('_')[:-1]) == cat_feat]) for cat_feat in self.cat_names_i]
65-
if num_names_i: # Get the length of the numerical features output block
66-
self._num_lens = len(num_names_o)
59+
self._cat_lens = [len([col for col in self.cat_feats.feat_names_out if \
60+
col.startswith(cat_feat) and search('_[0-9]*$', col)]) for cat_feat in self.cat_feats.feat_names_in]
61+
self._num_lens = len(self.num_feats.feat_names_out)
6762

6863
def call(self, _input): # pylint: disable=W0221
69-
num_cols, cat_cols = split(_input, [self._num_lens if self._num_lens else 0, -1], 1, name='split_num_cats')
70-
cat_cols = split(cat_cols, self._cat_lens if self._cat_lens else 1, 1, name='split_cats')
64+
num_cols, cat_cols = split(_input, [self._num_lens, -1], 1, name='split_num_cats')
65+
cat_cols = split(cat_cols, self._cat_lens, 1, name='split_cats')
7166

72-
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)] if self._num_lens else []
73-
cat_cols = [GumbelSoftmaxLayer(name=name).call(col)[0] for name, col in zip(self.cat_names_i, cat_cols)] \
74-
if self._cat_lens else []
67+
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)]
68+
cat_cols = [GumbelSoftmaxLayer(name=name)(col)[0] for name, col in \
69+
zip(self.cat_feats.feat_names_in, cat_cols)]
7570
return concat(num_cols+cat_cols, 1)

0 commit comments

Comments
 (0)