Skip to content

Commit f4477b8

Browse files
jfsantos-dsFrancisco Santos
authored andcommitted
feat: Gumbel Softmax and Activation Interface
1 parent ef3d5e2 commit f4477b8

4 files changed

Lines changed: 219 additions & 11 deletions

File tree

src/ydata_synthetic/preprocessing/base_processor.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
1-
"Implements a BaseProcessor Class, not meant to be directly instantiated."
1+
"Base class of Data Preprocessors, do not instantiate this class directly."
22
from __future__ import annotations
33

44
from abc import ABC, abstractmethod
5+
from collections import namedtuple
56
from typing import List, Optional
67

7-
from numpy import ndarray
8-
from pandas import DataFrame, Series
8+
from numpy import concatenate, ndarray, split, zeros
9+
from pandas import DataFrame, Series, concat
910
from sklearn.base import BaseEstimator, TransformerMixin
1011
from sklearn.exceptions import NotFittedError
1112
from typeguard import typechecked
1213

14+
ProcessorInfo = namedtuple("ProcessorInfo", ["numerical", "categorical"])
15+
PipelineInfo = namedtuple("PipelineInfo", ["feat_names_in", "feat_names_out"])
1316

17+
# pylint: disable=R0902
1418
@typechecked
1519
class BaseProcessor(ABC, BaseEstimator, TransformerMixin):
1620
"""
17-
Base class for Data Preprocessing.
18-
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
21+
This data processor works like a scikit learn transformer in with the methods fit, transform and inverse transform.
1922
Args:
2023
num_cols (list of strings):
2124
List of names of numerical columns.
2225
cat_cols (list of strings):
2326
List of names of categorical columns.
2427
"""
2528
def __init__(self, num_cols: Optional[List[str]] = None, cat_cols: Optional[List[str]] = None):
26-
2729
self.num_cols = [] if num_cols is None else num_cols
2830
self.cat_cols = [] if cat_cols is None else cat_cols
2931

30-
self._num_pipeline = None
31-
self._cat_pipeline = None
32+
self._num_pipeline = None # To be overriden by child processors
33+
self._cat_pipeline = None # To be overriden by child processors
3234

33-
self._types = None
35+
self._col_transform_info = None # Metadata object mapping inputs/outputs of each pipeline
3436

3537
@property
3638
def num_pipeline(self) -> BaseEstimator:
@@ -47,6 +49,25 @@ def types(self) -> Series:
4749
"""Returns a Series with the dtypes of each column in the fitted DataFrame."""
4850
return self._types
4951

