1+ {
2+ "nbformat" : 4 ,
3+ "nbformat_minor" : 0 ,
4+ "metadata" : {
5+ "colab" : {
6+ "provenance" : []
7+ },
8+ "kernelspec" : {
9+ "name" : " python3" ,
10+ "display_name" : " Python 3"
11+ },
12+ "language_info" : {
13+ "name" : " python"
14+ }
15+ },
16+ "cells" : [
17+ {
18+ "cell_type" : " code" ,
19+ "source" : [
20+ " !pip -q install \" transformers>=4.49\" \" optimum[onnxruntime]>=1.20.0\" \" datasets>=2.20\" \" evaluate>=0.4\" accelerate\n " ,
21+ " \n " ,
22+ " from pathlib import Path\n " ,
23+ " import os, time, numpy as np, torch\n " ,
24+ " from datasets import load_dataset\n " ,
25+ " import evaluate\n " ,
26+ " from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline\n " ,
27+ " from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer\n " ,
28+ " from optimum.onnxruntime.configuration import QuantizationConfig\n " ,
29+ " \n " ,
30+ " os.environ.setdefault(\" OMP_NUM_THREADS\" , \" 1\" )\n " ,
31+ " os.environ.setdefault(\" MKL_NUM_THREADS\" , \" 1\" )\n " ,
32+ " \n " ,
33+ " MODEL_ID = \" distilbert-base-uncased-finetuned-sst-2-english\"\n " ,
34+ " ORT_DIR = Path(\" onnx-distilbert\" )\n " ,
35+ " Q_DIR = Path(\" onnx-distilbert-quant\" )\n " ,
36+ " DEVICE = \" cuda\" if torch.cuda.is_available() else \" cpu\"\n " ,
37+ " BATCH = 16\n " ,
38+ " MAXLEN = 128\n " ,
39+ " N_WARM = 3\n " ,
40+ " N_ITERS = 8\n " ,
41+ " \n " ,
42+ " print(f\" Device: {DEVICE} | torch={torch.__version__}\" )"
43+ ],
44+ "metadata" : {
45+ "colab" : {
46+ "base_uri" : " https://localhost:8080/"
47+ },
48+ "id" : " Eli2cXUjJsiT" ,
49+ "outputId" : " 7a623347-afc8-4620-99b2-432ca491fac2"
50+ },
51+ "execution_count" : 3 ,
52+ "outputs" : [
53+ {
54+ "output_type" : " stream" ,
55+ "name" : " stdout" ,
56+ "text" : [
57+ " Device: cpu | torch=2.8.0+cu126\n "
58+ ]
59+ }
60+ ]
61+ },
62+ {
63+ "cell_type" : " code" ,
64+ "source" : [
65+ " ds = load_dataset(\" glue\" , \" sst2\" , split=\" validation[:20%]\" )\n " ,
66+ " texts, labels = ds[\" sentence\" ], ds[\" label\" ]\n " ,
67+ " metric = evaluate.load(\" accuracy\" )\n " ,
68+ " tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n " ,
69+ " \n " ,
70+ " def make_batches(texts, max_len=MAXLEN, batch=BATCH):\n " ,
71+ " for i in range(0, len(texts), batch):\n " ,
72+ " yield tokenizer(texts[i:i+batch], padding=True, truncation=True,\n " ,
73+ " max_length=max_len, return_tensors=\" pt\" )\n " ,
74+ " \n " ,
75+ " def run_eval(predict_fn, texts, labels):\n " ,
76+ " preds = []\n " ,
77+ " for toks in make_batches(texts):\n " ,
78+ " preds.extend(predict_fn(toks))\n " ,
79+ " return metric.compute(predictions=preds, references=labels)[\" accuracy\" ]\n " ,
80+ " \n " ,
81+ " def bench(predict_fn, texts, n_warm=N_WARM, n_iters=N_ITERS):\n " ,
82+ " for _ in range(n_warm):\n " ,
83+ " for toks in make_batches(texts[:BATCH*2]):\n " ,
84+ " predict_fn(toks)\n " ,
85+ " times = []\n " ,
86+ " for _ in range(n_iters):\n " ,
87+ " t0 = time.time()\n " ,
88+ " for toks in make_batches(texts):\n " ,
89+ " predict_fn(toks)\n " ,
90+ " times.append((time.time() - t0) * 1000)\n " ,
91+ " return float(np.mean(times)), float(np.std(times))"
92+ ],
93+ "metadata" : {
94+ "id" : " 7E_aUakIJyWU"
95+ },
96+ "execution_count" : 4 ,
97+ "outputs" : []
98+ },
99+ {
100+ "cell_type" : " code" ,
101+ "source" : [
102+ " torch_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()\n " ,
103+ " \n " ,
104+ " @torch.no_grad()\n " ,
105+ " def pt_predict(toks):\n " ,
106+ " toks = {k: v.to(DEVICE) for k, v in toks.items()}\n " ,
107+ " logits = torch_model(**toks).logits\n " ,
108+ " return logits.argmax(-1).detach().cpu().tolist()\n " ,
109+ " \n " ,
110+ " pt_ms, pt_sd = bench(pt_predict, texts)\n " ,
111+ " pt_acc = run_eval(pt_predict, texts, labels)\n " ,
112+ " print(f\" [PyTorch eager] {pt_ms:.1f}±{pt_sd:.1f} ms | acc={pt_acc:.4f}\" )\n " ,
113+ " \n " ,
114+ " compiled_model = torch_model\n " ,
115+ " compile_ok = False\n " ,
116+ " try:\n " ,
117+ " compiled_model = torch.compile(torch_model, mode=\" reduce-overhead\" , fullgraph=False)\n " ,
118+ " compile_ok = True\n " ,
119+ " except Exception as e:\n " ,
120+ " print(\" torch.compile unavailable or failed -> skipping:\" , repr(e))\n " ,
121+ " \n " ,
122+ " @torch.no_grad()\n " ,
123+ " def ptc_predict(toks):\n " ,
124+ " toks = {k: v.to(DEVICE) for k, v in toks.items()}\n " ,
125+ " logits = compiled_model(**toks).logits\n " ,
126+ " return logits.argmax(-1).detach().cpu().tolist()\n " ,
127+ " \n " ,
128+ " if compile_ok:\n " ,
129+ " ptc_ms, ptc_sd = bench(ptc_predict, texts)\n " ,
130+ " ptc_acc = run_eval(ptc_predict, texts, labels)\n " ,
131+ " print(f\" [torch.compile] {ptc_ms:.1f}±{ptc_sd:.1f} ms | acc={ptc_acc:.4f}\" )"
132+ ],
133+ "metadata" : {
134+ "colab" : {
135+ "base_uri" : " https://localhost:8080/"
136+ },
137+ "id" : " 8uYCXTJHJ3cu" ,
138+ "outputId" : " 4cf2584c-d8f8-4ee5-83fb-47a68164a809"
139+ },
140+ "execution_count" : null ,
141+ "outputs" : [
142+ {
143+ "output_type" : " stream" ,
144+ "name" : " stdout" ,
145+ "text" : [
146+ " [PyTorch eager] 20008.8±6119.6 ms | acc=0.9080\n "
147+ ]
148+ }
149+ ]
150+ },
151+ {
152+ "cell_type" : " code" ,
153+ "source" : [
154+ " provider = \" CUDAExecutionProvider\" if DEVICE == \" cuda\" else \" CPUExecutionProvider\"\n " ,
155+ " ort_model = ORTModelForSequenceClassification.from_pretrained(\n " ,
156+ " MODEL_ID, export=True, provider=provider, cache_dir=ORT_DIR\n " ,
157+ " )\n " ,
158+ " \n " ,
159+ " @torch.no_grad()\n " ,
160+ " def ort_predict(toks):\n " ,
161+ " logits = ort_model(**{k: v.cpu() for k, v in toks.items()}).logits\n " ,
162+ " return logits.argmax(-1).cpu().tolist()\n " ,
163+ " \n " ,
164+ " ort_ms, ort_sd = bench(ort_predict, texts)\n " ,
165+ " ort_acc = run_eval(ort_predict, texts, labels)\n " ,
166+ " print(f\" [ONNX Runtime] {ort_ms:.1f}±{ort_sd:.1f} ms | acc={ort_acc:.4f}\" )\n " ,
167+ " \n " ,
168+ " Q_DIR.mkdir(parents=True, exist_ok=True)\n " ,
169+ " quantizer = ORTQuantizer.from_pretrained(ORT_DIR)\n " ,
170+ " qconfig = QuantizationConfig(approach=\" dynamic\" , per_channel=False, reduce_range=True)\n " ,
171+ " quantizer.quantize(model_input=ORT_DIR, quantization_config=qconfig, save_dir=Q_DIR)\n " ,
172+ " \n " ,
173+ " ort_quant = ORTModelForSequenceClassification.from_pretrained(Q_DIR, provider=provider)\n " ,
174+ " \n " ,
175+ " @torch.no_grad()\n " ,
176+ " def ortq_predict(toks):\n " ,
177+ " logits = ort_quant(**{k: v.cpu() for k, v in toks.items()}).logits\n " ,
178+ " return logits.argmax(-1).cpu().tolist()\n " ,
179+ " \n " ,
180+ " oq_ms, oq_sd = bench(ortq_predict, texts)\n " ,
181+ " oq_acc = run_eval(ortq_predict, texts, labels)\n " ,
182+ " print(f\" [ORT Quantized] {oq_ms:.1f}±{oq_sd:.1f} ms | acc={oq_acc:.4f}\" )"
183+ ],
184+ "metadata" : {
185+ "id" : " znWVaD4BJ7Nc"
186+ },
187+ "execution_count" : null ,
188+ "outputs" : []
189+ },
190+ {
191+ "cell_type" : " code" ,
192+ "execution_count" : null ,
193+ "metadata" : {
194+ "id" : " hHqiX_aNGjsI"
195+ },
196+ "outputs" : [],
197+ "source" : [
198+ " pt_pipe = pipeline(\" sentiment-analysis\" , model=torch_model, tokenizer=tokenizer,\n " ,
199+ " device=0 if DEVICE==\" cuda\" else -1)\n " ,
200+ " ort_pipe = pipeline(\" sentiment-analysis\" , model=ort_model, tokenizer=tokenizer, device=-1)\n " ,
201+ " samples = [\n " ,
202+ " \" What a fantastic movie—performed brilliantly!\" ,\n " ,
203+ " \" This was a complete waste of time.\" ,\n " ,
204+ " \" I’m not sure how I feel about this one.\"\n " ,
205+ " ]\n " ,
206+ " print(\"\\ nSample predictions (PT | ORT):\" )\n " ,
207+ " for s in samples:\n " ,
208+ " a = pt_pipe(s)[0][\" label\" ]\n " ,
209+ " b = ort_pipe(s)[0][\" label\" ]\n " ,
210+ " print(f\" - {s}\\ n PT={a} | ORT={b}\" )\n " ,
211+ " \n " ,
212+ " import pandas as pd\n " ,
213+ " rows = [[\" PyTorch eager\" , pt_ms, pt_sd, pt_acc],\n " ,
214+ " [\" ONNX Runtime\" , ort_ms, ort_sd, ort_acc],\n " ,
215+ " [\" ORT Quantized\" , oq_ms, oq_sd, oq_acc]]\n " ,
216+ " if compile_ok: rows.insert(1, [\" torch.compile\" , ptc_ms, ptc_sd, ptc_acc])\n " ,
217+ " df = pd.DataFrame(rows, columns=[\" Engine\" , \" Mean ms (↓)\" , \" Std ms\" , \" Accuracy\" ])\n " ,
218+ " display(df)\n " ,
219+ " \n " ,
220+ " print(\"\"\"\n " ,
221+ " Notes:\n " ,
222+ " - BetterTransformer is deprecated on transformers>=4.49, hence omitted.\n " ,
223+ " - For larger gains on GPU, also try FlashAttention2 models or FP8 with TensorRT-LLM.\n " ,
224+ " - For CPU, tune threads: set OMP_NUM_THREADS/MKL_NUM_THREADS; try NUMA pinning.\n " ,
225+ " - For static (calibrated) quantization, use QuantizationConfig(approach='static') with a calibration set.\n " ,
226+ " \"\"\" )"
227+ ]
228+ }
229+ ]
230+ }
0 commit comments