Skip to content

Commit 149e9ef

Browse files
chore: add ctgan colab notebook (#248)
Adds ctgan notebook example. Updates dataset url to kaggle for consistency. Updates open in colab url (substitutes vanilla gan example).
1 parent dea3c83 commit 149e9ef

2 files changed

Lines changed: 224 additions & 2 deletions

File tree

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,15 @@ The below models are supported:
7878

7979
### Examples
8080
Here you can find usage examples of the package and models to synthesize tabular data.
81-
- Synthesizing the minority class with VanillaGAN on credit fraud dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/gan_example.ipynb)
81+
82+
- Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/CTGAN_Adult_Census_Income_Data.ipynb)
8283
- Time Series synthetic data generation with TimeGAN on stock dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/timeseries/TimeGAN_Synthetic_stock_data.ipynb)
8384
- More examples are continuously added and can be found in `/examples` directory.
8485

8586
### Datasets for you to experiment
8687
Here are some example datasets for you to try with the synthesizers:
8788
#### Tabular datasets
88-
- [Adult census](https://archive.ics.uci.edu/ml/datasets/adult)
89+
- [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income)
8990
- [Credit card fraud](https://www.kaggle.com/mlg-ulb/creditcardfraud)
9091
- [Cardiovascular Disease dataset](https://www.kaggle.com/datasets/sulianova/cardiovascular-disease-dataset)
9192

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
},
12+
"language_info": {
13+
"name": "python"
14+
},
15+
"accelerator": "GPU",
16+
"gpuClass": "standard"
17+
},
18+
"cells": [
19+
{
20+
"cell_type": "code",
21+
"source": [
22+
"# Note: You can select between running the Notebook on \"CPU\" or \"GPU\"\n",
23+
"# Click \"Runtime > Change Runtime time\" and set \"GPU\""
24+
],
25+
"metadata": {
26+
"id": "Kh7c1F1J_sD7"
27+
},
28+
"execution_count": null,
29+
"outputs": []
30+
},
31+
{
32+
"cell_type": "code",
33+
"source": [
34+
"#Uncomment to install ydata-synthetic lib\n",
35+
"#!pip install ydata-synthetic"
36+
],
37+
"metadata": {
38+
"id": "fwXSWiYu_tl0"
39+
},
40+
"execution_count": null,
41+
"outputs": []
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"source": [
46+
"# Tabular Synthetic Data Generation with CTGAN\n",
47+
"- CTGAN - Implemented accordingly with the [paper](https://arxiv.org/pdf/1907.00503.pdf)\n",
48+
"- This notebook is an example of how to use CTGAN to generate synthetic tabular data with numeric and categorical features.\n",
49+
"\n",
50+
"## Dataset\n",
51+
"\n",
52+
"- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n"
53+
],
54+
"metadata": {
55+
"id": "6T8gjToi_yKA"
56+
}
57+
},
58+
{
59+
"cell_type": "code",
60+
"source": [
61+
"from pmlb import fetch_data\n",
62+
"\n",
63+
"from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n",
64+
"from ydata_synthetic.synthesizers import ModelParameters, TrainParameters"
65+
],
66+
"metadata": {
67+
"id": "Ix4gZ9iSCVZI"
68+
},
69+
"execution_count": null,
70+
"outputs": []
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"source": [
75+
"## Load the data"
76+
],
77+
"metadata": {
78+
"id": "I0qyPwoECZ5x"
79+
}
80+
},
81+
{
82+
"cell_type": "code",
83+
"source": [
84+
"# Load data\n",
85+
"data = fetch_data('adult')\n",
86+
"num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
87+
"cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
88+
" 'native-country', 'target']"
89+
],
90+
"metadata": {
91+
"id": "YeFPnJVOMVqd"
92+
},
93+
"execution_count": 2,
94+
"outputs": []
95+
},
96+
{
97+
"cell_type": "markdown",
98+
"source": [
99+
"## Define model and training parameters"
100+
],
101+
"metadata": {
102+
"id": "m6-dt5hLCgxG"
103+
}
104+
},
105+
{
106+
"cell_type": "code",
107+
"source": [
108+
"# Defining the training parameters\n",
109+
"batch_size = 500\n",
110+
"epochs = 500+1\n",
111+
"learning_rate = 2e-4\n",
112+
"beta_1 = 0.5\n",
113+
"beta_2 = 0.9\n",
114+
"\n",
115+
"ctgan_args = ModelParameters(batch_size=batch_size,\n",
116+
" lr=learning_rate,\n",
117+
" betas=(beta_1, beta_2))\n",
118+
"\n",
119+
"train_args = TrainParameters(epochs=epochs)"
120+
],
121+
"metadata": {
122+
"id": "9SsyBS2nMaSA"
123+
},
124+
"execution_count": 1,
125+
"outputs": []
126+
},
127+
{
128+
"cell_type": "markdown",
129+
"source": [
130+
"## Create and Train the CTGAN"
131+
],
132+
"metadata": {
133+
"id": "68MoepO0Cpx6"
134+
}
135+
},
136+
{
137+
"cell_type": "code",
138+
"source": [
139+
"synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)\n",
140+
"synth.fit(data=data, train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)"
141+
],
142+
"metadata": {
143+
"id": "oIHMVgSZMg8_"
144+
},
145+
"execution_count": null,
146+
"outputs": []
147+
},
148+
{
149+
"cell_type": "markdown",
150+
"source": [
151+
"## Generate new synthetic data"
152+
],
153+
"metadata": {
154+
"id": "xHK-SRPyDUin"
155+
}
156+
},
157+
{
158+
"cell_type": "code",
159+
"source": [
160+
"synth_data = synth.sample(1000)\n",
161+
"print(synth_data)"
162+
],
163+
"metadata": {
164+
"id": "0aa2g0RLMkqe",
165+
"colab": {
166+
"base_uri": "https://localhost:8080/"
167+
},
168+
"outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0"
169+
},
170+
"execution_count": 8,
171+
"outputs": [
172+
{
173+
"output_type": "stream",
174+
"name": "stdout",
175+
"text": [
176+
" age workclass fnlwgt education education-num \\\n",
177+
"0 38.753654 4 179993.565472 8 10.0 \n",
178+
"1 36.408844 4 245841.807958 9 10.0 \n",
179+
"2 56.251066 4 400895.076058 11 13.0 \n",
180+
"3 26.846605 4 240156.201048 11 10.0 \n",
181+
"4 29.083102 1 5601.059126 11 9.0 \n",
182+
".. ... ... ... ... ... \n",
183+
"995 79.281276 4 30664.183560 1 10.0 \n",
184+
"996 51.423132 4 414524.980527 1 10.0 \n",
185+
"997 17.342915 6 177716.451926 11 13.0 \n",
186+
"998 39.298867 4 132011.369567 15 12.0 \n",
187+
"999 46.977763 2 92662.371635 9 13.0 \n",
188+
"\n",
189+
" marital-status occupation relationship race sex capital-gain \\\n",
190+
"0 4 0 3 4 0 55.771499 \n",
191+
"1 6 7 0 4 1 124.337939 \n",
192+
"2 4 3 3 4 1 27.968087 \n",
193+
"3 4 6 1 4 0 25.065678 \n",
194+
"4 6 3 0 4 0 126.269337 \n",
195+
".. ... ... ... ... ... ... \n",
196+
"995 2 0 3 4 1 4.393001 \n",
197+
"996 4 7 3 2 0 54.841598 \n",
198+
"997 4 4 4 4 0 99.394428 \n",
199+
"998 4 14 1 4 1 97.834797 \n",
200+
"999 4 8 1 4 0 51.258308 \n",
201+
"\n",
202+
" capital-loss hours-per-week native-country target \n",
203+
"0 -1.271118 39.749641 39 1 \n",
204+
"1 -2.114950 44.488198 39 1 \n",
205+
"2 1.541738 40.042696 39 1 \n",
206+
"3 1.148560 39.952615 39 1 \n",
207+
"4 -1.786768 39.808085 39 0 \n",
208+
".. ... ... ... ... \n",
209+
"995 0.224015 50.580637 39 1 \n",
210+
"996 1.319341 4.441194 39 1 \n",
211+
"997 -5.231663 39.779674 39 1 \n",
212+
"998 1.595817 39.731359 13 1 \n",
213+
"999 1.129814 39.838415 39 1 \n",
214+
"\n",
215+
"[1000 rows x 15 columns]\n"
216+
]
217+
}
218+
]
219+
}
220+
]
221+
}

0 commit comments

Comments
 (0)