52+
@property
53+
def col_transform_info(self) -> ProcessorInfo:
54+
"""Returns a ProcessorInfo object specifying input/output feature mappings of this processor's pipelines."""
55+
self._check_is_fitted()
56+
if self._col_transform_info is None:
57+
self._col_transform_info = self.__create_metadata_synth()
58+
return self._col_transform_info
59+
60+
def __create_metadata_synth(self):
61+
num_info = None
62+
cat_info = None
63+
# Numerical ls named tuple
64+
if self.num_cols:
65+
num_info = PipelineInfo(self.num_pipeline.feature_names_in_, self.num_pipeline.get_feature_names_out())
66+
# Categorical ls named tuple
67+
if self.cat_cols:
68+
cat_info = PipelineInfo(self.cat_pipeline.feature_names_in_, self.cat_pipeline.get_feature_names_out())
69+
return ProcessorInfo(num_info, cat_info)
70+
5071
def _check_is_fitted(self):
5172
"""Checks if the processor is fitted by testing the numerical pipeline.
5273
Raises NotFittedError if not."""
@@ -86,8 +107,7 @@ def transform(self, X: DataFrame) -> ndarray:
86107
DataFrame used to fit the processor parameters.
87108
Should be aligned with the columns types defined in initialization.
88109
Returns:
89-
transformed (ndarray):
90-
Processed version of the passed DataFrame.
110+
transformed (ndarray): Processed version of the passed DataFrame.
91111
"""
92112
raise NotImplementedError
93113

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"Activation Interface layer test suite."
2+
from numpy import cumsum, isin, split
3+
from numpy import sum as npsum
4+
from numpy.random import normal
5+
from pytest import fixture
6+
from tensorflow.keras import Model
7+
from tensorflow.keras.layers import Dense, Input
8+
9+
from ydata_synthetic.utils.gumbel_softmax import ActivationInterface
10+
11+
12+
@fixture(name='noise_batch')
13+
def fixture_noise_batch():
14+
"Sample noise for mock output generation."
15+
return normal(size=(10, 16))
16+
17+
@fixture(name='mock_col_map')
18+
def fixture_mock_col_map():
19+
"Mock data processing column map (var blocks i/o names)."
20+
return {'numerical': [
21+
[f'nfeat{n}' for n in range(6)],
22+
[f'nfeat{n}' for n in range(6)]],
23+
'categorical': [
24+
[f'cfeat{n}' for n in range(2)],
25+
sum([[f'cfeat0_{i}' for i in range(4)], [f'cfeat1_{i}' for i in range(2)]],[])]}
26+
27+
# pylint: disable=C0103
28+
@fixture(name='mock_generator')
29+
def fixture_mock_generator(noise_batch, mock_col_map):
30+
"A mock generator with the Activation Interface as final layer."
31+
input_ = Input(shape=noise_batch.shape[1], batch_size = noise_batch.shape[0])
32+
dim = 15
33+
data_dim = 12
34+
x = Dense(dim, activation='relu')(input_)
35+
x = Dense(dim * 2, activation='relu')(x)
36+
x = Dense(dim * 4, activation='relu')(x)
37+
x = Dense(data_dim)(x)
38+
x = ActivationInterface(mock_col_map, name='act_itf')(x)
39+
return Model(inputs=input_, outputs=x)
40+
41+
@fixture(name='mock_output')
42+
def fixture_mock_output(noise_batch, mock_generator):
43+
"Returns mock output of the model as a numpy object."
44+
return mock_generator(noise_batch).numpy()
45+
46+
# pylint: disable=W0632
47+
def test_io(noise_batch, mock_col_map, mock_output):
48+
"Tests the output format of the activation interface for a known input."
49+
num_lens = len(mock_col_map.get('numerical')[1])
50+
cat_lens = len(mock_col_map.get('categorical')[1])
51+
assert mock_output.shape == (len(noise_batch), num_lens + cat_lens), "The output has wrong shape."
52+
num_part, cat_part = split(mock_output, [num_lens], 1)
53+
assert not isin(num_part, [0, 1]).all(), "The numerical block is not expected to contain 0 or 1."
54+
assert isin(cat_part, [0, 1]).all(), "The categorical block is expected to contain only 0 or 1."
55+
cat_i, cat_o = mock_col_map.get('categorical')
56+
cat_blocks = cumsum([len([col for col in cat_o if ''.join(col.split('_')[:-1]) == feat]) for feat in cat_i])
57+
cat_blocks = split(cat_part, cat_blocks[:-1], 1)
58+
assert all(npsum(abs(block)) == noise_batch.shape[0] for block in cat_blocks), "There are non one-hot encoded \
59+
categorical blocks."
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"Test suite for the Gumbel-Softmax layer implementation."
2+
import tensorflow as tf
3+
from numpy import amax, amin, isclose, ones
4+
from numpy import sum as npsum
5+
from pytest import fixture
6+
from tensorflow.keras import layers
7+
8+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxLayer
9+
10+
11+
# pylint:disable=W0613
12+
def custom_initializer(shape_list, dtype):
13+
"A constant weight intializer to ensure test reproducibility."
14+
return tf.constant(ones((5, 5)), dtype=tf.dtypes.float32)
15+
16+
@fixture(name='rand_input')
17+
def fixture_rand_input():
18+
"A random, reproducible, input for the mock model."
19+
return tf.constant(tf.random.normal([4, 5], seed=42))
20+
21+
def test_hard_sample_output_format(rand_input):
22+
"""Tests that the hard output samples are in the expected formats.
23+
The hard sample should be returned as a one-hot tensor."""
24+
affined = layers.Dense(5, use_bias = False, kernel_initializer=custom_initializer)(rand_input)
25+
hard_sample, _ = GumbelSoftmaxLayer()(affined)
26+
assert npsum(hard_sample) == hard_sample.shape[0], "The sum of the hard samples should equal the number."
27+
assert all(npsum(hard_sample == 0, 1) == hard_sample.shape[1] - 1), "The hard samples is not a one-hot tensor."
28+
29+
def test_soft_sample_output_format(rand_input):
30+
"""Tests that the soft output samples are in the expected formats.
31+
The soft sample should be returned as a probabilities tensor."""
32+
affined = layers.Dense(5, use_bias = False, kernel_initializer=custom_initializer)(rand_input)
33+
_, soft_sample = GumbelSoftmaxLayer(tau=0.5)(affined)
34+
assert isclose(npsum(soft_sample), soft_sample.shape[0]), "The sum of the soft samples should be close to \
35+
the number of records."
36+
assert amax(soft_sample) <= 1, "Invalid probability values found."
37+
assert amin(soft_sample) >= 0, "Invalid probability values found."
38+
39+
def test_gradients(rand_input):
40+
"Performs basic numerical assertions on the gradients of the sof/hard samples."
41+
def mock(i):
42+
return GumbelSoftmaxLayer()(layers.Dense(5, use_bias=False, kernel_initializer=custom_initializer)(i))
43+
with tf.GradientTape() as hard_tape:
44+
hard_tape.watch(rand_input)
45+
hard_sample, _ = mock(rand_input)
46+
with tf.GradientTape() as soft_tape:
47+
soft_tape.watch(rand_input)
48+
_, soft_sample = mock(rand_input)
49+
hard_grads = hard_tape.gradient(hard_sample, rand_input)
50+
soft_grads = soft_tape.gradient(soft_sample, rand_input)
51+
52+
assert hard_grads is None, "The hard sample must not compute gradients."
53+
assert soft_grads is not None, "The soft sample is expected to compute gradients."
54+
assert npsum(abs(soft_grads)) != 0, "The soft sample is expected to have non-zero gradients."
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Gumbel-Softmax layer implementation.
2+
Reference: https://arxiv.org/pdf/1611.04051.pdf"""
3+
from typing import Dict, List, Optional
4+
5+
# pylint: disable=E0401
6+
from tensorflow import (Tensor, TensorShape, concat, one_hot, split, squeeze,
7+
stop_gradient)
8+
from tensorflow.keras.layers import Activation, Layer
9+
from tensorflow.math import log
10+
from tensorflow.nn import softmax
11+
from tensorflow.random import categorical, uniform
12+
13+
TOL = 1e-20
14+
15+
16+
def gumbel_noise(shape: TensorShape) -> Tensor:
17+
"""Create a single sample from the standard (loc = 0, scale = 1) Gumbel distribution."""
18+
uniform_sample = uniform(shape, seed=0)
19+
return -log(-log(uniform_sample + TOL) + TOL)
20+
21+
22+
class GumbelSoftmaxLayer(Layer):
23+
"A Gumbel-Softmax layer implementation that should be stacked on top of a categorical feature logits."
24+
25+
def __init__(self, tau: float = 0.2, name: Optional[str] = None):
26+
super().__init__(name = name)
27+
self.tau = tau
28+
29+
# pylint: disable=W0221, E1120
30+
def call(self, _input):
31+
"""Computes Gumbel-Softmax for the logits output of a particular categorical feature."""
32+
noised_input = _input + gumbel_noise(_input.shape)
33+
soft_sample = softmax(noised_input/self.tau, -1)
34+
hard_sample = stop_gradient(squeeze(one_hot(categorical(log(soft_sample), 1), _input.shape[-1]), 1))
35+
return hard_sample, soft_sample
36+
37+
38+
class ActivationInterface(Layer):
39+
"""An interface layer connecting different parts of an incoming tensor to adequate activation functions.
40+
The tensor parts are qualified according to the passed processor object.
41+
Processed categorical features are sent to specific Gumbel-Softmax layers.
42+
Processed features of different kind are sent to a TanH activation.
43+
Finally all output parts are concatenated and returned in the same order.
44+
45+
The parts of an incoming tensor are qualified by leveraging a data processor's in/out feature map.
46+
47+
Example of how to get a col_map from a Data Processor ProcessorInfo attribute:
48+
>>> col_map = {k: [v.feat_names_in, v.feat_names_out] for k, v in ProcessorInfo._asdict().items() if v}"""
49+
50+
def __init__(self, col_map: Dict[str, List[List[str]]], name: Optional[str] = None):
51+
"""Arguments:
52+
col_map (Dict[str, List[List[str]]]): A map defining the processor pipelines input/output features.
53+
name (Optional[str]): Name of the layer"""
54+
super().__init__(name)
55+
56+
self.cat_names_i, cat_names_o = col_map.get("categorical", [[],[]])
57+
num_names_i, num_names_o = col_map.get("numerical", [[],[]])
58+
59+
self._cat_lens = None
60+
self._num_lens = None
61+
62+
if self.cat_names_i: # Get the length of each processed categorical feature's output block
63+
self._cat_lens = [len([col for col in cat_names_o \
64+
if ''.join(col.split('_')[:-1]) == cat_feat]) for cat_feat in self.cat_names_i]
65+
if num_names_i: # Get the length of the numerical features output block
66+
self._num_lens = len(num_names_o)
67+
68+
def call(self, _input): # pylint: disable=W0221
69+
num_cols, cat_cols = split(_input, [self._num_lens if self._num_lens else 0, -1], 1, name='split_num_cats')
70+
cat_cols = split(cat_cols, self._cat_lens if self._cat_lens else 1, 1, name='split_cats')
71+
72+
num_cols = [Activation('tanh', name='num_cols_activation')(num_cols)] if self._num_lens else []
73+
cat_cols = [GumbelSoftmaxLayer(name=name).call(col)[0] for name, col in zip(self.cat_names_i, cat_cols)] \
74+
if self._cat_lens else []
75+
return concat(num_cols+cat_cols, 1)

0 commit comments

Comments
 (0)