@@ -27,8 +27,7 @@ def __init__(self, model_parameters):
2727
2828 def define_gan (self , activation_info : Optional [NamedTuple ]):
2929 self .generator = Generator (self .batch_size ).\
30- build_model (input_shape = (self .noise_dim ,), dim = self .layers_dim , data_dim = self .data_dim ,
31- activation_info = activation_info , tau = self .tau )
30+ build_model (input_shape = (self .noise_dim ,), dim = self .layers_dim , data_dim = self .data_dim ,)
3231
3332 self .discriminator = Discriminator (self .batch_size ).\
3433 build_model (input_shape = (self .data_dim ,), dim = self .layers_dim )
@@ -136,14 +135,12 @@ class Generator(tf.keras.Model):
136135 def __init__ (self , batch_size ):
137136 self .batch_size = batch_size
138137
139- def build_model (self , input_shape , dim , data_dim , activation_info : Optional [ NamedTuple ] = None , tau : Optional [ float ] = None ):
138+ def build_model (self , input_shape , dim , data_dim ):
140139 input = Input (shape = input_shape , batch_size = self .batch_size )
141140 x = Dense (dim , activation = 'relu' )(input )
142141 x = Dense (dim * 2 , activation = 'relu' )(x )
143142 x = Dense (dim * 4 , activation = 'relu' )(x )
144143 x = Dense (data_dim )(x )
145- if activation_info :
146- x = GumbelSoftmaxActivation (activation_info , tau = tau )(x )
147144 return Model (inputs = input , outputs = x )
148145
149146class Discriminator (tf .keras .Model ):
0 commit comments