Skip to content

Commit 81abe1d

Browse files
authored
feat: add new gmm based synth for fast synthesis (#269)
* feat: Add new GMM model for fast synthesis * feat: add save and load for new model * fix: synthesis base class * fix: linter * fix: linter warnings
1 parent 2f6fd89 commit 81abe1d

17 files changed

Lines changed: 415 additions & 36 deletions

File tree

README.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,30 @@ Join us on [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the
1212
# YData Synthetic
1313
A package to generate synthetic tabular and time-series data leveraging the state of the art generative models.
1414

15-
## 🎊 We have **big news**: v1.0.0 is here
16-
> We have exciting news for you. The new version of `ydata-synthetic` include new and exciting features:
15+
## 🎊 The exciting features:
16+
> These are must try features whne it comes to synthetic data generation:
17+
> - A new streamlit app that delivers the synthetic data generation experience with a UI interface. A low code experience for the quick generation of synthetic data
18+
> - A new fast synthetic data generation model based on Gaussian Mixture. So you can quickstart in the world of synthetic data generation without the need for a GPU.
1719
> - A conditional architecture for tabular data: CTGAN, which will make the process of synthetic data generation easier and with higher quality!
18-
> - A new streamlit app that delivers the synthetic data generation experience with a UI interface
19-
20+
2021
## Synthetic data
2122
### What is synthetic data?
2223
Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals' privacy.
2324

2425
### Why Synthetic Data?
2526
Synthetic data can be used for many applications:
26-
- Privacy
27+
- Privacy compliance for data-sharing and Machine Learning development
2728
- Remove bias
2829
- Balance datasets
2930
- Augment datasets
3031

3132
# ydata-synthetic
32-
This repository contains material related with Generative Adversarial Networks for synthetic data generation, in particular regular tabular data and time-series.
33-
It consists a set of different GANs architectures developed using Tensorflow 2.0. Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures.
33+
This repository contains material related with architectures and models for synthetic data, from Generative Adversarial Networks (GANs) to Gaussian Mixtures.
34+
The repo includes a full ecosystem for synthetic data generation, that includes different models for the generation of synthetic structure data and time-series.
35+
All the Deep Learning models are implemented leveraging Tensorflow 2.0.
36+
Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures.
37+
38+
Are you ready to learn more about synthetic data and the bext-practices for synthetic data generation?
3439

3540
## Quickstart
3641
The source code is currently hosted on GitHub at: https://github.com/ydataai/ydata-synthetic
@@ -78,8 +83,8 @@ The below models are supported:
7883

7984
### Examples
8085
Here you can find usage examples of the package and models to synthesize tabular data.
81-
82-
- Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Data-Centric-AI-Community/awesome-python-for-data-science/blob/main/workshop-ds/Workshop%20-%20Data-Centric%20AI%20pipelines%20-%20How%20and%20why.ipynb)
86+
- Fast tabular data synthesis on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb)
87+
- Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/CTGAN_Adult_Census_Income_Data.ipynb)
8388
- Time Series synthetic data generation with TimeGAN on stock dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/timeseries/TimeGAN_Synthetic_stock_data.ipynb)
8489
- More examples are continuously added and can be found in `/examples` directory.
8590

