Skip to content

Commit c68a4cd

Browse files
authored
Add files via upload
1 parent 090648a commit c68a4cd

1 file changed

Lines changed: 230 additions & 0 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
},
12+
"language_info": {
13+
"name": "python"
14+
}
15+
},
16+
"cells": [
17+
{
18+
"cell_type": "code",
19+
"source": [
20+
"!pip -q install \"transformers>=4.49\" \"optimum[onnxruntime]>=1.20.0\" \"datasets>=2.20\" \"evaluate>=0.4\" accelerate\n",
21+
"\n",
22+
"from pathlib import Path\n",
23+
"import os, time, numpy as np, torch\n",
24+
"from datasets import load_dataset\n",
25+
"import evaluate\n",
26+
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline\n",
27+
"from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer\n",
28+
"from optimum.onnxruntime.configuration import QuantizationConfig\n",
29+
"\n",
30+
"os.environ.setdefault(\"OMP_NUM_THREADS\", \"1\")\n",
31+
"os.environ.setdefault(\"MKL_NUM_THREADS\", \"1\")\n",
32+
"\n",
33+
"MODEL_ID = \"distilbert-base-uncased-finetuned-sst-2-english\"\n",
34+
"ORT_DIR = Path(\"onnx-distilbert\")\n",
35+
"Q_DIR = Path(\"onnx-distilbert-quant\")\n",
36+
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
37+
"BATCH = 16\n",
38+
"MAXLEN = 128\n",
39+
"N_WARM = 3\n",
40+
"N_ITERS = 8\n",
41+
"\n",
42+
"print(f\"Device: {DEVICE} | torch={torch.__version__}\")"
43+
],
44+
"metadata": {
45+
"colab": {
46+
"base_uri": "https://localhost:8080/"
47+
},
48+
"id": "Eli2cXUjJsiT",
49+
"outputId": "7a623347-afc8-4620-99b2-432ca491fac2"
50+
},
51+
"execution_count": 3,
52+
"outputs": [
53+
{
54+
"output_type": "stream",
55+
"name": "stdout",
56+
"text": [
57+
"Device: cpu | torch=2.8.0+cu126\n"
58+
]
59+
}
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"source": [
65+
"ds = load_dataset(\"glue\", \"sst2\", split=\"validation[:20%]\")\n",
66+
"texts, labels = ds[\"sentence\"], ds[\"label\"]\n",
67+
"metric = evaluate.load(\"accuracy\")\n",
68+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
69+
"\n",
70+
"def make_batches(texts, max_len=MAXLEN, batch=BATCH):\n",
71+
" for i in range(0, len(texts), batch):\n",
72+
" yield tokenizer(texts[i:i+batch], padding=True, truncation=True,\n",
73+
" max_length=max_len, return_tensors=\"pt\")\n",
74+
"\n",
75+
"def run_eval(predict_fn, texts, labels):\n",
76+
" preds = []\n",
77+
" for toks in make_batches(texts):\n",
78+
" preds.extend(predict_fn(toks))\n",
79+
" return metric.compute(predictions=preds, references=labels)[\"accuracy\"]\n",
80+
"\n",
81+
"def bench(predict_fn, texts, n_warm=N_WARM, n_iters=N_ITERS):\n",
82+
" for _ in range(n_warm):\n",
83+
" for toks in make_batches(texts[:BATCH*2]):\n",
84+
" predict_fn(toks)\n",
85+
" times = []\n",
86+
" for _ in range(n_iters):\n",
87+
" t0 = time.time()\n",
88+
" for toks in make_batches(texts):\n",
89+
" predict_fn(toks)\n",
90+
" times.append((time.time() - t0) * 1000)\n",
91+
" return float(np.mean(times)), float(np.std(times))"
92+
],
93+
"metadata": {
94+
"id": "7E_aUakIJyWU"
95+
},
96+
"execution_count": 4,
97+
"outputs": []
98+
},
99+
{
100+
"cell_type": "code",
101+
"source": [
102+
"torch_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()\n",
103+
"\n",
104+
"@torch.no_grad()\n",
105+
"def pt_predict(toks):\n",
106+
" toks = {k: v.to(DEVICE) for k, v in toks.items()}\n",
107+
" logits = torch_model(**toks).logits\n",
108+
" return logits.argmax(-1).detach().cpu().tolist()\n",
109+
"\n",
110+
"pt_ms, pt_sd = bench(pt_predict, texts)\n",
111+
"pt_acc = run_eval(pt_predict, texts, labels)\n",
112+
"print(f\"[PyTorch eager] {pt_ms:.1f}±{pt_sd:.1f} ms | acc={pt_acc:.4f}\")\n",
113+
"\n",
114+
"compiled_model = torch_model\n",
115+
"compile_ok = False\n",
116+
"try:\n",
117+
" compiled_model = torch.compile(torch_model, mode=\"reduce-overhead\", fullgraph=False)\n",
118+
" compile_ok = True\n",
119+
"except Exception as e:\n",
120+
" print(\"torch.compile unavailable or failed -> skipping:\", repr(e))\n",
121+
"\n",
122+
"@torch.no_grad()\n",
123+
"def ptc_predict(toks):\n",
124+
" toks = {k: v.to(DEVICE) for k, v in toks.items()}\n",
125+
" logits = compiled_model(**toks).logits\n",
126+
" return logits.argmax(-1).detach().cpu().tolist()\n",
127+
"\n",
128+
"if compile_ok:\n",
129+
" ptc_ms, ptc_sd = bench(ptc_predict, texts)\n",
130+
" ptc_acc = run_eval(ptc_predict, texts, labels)\n",
131+
" print(f\"[torch.compile] {ptc_ms:.1f}±{ptc_sd:.1f} ms | acc={ptc_acc:.4f}\")"
132+
],
133+
"metadata": {
134+
"colab": {
135+
"base_uri": "https://localhost:8080/"
136+
},
137+
"id": "8uYCXTJHJ3cu",
138+
"outputId": "4cf2584c-d8f8-4ee5-83fb-47a68164a809"
139+
},
140+
"execution_count": null,
141+
"outputs": [
142+
{
143+
"output_type": "stream",
144+
"name": "stdout",
145+
"text": [
146+
"[PyTorch eager] 20008.8±6119.6 ms | acc=0.9080\n"
147+
]
148+
}
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"source": [
154+
"provider = \"CUDAExecutionProvider\" if DEVICE == \"cuda\" else \"CPUExecutionProvider\"\n",
155+
"ort_model = ORTModelForSequenceClassification.from_pretrained(\n",
156+
" MODEL_ID, export=True, provider=provider, cache_dir=ORT_DIR\n",
157+
")\n",
158+
"\n",
159+
"@torch.no_grad()\n",
160+
"def ort_predict(toks):\n",
161+
" logits = ort_model(**{k: v.cpu() for k, v in toks.items()}).logits\n",
162+
" return logits.argmax(-1).cpu().tolist()\n",
163+
"\n",
164+
"ort_ms, ort_sd = bench(ort_predict, texts)\n",
165+
"ort_acc = run_eval(ort_predict, texts, labels)\n",
166+
"print(f\"[ONNX Runtime] {ort_ms:.1f}±{ort_sd:.1f} ms | acc={ort_acc:.4f}\")\n",
167+
"\n",
168+
"Q_DIR.mkdir(parents=True, exist_ok=True)\n",
169+
"quantizer = ORTQuantizer.from_pretrained(ORT_DIR)\n",
170+
"qconfig = QuantizationConfig(approach=\"dynamic\", per_channel=False, reduce_range=True)\n",
171+
"quantizer.quantize(model_input=ORT_DIR, quantization_config=qconfig, save_dir=Q_DIR)\n",
172+
"\n",
173+
"ort_quant = ORTModelForSequenceClassification.from_pretrained(Q_DIR, provider=provider)\n",
174+
"\n",
175+
"@torch.no_grad()\n",
176+
"def ortq_predict(toks):\n",
177+
" logits = ort_quant(**{k: v.cpu() for k, v in toks.items()}).logits\n",
178+
" return logits.argmax(-1).cpu().tolist()\n",
179+
"\n",
180+
"oq_ms, oq_sd = bench(ortq_predict, texts)\n",
181+
"oq_acc = run_eval(ortq_predict, texts, labels)\n",
182+
"print(f\"[ORT Quantized] {oq_ms:.1f}±{oq_sd:.1f} ms | acc={oq_acc:.4f}\")"
183+
],
184+
"metadata": {
185+
"id": "znWVaD4BJ7Nc"
186+
},
187+
"execution_count": null,
188+
"outputs": []
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"id": "hHqiX_aNGjsI"
195+
},
196+
"outputs": [],
197+
"source": [
198+
"pt_pipe = pipeline(\"sentiment-analysis\", model=torch_model, tokenizer=tokenizer,\n",
199+
" device=0 if DEVICE==\"cuda\" else -1)\n",
200+
"ort_pipe = pipeline(\"sentiment-analysis\", model=ort_model, tokenizer=tokenizer, device=-1)\n",
201+
"samples = [\n",
202+
" \"What a fantastic movie—performed brilliantly!\",\n",
203+
" \"This was a complete waste of time.\",\n",
204+
" \"I’m not sure how I feel about this one.\"\n",
205+
"]\n",
206+
"print(\"\\nSample predictions (PT | ORT):\")\n",
207+
"for s in samples:\n",
208+
" a = pt_pipe(s)[0][\"label\"]\n",
209+
" b = ort_pipe(s)[0][\"label\"]\n",
210+
" print(f\"- {s}\\n PT={a} | ORT={b}\")\n",
211+
"\n",
212+
"import pandas as pd\n",
213+
"rows = [[\"PyTorch eager\", pt_ms, pt_sd, pt_acc],\n",
214+
" [\"ONNX Runtime\", ort_ms, ort_sd, ort_acc],\n",
215+
" [\"ORT Quantized\", oq_ms, oq_sd, oq_acc]]\n",
216+
"if compile_ok: rows.insert(1, [\"torch.compile\", ptc_ms, ptc_sd, ptc_acc])\n",
217+
"df = pd.DataFrame(rows, columns=[\"Engine\", \"Mean ms (↓)\", \"Std ms\", \"Accuracy\"])\n",
218+
"display(df)\n",
219+
"\n",
220+
"print(\"\"\"\n",
221+
"Notes:\n",
222+
"- BetterTransformer is deprecated on transformers>=4.49, hence omitted.\n",
223+
"- For larger gains on GPU, also try FlashAttention2 models or FP8 with TensorRT-LLM.\n",
224+
"- For CPU, tune threads: set OMP_NUM_THREADS/MKL_NUM_THREADS; try NUMA pinning.\n",
225+
"- For static (calibrated) quantization, use QuantizationConfig(approach='static') with a calibration set.\n",
226+
"\"\"\")"
227+
]
228+
}
229+
]
230+
}

0 commit comments

Comments
 (0)