Skip to content

Commit dcdab7f

Browse files
authored
feat: add Fabric Regular Synthesizer to Streamlit app (#252)
* feat: Fabric Regular Synthesizer in Streamlit app * feat: add ydata-sdk as requirement for streamlit * feat: allow to overwrite default datatype for Fabric Regular Synthesizer * fix: restore streamlit dependency * feat: rename the SDK synthesizer, improve documentation * fix: type exception
1 parent 149e9ef commit dcdab7f

4 files changed

Lines changed: 127 additions & 16 deletions

File tree

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
"streamlit==1.18.1",
5454
"typing-extensions==3.10.0",
5555
"streamlit_pandas_profiling==0.1.3",
56-
"ydata-profiling==4.0.0"
56+
"ydata-profiling==4.0.0",
57+
"ydata-sdk>=0.2.1"
5758
],
5859
},
5960
)

src/ydata_synthetic/streamlit_app/About.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def main():
4444
- WGAN
4545
- WGANGP
4646
- CTGAN
47+
- **ydata-sdk Synthesizer**
48+
''')
49+
50+
st.success('''In particular, **ydata-sdk Synthesizer** uses [`ydata-sdk`](https://docs.sdk.ydata.ai/) to leverage the state-of-the-art synthesizer model developed by YData.''')
51+
st.info('''
52+
Using **ydata-sdk Synthesizer** requires a valid token. The token is attached to a Fabric account.
53+
In case you do not have an account, you can create one at https://ydata.ai/ydata-fabric-free-trial.
54+
To obtain the token, please, login to https://fabric.ydata.ai.
55+
The token is available on the homepage once you are connected.
4756
''')
4857

4958
#best practives for synthetic data generation

src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from typing import Union
2+
import os
3+
import json
24
import streamlit as st
35

6+
from ydata.sdk.synthesizers import RegularSynthesizer
7+
from ydata.sdk.common.client import get_client
8+
49
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
510
from ydata_synthetic.synthesizers.regular.model import Model
611

@@ -12,7 +17,7 @@ def get_available_models(type: Union[str, DataType]):
1217

1318
dtype = DataType(type)
1419
if dtype == DataType.TABULAR:
15-
models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']]
20+
models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']] + ['ydata-sdk Synthesizer']
1621
else:
1722
st.warning('Time-Series models are not yet supported .')
1823
models_list = ([''])
@@ -35,7 +40,7 @@ def run():
3540
models_list = get_available_models(type=datatype)
3641
model_name = st.selectbox('Select your model', models_list)
3742

38-
if model_name !='':
43+
if model_name not in ['', 'ydata-sdk Synthesizer']:
3944
st.text("Select your synthesizer model parameters")
4045
col1, col2 = st.columns(2)
4146
with col1:
@@ -50,14 +55,14 @@ def run():
5055

5156
# Create the Train parameters
5257
gan_args = ModelParameters(batch_size=batch_size,
53-
lr=lr,
54-
betas=(beta_1, beta_2),
55-
noise_dim=noise_dim,
56-
layers_dim=layer_dim)
58+
lr=lr,
59+
betas=(beta_1, beta_2),
60+
noise_dim=noise_dim,
61+
layers_dim=layer_dim)
5762

5863
model = init_synth(datatype=datatype, modelname=model_name, model_parameters=gan_args)
5964

60-
if model!=None:
65+
if model != None:
6166
st.text("Set your synthesizer training parameters")
6267
#Get the training parameters
6368
epochs, label_col = training_parameters(model_name, df.columns)
@@ -72,11 +77,64 @@ def run():
7277
else:
7378
model.fit(data=df, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_args)
7479

75-
st.success('Synthesizer was trained succesfully!!')
76-
80+
st.success('Synthesizer was trained succesfully!')
7781
st.info(f"The trained model will be saved at {model_path}.")
7882

7983
model.save(model_path)
8084