@@ -106,6 +111,7 @@ In this repository you can find the several GAN architectures that are used to c
106111
- [Cramer GAN (The Cramer Distance as a Solution to Biased Wasserstein Gradients)](https://arxiv.org/abs/1705.10743)
107112
- [CWGAN-GP (Conditional Wassertein GAN with Gradient Penalty)](https://cameronfabbri.github.io/papers/conditionalWGAN.pdf)
108113
- [CTGAN (Conditional Tabular GAN)](https://arxiv.org/pdf/1907.00503.pdf)
114+
- [Gaussian Mixture](https://towardsdatascience.com/gaussian-mixture-models-explained-6986aaf5a95)
109115

110116
### Sequential data
111117
- [TimeGAN](https://papers.nips.cc/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf)
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
},
12+
"language_info": {
13+
"name": "python"
14+
},
15+
"accelerator": "GPU",
16+
"gpuClass": "standard"
17+
},
18+
"cells": [
19+
{
20+
"cell_type": "code",
21+
"source": [
22+
"#Uncomment to install ydata-synthetic lib\n",
23+
"#!pip install ydata-synthetic"
24+
],
25+
"metadata": {
26+
"id": "fwXSWiYu_tl0",
27+
"pycharm": {
28+
"name": "#%%\n"
29+
}
30+
},
31+
"execution_count": null,
32+
"outputs": []
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"source": [
37+
"# Tabular Synthetic Data Generation with Gaussian Mixture\n",
38+
"- This notebook is an example of how to use a synthetic data generation methods based on [GMM](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to generate synthetic tabular data with numeric and categorical features.\n",
39+
"\n",
40+
"## Dataset\n",
41+
"- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n"
42+
],
43+
"metadata": {
44+
"id": "6T8gjToi_yKA",
45+
"pycharm": {
46+
"name": "#%% md\n"
47+
}
48+
}
49+
},
50+
{
51+
"cell_type": "code",
52+
"source": [
53+
"from pmlb import fetch_data\n",
54+
"\n",
55+
"from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n",
56+
"from ydata_synthetic.synthesizers import ModelParameters, TrainParameters"
57+
],
58+
"metadata": {
59+
"id": "Ix4gZ9iSCVZI",
60+
"pycharm": {
61+
"name": "#%%\n"
62+
}
63+
},
64+
"execution_count": null,
65+
"outputs": []
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"source": [
70+
"## Load the data"
71+
],
72+
"metadata": {
73+
"id": "I0qyPwoECZ5x",
74+
"pycharm": {
75+
"name": "#%% md\n"
76+
}
77+
}
78+
},
79+
{
80+
"cell_type": "code",
81+
"source": [
82+
"# Load data\n",
83+
"data = fetch_data('adult')\n",
84+
"num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
85+
"cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
86+
" 'native-country', 'target']"
87+
],
88+
"metadata": {
89+
"id": "YeFPnJVOMVqd",
90+
"pycharm": {
91+
"name": "#%%\n"
92+
}
93+
},
94+
"execution_count": 2,
95+
"outputs": []
96+
},
97+
{
98+
"cell_type": "markdown",
99+
"source": [
100+
"## Create and Train the synthetic data generator"
101+
],
102+
"metadata": {
103+
"id": "68MoepO0Cpx6",
104+
"pycharm": {
105+
"name": "#%% md\n"
106+
}
107+
}
108+
},
109+
{
110+
"cell_type": "code",
111+
"source": [
112+
"synth = RegularSynthesizer(modelname='fast')\n",
113+
"synth.fit(data=data, num_cols=num_cols, cat_cols=cat_cols)"
114+
],
115+
"metadata": {
116+
"id": "oIHMVgSZMg8_",
117+
"pycharm": {
118+
"name": "#%%\n"
119+
}
120+
},
121+
"execution_count": null,
122+
"outputs": []
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"source": [
127+
"## Generate new synthetic data"
128+
],
129+
"metadata": {
130+
"id": "xHK-SRPyDUin",
131+
"pycharm": {
132+
"name": "#%% md\n"
133+
}
134+
}
135+
},
136+
{
137+
"cell_type": "code",
138+
"source": [
139+
"synth_data = synth.sample(1000)\n",
140+
"print(synth_data)"
141+
],
142+
"metadata": {
143+
"id": "0aa2g0RLMkqe",
144+
"colab": {
145+
"base_uri": "https://localhost:8080/"
146+
},
147+
"outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0",
148+
"pycharm": {
149+
"name": "#%%\n"
150+
}
151+
},
152+
"execution_count": 8,
153+
"outputs": [
154+
{
155+
"output_type": "stream",
156+
"name": "stdout",
157+
"text": [
158+
" age workclass fnlwgt education education-num \\\n",
159+
"0 38.753654 4 179993.565472 8 10.0 \n",
160+
"1 36.408844 4 245841.807958 9 10.0 \n",
161+
"2 56.251066 4 400895.076058 11 13.0 \n",
162+
"3 26.846605 4 240156.201048 11 10.0 \n",
163+
"4 29.083102 1 5601.059126 11 9.0 \n",
164+
".. ... ... ... ... ... \n",
165+
"995 79.281276 4 30664.183560 1 10.0 \n",
166+
"996 51.423132 4 414524.980527 1 10.0 \n",
167+
"997 17.342915 6 177716.451926 11 13.0 \n",
168+
"998 39.298867 4 132011.369567 15 12.0 \n",
169+
"999 46.977763 2 92662.371635 9 13.0 \n",
170+
"\n",
171+
" marital-status occupation relationship race sex capital-gain \\\n",
172+
"0 4 0 3 4 0 55.771499 \n",
173+
"1 6 7 0 4 1 124.337939 \n",
174+
"2 4 3 3 4 1 27.968087 \n",
175+
"3 4 6 1 4 0 25.065678 \n",
176+
"4 6 3 0 4 0 126.269337 \n",
177+
".. ... ... ... ... ... ... \n",
178+
"995 2 0 3 4 1 4.393001 \n",
179+
"996 4 7 3 2 0 54.841598 \n",
180+
"997 4 4 4 4 0 99.394428 \n",
181+
"998 4 14 1 4 1 97.834797 \n",
182+
"999 4 8 1 4 0 51.258308 \n",
183+
"\n",
184+
" capital-loss hours-per-week native-country target \n",
185+
"0 -1.271118 39.749641 39 1 \n",
186+
"1 -2.114950 44.488198 39 1 \n",
187+
"2 1.541738 40.042696 39 1 \n",
188+
"3 1.148560 39.952615 39 1 \n",
189+
"4 -1.786768 39.808085 39 0 \n",
190+
".. ... ... ... ... \n",
191+
"995 0.224015 50.580637 39 1 \n",
192+
"996 1.319341 4.441194 39 1 \n",
193+
"997 -5.231663 39.779674 39 1 \n",
194+
"998 1.595817 39.731359 13 1 \n",
195+
"999 1.129814 39.838415 39 1 \n",
196+
"\n",
197+
"[1000 rows x 15 columns]\n"
198+
]
199+
}
200+
]
201+
}
202+
]
203+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from ydata_synthetic.preprocessing.regular.processor import RegularDataProcessor
2+
from ydata_synthetic.preprocessing.timeseries.timeseries_processor import TimeSeriesDataProcessor
3+
4+
__all__ = [
5+
"RegularDataProcessor",
6+
"TimeSeriesDataProcessor"
7+
]
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
from ydata_synthetic.synthesizers.gan import ModelParameters, TrainParameters
1+
from ydata_synthetic.synthesizers.base import ModelParameters, TrainParameters
2+
3+
__all__ = [
4+
"ModelParameters",
5+
"TrainParameters"
6+
]
Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"Implements a GAN BaseModel synthesizer, not meant to be directly instantiated."
2+
from abc import ABC, abstractmethod
23
from collections import namedtuple
34
from typing import List, Optional, Union
45

6+
import pandas as pd
57
import tqdm
68

79
from numpy import array, vstack, ndarray
@@ -40,19 +42,57 @@
4042
ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df)
4143
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True))
4244

