11"""CGAN implementation"""
22import os
33from os import path
4- from typing import List , Tuple , Union , Optional , NamedTuple
4+ from typing import List , Optional , NamedTuple
55
66import numpy as np
77from numpy import array , empty , hstack , ndarray , vstack , save
@@ -30,6 +30,7 @@ class CGAN(BaseModel):
3030 def __init__ (self , model_parameters , num_classes ):
3131 self .num_classes = num_classes
3232 self ._label_col = None
33+ self ._col_order = None
3334 super ().__init__ (model_parameters )
3435
3536 @property
@@ -38,18 +39,9 @@ def label_col(self) -> str:
3839 return self ._label_col
3940
4041 @label_col .setter
41- def label_col (self , data_label : Tuple [Union [DataFrame , array ], str ]):
42- "Validates the label_col format, raises ValueError if invalid."
43- data , label_col = data_label
44- assert label_col in data .columns , f"The column { label_col } could not be found on the provided dataset and \
45- cannot be used as condition."
46- assert data [label_col ].isna ().sum () == 0 , "The label column contains NaN values, please impute or drop the \
47- respective records before proceeding."
48- assert is_float_dtype (data [label_col ]) or is_integer_dtype (data [label_col ]), "The label column is expected to be an \
49- integer or a float dtype to ensure the function of the embedding layer."
50- unique_frac = data [label_col ].nunique ()/ len (data .index )
51- assert unique_frac < 1 , "The provided column {label_col} is constituted by unique values and is not suitable \
52- to be used as condition."
42+ def label_col (self , label_col : str ):
43+ """Set the label_col property."""
44+ self ._label_col = label_col
5345
5446 def define_gan (self , activation_info : Optional [NamedTuple ] = None ):
5547 self .generator = Generator (self .batch_size , self .num_classes ). \
@@ -103,18 +95,20 @@ def get_data_batch(self, data, batch_size, seed=0):
10395 data_ix = np .random .choice (data .shape [0 ], replace = False , size = len (data )) # wasteful to shuffle every time
10496 return data [data_ix [start_i : stop_i ]]
10597
106- def train (self , data : Union [ DataFrame , array ] , label_col : str , train_arguments : TrainParameters , num_cols : List [str ],
98+ def train (self , data : DataFrame , label_col : str , train_arguments : TrainParameters , num_cols : List [str ],
10799 cat_cols : List [str ]):
108100 """
109101 Args:
110- data: A pandas DataFrame or a Numpy array with the data to be synthesized
102+ data: A pandas DataFrame with the data to be synthesized
111103 label: The name of the column to be used as a label and condition for the training
112104 train_arguments: GAN training arguments.
113105 num_cols: List of columns of the data object to be handled as numerical
114106 cat_cols: List of columns of the data object to be handled as categorical
115107 """
116108 # Validating the label column
117- self .label_col = (data , label_col )
109+ self ._validate_label_col (data , label_col )
110+ self ._col_order = data .columns
111+ self .label_col = label_col
118112
119113 # Separating labels from the rest of the data to fit the data processor
120114 data , label = data .loc [:, data .columns != label_col ], expand_dims (data [label_col ], 1 )
@@ -182,15 +176,28 @@ def sample(self, condition: ndarray, n_samples: int,) -> ndarray:
182176 steps = n_samples // self .batch_size + 1
183177 data = []
184178 z_dist = self .get_batch_noise ()
185- condition = expand_dims (convert_to_tensor (condition , dtypes .float32 ), axis = 0 )
186- cond_seq = tile (condition , multiples = [self .batch_size , 1 ])
179+ cond_seq = expand_dims (convert_to_tensor (condition , dtypes .float32 ), axis = 0 )
180+ cond_seq = tile (cond_seq , multiples = [self .batch_size , 1 ])
187181 for _ in trange (steps , desc = 'Synthetic data generation' ):
188182 records = empty (shape = (self .batch_size , self .data_dim ))
189183 records = self .generator ([next (z_dist ), cond_seq ], training = False )
190184 data .append (records )
191185 data = self .processor .inverse_transform (array (vstack (data )))
192- data [self .label_col ] = tile (condition , multiples = [data .shape [0 ], 1 ])
193- return data
186+ data [self .label_col ] = condition [0 ]
187+ return data [self ._col_order ]
188+
189+ @staticmethod
190+ def _validate_label_col (data : DataFrame , label_col : str ):
191+ "Validates the label_col format, raises ValueError if invalid."
192+ assert label_col in data .columns , f"The column { label_col } could not be found on the provided dataset and \
193+ cannot be used as condition."
194+ assert data [label_col ].isna ().sum () == 0 , "The label column contains NaN values, please impute or drop the \
195+ respective records before proceeding."
196+ assert is_float_dtype (data [label_col ]) or is_integer_dtype (data [label_col ]), "The label column is expected to be an \
197+ integer or a float dtype to ensure the function of the embedding layer."
198+ unique_frac = data [label_col ].nunique ()/ len (data .index )
199+ assert unique_frac < 1 , "The provided column {label_col} is constituted by unique values and is not suitable \
200+ to be used as condition."
194201
195202
196203# pylint: disable=R0903
0 commit comments