85+
86+
87+
if model_name == 'ydata-sdk Synthesizer':
88+
valid_token = False
89+
st.text("Model parameters")
90+
col1, col2 = st.columns(2)
91+
with col1:
92+
token = st.text_input("SDK Token", type="password")
93+
os.environ['YDATA_TOKEN'] = token
94+
95+
with col2:
96+
st.write("##")
97+
try:
98+
get_client()
99+
st.text('✅ Valid')
100+
valid_token = True
101+
except Exception:
102+
st.text('❌ Invalid')
103+
104+
if not valid_token:
105+
st.error("""**ydata-sdk Synthesizer requires a valid token.**
106+
In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial.
107+
To obtain the token, please, login to https://fabric.ydata.ai.
108+
The token is available on the homepage once you are connected.
109+
""")
110+
111+
112+
with st.expander('**More settings**'):
113+
model_path = st.text_input("Saved trained model to path:", value="trained_synth.pkl")
114+
115+
st.subheader("3. Train your synthesizer")
116+
if st.button('Click here to start the training process', disabled=not valid_token):
117+
model = RegularSynthesizer()
118+
with st.spinner("Please wait while your synthesizer trains..."):
119+
dtypes = {}
120+
for c in num_cols:
121+
dtypes[c] = 'numerical'
122+
for c in cat_cols:
123+
dtypes[c] = 'categorical'
124+
model.fit(X=df, dtypes=dtypes)
125+
126+
st.success('Synthesizer was trained succesfully!')
127+
st.info(f"The trained model will be saved at {model_path}.")
128+
129+
model_data = {
130+
'uid': model.uid,
131+
'token': os.environ['YDATA_TOKEN']
132+
}
133+
with open(model_path, 'w') as outfile:
134+
json.dump(model_data, outfile)
135+
136+
137+
138+
81139
if __name__ == '__main__':
82140
run()

src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,57 @@
11
import streamlit as st
2+
import json
3+
import os
4+
5+
from ydata.sdk.synthesizers import RegularSynthesizer
6+
from ydata.sdk.common.client import get_client
27

38
from ydata_synthetic.streamlit_app.pages.functions.train import DataType
49
from ydata_synthetic.streamlit_app.pages.functions.generate import load_model, generate_profile
510

611
def run():
712
st.subheader("Generate synthetic data from a trained model")
8-
13+
from_SDK = False
14+
model_data = {}
15+
valid_token = False
916
col1, col2 = st.columns([4, 2])
1017
with col1:
1118
input_path = st.text_input("Provide the path to a trained model", value="trained_synth.pkl")
19+
# Try to load as a JSON as SDK
20+
try:
21+
f = open(input_path)
22+
model_data = json.load(f)
23+
from_SDK = True
24+
except:
25+
pass
26+
27+
if from_SDK:
28+
token = st.text_input("SDK Token", type="password", value=model_data.get('token'))
29+
os.environ['YDATA_TOKEN'] = token
30+
31+
1232
with col2:
1333
datatype = st.selectbox('Select your data type', (DataType.TABULAR.value,))
1434
datatype=DataType(datatype)
1535

36+
if from_SDK and 'YDATA_TOKEN' in os.environ:
37+
st.write("##")
38+
try:
39+
get_client()
40+
st.text('✅ Valid')
41+
valid_token = True
42+
except Exception:
43+
st.text('❌ Invalid')
44+
45+
if from_SDK and 'token' in model_data and not valid_token:
46+
st.warning("The token used during training is not valid anymore. Please, use a new token.")
47+
48+
if from_SDK and not valid_token:
49+
st.error("""**ydata-sdk Synthesizer requires a valid token.**
50+
In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial.
51+
To obtain the token, please, login to https://fabric.ydata.ai.
52+
The token is available on the homepage once you are connected.
53+
""")
54+
1655
col1, col2 = st.columns([4,2])
1756
with col1:
1857
n_samples = st.number_input("Number of samples to generate", min_value=0, value=1000)
@@ -21,14 +60,18 @@ def run():
2160
sample_path = st.text_input("Synthetic samples file path", value='synthetic.csv')
2261

2362
if st.button('Generate samples'):
24-
#load a trained model
25-
model = load_model(input_path=input_path,
26-
datatype=datatype)
63+
if from_SDK:
64+
model = RegularSynthesizer.get(uid=model_data.get('uid'))
65+
66+
else:
67+
model = load_model(input_path=input_path, datatype=datatype)
68+
69+
st.success('The model was properly loaded and is now ready to generate synthetic samples!')
2770

28-
st.success('Trained model was loaded. You can now generate synthetic samples')
2971

3072
#sample synthetic data
31-
synth_data = model.sample(n_samples)
73+
with st.spinner('Generating samples... This might take time.'):
74+
synth_data = model.sample(n_samples)
3275
st.write(synth_data)
3376

3477
#save the synthetic data samples to a given path

0 commit comments

Comments
 (0)