45+
@typechecked
46+
class BaseModel(ABC):
47+
"""
48+
Abstract class for synthetic data generation nmodels
49+
50+
The main methods are train (for fitting the synthesizer), save/load and sample (generating synthetic records).
51+
52+
"""
53+
__MODEL__ = None
54+
55+
@abstractmethod
56+
def fit(self, data: Union[DataFrame, array],
57+
num_cols: Optional[List[str]] = None,
58+
cat_cols: Optional[List[str]] = None):
59+
"""
60+
### Description:
61+
Trains and fit a synthesizer model to a given input dataset.
62+
63+
### Args:
64+
`data` (Union[DataFrame, array]): Training data
65+
`num_cols` (Optional[List[str]]) : List with the names of the categorical columns
66+
`cat_cols` (Optional[List[str]]): List of names of categorical columns
67+
68+
### Returns:
69+
**self:** *object*
70+
Fitted synthesizer
71+
"""
72+
...
73+
@abstractmethod
74+
def sample(self, n_samples:int) -> pd.DataFrame:
75+
assert n_samples>0, "Please insert a value bigger than 0 for n_samples parameter."
76+
...
77+
78+
@classmethod
79+
def load(cls, path: str):
80+
...
81+
82+
@abstractmethod
83+
def save(self, path: str):
84+
...
4385

4486
# pylint: disable=R0902
4587
@typechecked
46-
class BaseModel():
88+
class BaseGANModel(BaseModel):
4789
"""
4890
Base class of GAN synthesizer models.
4991
The main methods are train (for fitting the synthesizer), save/load and sample (obtain synthetic records).
5092
Args:
5193
model_parameters (ModelParameters):
5294
Set of architectural parameters for model definition.
5395
"""
54-
__MODEL__ = None
55-
5696
def __init__(
5797
self,
5898
model_parameters: ModelParameters
@@ -84,7 +124,7 @@ def __init__(
84124
self.gp_lambda = model_parameters.gp_lambda
85125
self.pac = model_parameters.pac
86126

87-
self.processor = None
127+
self.processor=None
88128
if self.__MODEL__ in RegularModels.__members__ or \
89129
self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
90130
self.tau = model_parameters.tau_gs
@@ -183,8 +223,8 @@ def save(self, path):
183223
make_keras_picklable()
184224
dump(self, path)
185225

186-
@staticmethod
187-
def load(path):
226+
@classmethod
227+
def load(cls, path):
188228
"""
189229
### Description:
190230
Loads a saved synthesizer from a pickle.

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#Import ydata synthetic classes
2222
from ....synthesizers import TrainParameters
23-
from ....synthesizers.gan import ConditionalModel
23+
from ....synthesizers.base import ConditionalModel
2424

2525
class CGAN(ConditionalModel):
2626
"CGAN model for discrete conditions"

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
#Import ydata synthetic classes
1717
from ....synthesizers import TrainParameters
18-
from ....synthesizers.gan import BaseModel
18+
from ....synthesizers.base import BaseGANModel
1919
from ....synthesizers.loss import Mode, gradient_penalty
2020

21-
class CRAMERGAN(BaseModel):
21+
class CRAMERGAN(BaseGANModel):
2222

2323
__MODEL__='CRAMERGAN'
2424

0 commit comments

Comments
 (0)