Skip to content

Commit f25ff47

Browse files
authored
fix: remove gauumbel softmax dependencies (#235)
1 parent cea1d8e commit f25ff47

7 files changed

Lines changed: 3 additions & 11 deletions

File tree

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#Import ydata synthetic classes
2222
from ....synthesizers import TrainParameters
2323
from ....synthesizers.gan import ConditionalModel
24-
from ....utils.gumbel_softmax import GumbelSoftmaxActivation
2524

2625
class CGAN(ConditionalModel):
2726
"CGAN model for discrete conditions"

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ....synthesizers import TrainParameters
1818
from ....synthesizers.gan import BaseModel
1919
from ....synthesizers.loss import Mode, gradient_penalty
20-
from ....utils.gumbel_softmax import GumbelSoftmaxActivation
2120

2221
class CRAMERGAN(BaseModel):
2322

src/ydata_synthetic/synthesizers/regular/cwgangp/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
#Import ydata synthetic classes
1717
from ....synthesizers import TrainParameters
18-
from ....synthesizers.gan import BaseModel, ConditionalModel
18+
from ....synthesizers.gan import ConditionalModel
1919
from ....synthesizers.regular.wgangp.model import WGAN_GP
2020

2121
class CWGANGP(ConditionalModel, WGAN_GP):

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#Import ydata synthetic classes
1515
from ....synthesizers.gan import BaseModel
1616
from ....synthesizers.loss import Mode, gradient_penalty
17-
from ....utils.gumbel_softmax import GumbelSoftmaxActivation
1817

1918
class DRAGAN(BaseModel):
2019

src/ydata_synthetic/synthesizers/regular/vanillagan/model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def __init__(self, model_parameters):
2727

2828
def define_gan(self, activation_info: Optional[NamedTuple]):
2929
self.generator = Generator(self.batch_size).\
30-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
31-
activation_info = activation_info, tau = self.tau)
30+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,)
3231

3332
self.discriminator = Discriminator(self.batch_size).\
3433
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -136,14 +135,12 @@ class Generator(tf.keras.Model):
136135
def __init__(self, batch_size):
137136
self.batch_size=batch_size
138137

139-
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
138+
def build_model(self, input_shape, dim, data_dim):
140139
input= Input(shape=input_shape, batch_size=self.batch_size)
141140
x = Dense(dim, activation='relu')(input)
142141
x = Dense(dim * 2, activation='relu')(x)
143142
x = Dense(dim * 4, activation='relu')(x)
144143
x = Dense(data_dim)(x)
145-
if activation_info:
146-
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
147144
return Model(inputs=input, outputs=x)
148145

149146
class Discriminator(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/wgan/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#Import ydata synthetic classes
1919
from ....synthesizers import TrainParameters
2020
from ....synthesizers.gan import BaseModel
21-
from ....utils.gumbel_softmax import GumbelSoftmaxActivation
2221

2322
#Auxiliary Keras backend class to calculate the Random Weighted average
2423
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras

src/ydata_synthetic/synthesizers/regular/wgangp/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#Import ydata synthetic classes
1818
from ....synthesizers import TrainParameters
1919
from ....synthesizers.gan import BaseModel
20-
from ....utils.gumbel_softmax import GumbelSoftmaxActivation
2120

2221
class WGAN_GP(BaseModel):
2322

0 commit comments

Comments
 (0)