11"Activation Interface layer test suite."
2- from numpy import cumsum , isin , split
2+ from itertools import cycle , islice
3+ from re import search
4+
5+ from numpy import array , cumsum , isin , split
36from numpy import sum as npsum
47from numpy .random import normal
8+ from pandas import DataFrame , concat
59from pytest import fixture
610from tensorflow .keras import Model
711from tensorflow .keras .layers import Dense , Input
812
13+ from ydata_synthetic .preprocessing .regular .processor import \
14+ RegularDataProcessor
915from ydata_synthetic .utils .gumbel_softmax import ActivationInterface
1016
17+ BATCH_SIZE = 10
1118
1219@fixture (name = 'noise_batch' )
1320def fixture_noise_batch ():
1421 "Sample noise for mock output generation."
15- return normal (size = (10 , 16 ))
16-
17- @fixture (name = 'mock_col_map' )
18- def fixture_mock_col_map ():
19- "Mock data processing column map (var blocks i/o names)."
20- return {'numerical' : [
21- [f'nfeat{ n } ' for n in range (6 )],
22- [f'nfeat{ n } ' for n in range (6 )]],
23- 'categorical' : [
24- [f'cfeat{ n } ' for n in range (2 )],
25- sum ([[f'cfeat0_{ i } ' for i in range (4 )], [f'cfeat1_{ i } ' for i in range (2 )]],[])]}
22+ return normal (size = (BATCH_SIZE , 16 ))
23+
24+ @fixture (name = 'mock_data' )
25+ def fixture_mock_data ():
26+ "Creates mock data for the tests."
27+ num_block = DataFrame (normal (size = (BATCH_SIZE , 6 )), columns = [f'num_{ i } ' for i in range (6 )])
28+ cat_block_1 = DataFrame (array (list (islice (cycle (range (2 )), BATCH_SIZE ))), columns = ['cat_0' ])
29+ cat_block_2 = DataFrame (array (list (islice (cycle (range (4 )), BATCH_SIZE ))), columns = ['cat_1' ])
30+ return concat ([num_block , cat_block_1 , cat_block_2 ], axis = 1 )
31+
32+ @fixture (name = 'mock_processor' )
33+ def fixture_mock_processor (mock_data ):
34+ "Creates a mock data processor for the mock data."
35+ num_cols = [col for col in mock_data .columns if col .startswith ('num' )]
36+ cat_cols = [col for col in mock_data .columns if col .startswith ('cat' )]
37+ return RegularDataProcessor (num_cols , cat_cols ).fit (mock_data )
2638
2739# pylint: disable=C0103
2840@fixture (name = 'mock_generator' )
29- def fixture_mock_generator (noise_batch , mock_col_map ):
41+ def fixture_mock_generator (noise_batch , mock_processor ):
3042 "A mock generator with the Activation Interface as final layer."
31- input_ = Input (shape = noise_batch .shape [1 ], batch_size = noise_batch . shape [ 0 ] )
43+ input_ = Input (shape = noise_batch .shape [1 ], batch_size = BATCH_SIZE )
3244 dim = 15
3345 data_dim = 12
3446 x = Dense (dim , activation = 'relu' )(input_ )
3547 x = Dense (dim * 2 , activation = 'relu' )(x )
3648 x = Dense (dim * 4 , activation = 'relu' )(x )
3749 x = Dense (data_dim )(x )
38- x = ActivationInterface (mock_col_map , name = 'act_itf' )(x )
50+ x = ActivationInterface (processor_info = mock_processor . col_transform_info , name = 'act_itf' )(x )
3951 return Model (inputs = input_ , outputs = x )
4052
4153@fixture (name = 'mock_output' )
@@ -44,16 +56,17 @@ def fixture_mock_output(noise_batch, mock_generator):
4456 return mock_generator (noise_batch ).numpy ()
4557
4658# pylint: disable=W0632
47- def test_io (noise_batch , mock_col_map , mock_output ):
59+ def test_io (mock_processor , mock_output ):
4860 "Tests the output format of the activation interface for a known input."
49- num_lens = len (mock_col_map . get ( ' numerical' )[ 1 ] )
50- cat_lens = len (mock_col_map . get ( ' categorical' )[ 1 ] )
51- assert mock_output .shape == (len ( noise_batch ) , num_lens + cat_lens ), "The output has wrong shape."
61+ num_lens = len (mock_processor . col_transform_info . numerical . feat_names_out )
62+ cat_lens = len (mock_processor . col_transform_info . categorical . feat_names_out )
63+ assert mock_output .shape == (BATCH_SIZE , num_lens + cat_lens ), "The output has wrong shape."
5264 num_part , cat_part = split (mock_output , [num_lens ], 1 )
5365 assert not isin (num_part , [0 , 1 ]).all (), "The numerical block is not expected to contain 0 or 1."
5466 assert isin (cat_part , [0 , 1 ]).all (), "The categorical block is expected to contain only 0 or 1."
55- cat_i , cat_o = mock_col_map .get ('categorical' )
56- cat_blocks = cumsum ([len ([col for col in cat_o if '' .join (col .split ('_' )[:- 1 ]) == feat ]) for feat in cat_i ])
67+ cat_i , cat_o = mock_processor .col_transform_info .categorical
68+ cat_blocks = cumsum ([len ([col for col in cat_o if col .startswith (feat ) and search ('_[0-9]*$' , col )]) \
69+ for feat in cat_i ])
5770 cat_blocks = split (cat_part , cat_blocks [:- 1 ], 1 )
58- assert all (npsum (abs (block )) == noise_batch . shape [ 0 ] for block in cat_blocks ), "There are non one-hot encoded \
71+ assert all (npsum (abs (block )) == BATCH_SIZE for block in cat_blocks ), "There are non one-hot encoded \
5972 categorical blocks."
0 commit comments