diff --git a/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb b/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb new file mode 100644 index 000000000..cf171e414 --- /dev/null +++ b/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb @@ -0,0 +1,435 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "iUr0AjrgVqUk" + }, + "outputs": [], + "source": [ + "# Copyright 2026 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kxnledoMWKxp" + }, + "source": [ + "# Vertex AI Model Garden - OSS Distillation Feasibility Study\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \"Workbench
Run in Workbench\n", + "
\n", + "
\n", + " \n", + " \"Google
Run in Colab Enterprise\n", + "
\n", + "
\n", + " \n", + " \"GitHub
View on GitHub\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "spClB3u_WNz1" + }, + "source": [ + "## Overview\n", + "This notebook serves as a critical validation step in the Model Distillation pipeline. Before committing resources to distill a smaller \"Student\" model to mimic a larger \"Teacher\" model, we must first determine if the model pair and the target dataset are compatible.\n", + "\n", + "### Core Objective\n", + "The primary goal is to calculate and compare the Perplexity scores of both the Teacher and Student models on a specific domain dataset, in this case, [syz-ml2025/medmcqa](https://huggingface.co/datasets/syz-ml2025/medmcqa) (Medical Multiple-Choice QA).\n", + "\n", + "### Why Perplexity Matters\n", + "Perplexity is a measurement of how well a probability model predicts a sample.\n", + "\n", + "- **Teacher Validation:** We use the `deepseek-ai/deepseek-r1-0528-maas` model via the [Vertex AI MaaS API](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/maas/call-open-model-apis). A low teacher perplexity confirms that the teacher understands the domain well enough to provide high-quality supervision.\n", + "\n", + "- **Gap Identification:** We evaluate the `Qwen/Qwen3-0.6B` student model using the HuggingFace library. By quantifying the performance gap between the teacher and the student, we can estimate the potential \"knowledge headroom\" available for distillation.\n", + "\n", + "### Workflow Steps\n", + "Environment Setup: Installation of prerequisite libraries (Vertex AI SDK, Transformers, Datasets).\n", + "\n", + "- **Dataset Sampling:** Extraction of a representative subset (200 samples) from the medmcqa training split to ensure efficient benchmarking.\n", + "\n", + "- **Teacher Inference:** Leveraging Vertex AI Model-as-a-Service (MaaS) to compute log-likelihoods and derive the teacher's perplexity.\n", + "\n", + "- **Student Inference:** Running the student model on local GPU resources (NVIDIA L4) to calculate its baseline perplexity before any training occurs.\n", + "\n", + "### Success Criteria\n", + "A successful prerequisite check is identified when the Teacher model shows significantly lower perplexity than the Student, suggesting that the Student has a meaningful opportunity to learn the Teacher's superior reasoning and linguistic patterns within the medical context." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1xpacQK2Wn20" + }, + "source": [ + "## Before you begin" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "R7RIdZD8WqIz" + }, + "outputs": [], + "source": [ + "# @markdown Install the prerequisite libraries.\n", + "!pip3 install --upgrade datasets transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "9GNQZPRpWto6" + }, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# @markdown Create a small test dataset using HF dataset for perplexity calculation.\n", + "from datasets import load_dataset\n", + "\n", + "\n", + "def create_test_dataset(sample_dataset, test_file_path, dataset_split, num_samples=200):\n", + " dataset = load_dataset(sample_dataset, split=dataset_split)\n", + " test_samples = dataset.select(range(num_samples))\n", + " with open(test_file_path, \"w\") as f:\n", + " for sample in test_samples:\n", + " f.write(json.dumps(sample) + \"\\n\")\n", + "\n", + "\n", + "def format_dataset(dataset):\n", + " \"\"\"Formats the dataset for perplexity calculation.\"\"\"\n", + " formatted_data = []\n", + " with open(dataset, \"r\") as f:\n", + " data = [json.loads(line) for line in f if line.strip()]\n", + " for record in tqdm.tqdm(data):\n", + " cop_map = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}\n", + " text = (\n", + " f'{record[\"question\"]}\\nA) {record[\"opa\"]}\\nB) {record[\"opb\"]}\\nC) '\n", + " f'{record[\"opc\"]}\\nD) {record[\"opd\"]}'\n", + " )\n", + " if record[\"exp\"]:\n", + " text += f'\\n{record[\"exp\"]}'\n", + " text += f'\\nThe answer is {cop_map[record[\"cop\"]]}'\n", + " formatted_data.append(text)\n", + "\n", + " return formatted_data\n", + "\n", + "\n", + "sample_dataset = \"syz-ml2025/medmcqa\" # @param {type:\"string\"}\n", + "test_file_path = \"test.jsonl\" # @param {type:\"string\"}\n", + "dataset_split = \"train\" # @param {type:\"string\"}\n", + "dataset_sample_size = 200 # @param {type:\"integer\"}\n", + "\n", + "create_test_dataset(\n", + " sample_dataset=sample_dataset,\n", + " test_file_path=test_file_path,\n", + " dataset_split=dataset_split,\n", + " num_samples=dataset_sample_size,\n", + ")\n", + "formatted_dataset = format_dataset(test_file_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "IiiN9YiKW3dv" + }, + "outputs": [], + "source": [ + "# @markdown Perplexity calculation on the test dataset using VMG MaaS API for the teacher model.\n", + "\n", + "import concurrent.futures\n", + "import os\n", + "import subprocess\n", + "import sys\n", + "from typing import Any, Tuple\n", + "\n", + "import openai\n", + "import tenacity\n", + "\n", + "\n", + "def get_access_token() -> str:\n", + " \"\"\"Gets the access token from gcloud.\"\"\"\n", + " try:\n", + " return subprocess.check_output(\n", + " [\"gcloud\", \"auth\", \"print-access-token\"], encoding=\"utf-8\"\n", + " ).strip()\n", + " except subprocess.CalledProcessError as e:\n", + " print(\"Error getting access token: %r\", e)\n", + " sys.exit(1)\n", + "\n", + "\n", + "@tenacity.retry(\n", + " stop=tenacity.stop_after_attempt(10),\n", + " wait=tenacity.wait_exponential(multiplier=1, min=2, max=60),\n", + " retry=tenacity.retry_if_exception_type(\n", + " (\n", + " openai.APIConnectionError,\n", + " openai.RateLimitError,\n", + " openai.APIStatusError,\n", + " )\n", + " ),\n", + ")\n", + "def _create_completion_with_retry(\n", + " client: openai.OpenAI,\n", + " messages: list[dict[str, str]],\n", + " model: str,\n", + ") -> Any:\n", + " \"\"\"Creates a completion with retry logic to get logprobs.\"\"\"\n", + " return client.chat.completions.create(\n", + " model=model,\n", + " messages=messages,\n", + " max_tokens=16384,\n", + " logprobs=True,\n", + " extra_body={\"prompt_logprobs\": 0},\n", + " temperature=0.0,\n", + " )\n", + "\n", + "\n", + "def calculate_perplexity_item(\n", + " messages,\n", + " client: openai.OpenAI,\n", + " model: str,\n", + ") -> Tuple[float, int]:\n", + " \"\"\"Calculates sum neg log likelihood and token count for a single record.\"\"\"\n", + " response = _create_completion_with_retry(\n", + " client=client,\n", + " messages=messages,\n", + " model=model,\n", + " )\n", + "\n", + " logprobs_list = []\n", + " # Check for prompt_logprobs field first\n", + " if hasattr(response, \"prompt_logprobs\") and response.prompt_logprobs:\n", + " # prompt_logprobs is like [None, {'token1': {'logprob': -0.1}}, ...]\n", + " for token_logprob in response.prompt_logprobs[1:]:\n", + " if token_logprob:\n", + " logprobs_list.append(list(token_logprob.values())[0][\"logprob\"])\n", + " # If not, check logprobs in choices (OpenAI format for chat output tokens)\n", + " # in case MaaS populates it with prompt logprobs\n", + " elif (\n", + " response.choices\n", + " and response.choices[0].logprobs\n", + " and response.choices[0].logprobs.content\n", + " ):\n", + " logprobs_list = [\n", + " lp.logprob\n", + " for lp in response.choices[0].logprobs.content\n", + " if lp and lp.logprob is not None\n", + " ]\n", + " else:\n", + " return 0.0, 0\n", + "\n", + " if not logprobs_list:\n", + " return 0.0, 0\n", + "\n", + " sum_nll = -np.sum(logprobs_list)\n", + " token_count = len(logprobs_list)\n", + " return sum_nll, token_count\n", + "\n", + "\n", + "def calculate_perplexity_teacher(\n", + " model: str,\n", + " project_id: str,\n", + " region: str,\n", + " dataset: list[str],\n", + " max_workers: int = 10,\n", + ") -> None:\n", + " \"\"\"Calculate the Perplexity score of the teacher model using VMG MaaS API.\n", + "\n", + " Args:\n", + " model: Model ID for MaaS.\n", + " project_id: GCP Project ID.\n", + " region: GCP Region.\n", + " dataset: The dataset.\n", + " max_workers: Number of parallel workers.\n", + " \"\"\"\n", + " api_key = get_access_token()\n", + " if region == \"global\":\n", + " api_endpoint = \"aiplatform.googleapis.com\"\n", + " else:\n", + " api_endpoint = f\"{region}-aiplatform.googleapis.com\"\n", + " base_url = f\"https://{api_endpoint}/v1/projects/{project_id}/locations/{region}/endpoints/openapi\"\n", + " print(f\"Using MaaS base URL for the teacher model: {base_url}\")\n", + "\n", + " client = openai.OpenAI(\n", + " base_url=base_url,\n", + " api_key=api_key,\n", + " )\n", + " per_sequence_ppls = []\n", + " total_tokens = 0\n", + "\n", + " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", + " records = [[{\"role\": \"user\", \"content\": text}] for text in dataset]\n", + "\n", + " futures = [\n", + " executor.submit(\n", + " calculate_perplexity_item,\n", + " record,\n", + " client,\n", + " model,\n", + " )\n", + " for record in records\n", + " ]\n", + " for future in tqdm.tqdm(\n", + " concurrent.futures.as_completed(futures), total=len(futures)\n", + " ):\n", + " sum_nll, token_count = future.result()\n", + " if token_count > 0:\n", + " total_tokens += token_count\n", + " per_sequence_ppls.append(np.exp(sum_nll / token_count))\n", + "\n", + " if total_tokens == 0:\n", + " print(\"\\n\\nNo perplexity results calculated for the teacher model.\")\n", + " return\n", + "\n", + " ppl = np.mean(per_sequence_ppls)\n", + " print(f\"\\n\\nTeacher model Perplexity for the given dataset: {ppl:.4f}\")\n", + "\n", + "\n", + "# Execute\n", + "teacher_maas_model_id = \"deepseek-ai/deepseek-r1-0528-maas\" # @param {type:\"string\"}\n", + "teacher_project = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n", + "teacher_region = \"us-central1\" # @param {type:\"string\"}\n", + "\n", + "calculate_perplexity_teacher(\n", + " teacher_maas_model_id,\n", + " teacher_project,\n", + " teacher_region,\n", + " formatted_dataset,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "EY8qkPP1W5rU" + }, + "outputs": [], + "source": [ + "# @markdown Perplexity calculation on the test dataset using HuggingFace for the student model. This section will require a GPU runtime environment to execute.\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import tqdm\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "\n", + "def calculate_perplexity_student(model_name, dataset, enable_thinking=True) -> None:\n", + " \"\"\"Runs perplexity calculation on input dataset for student model using HuggingFace.\n", + "\n", + " Args:\n", + " model_name: HuggingFace model ID or local path.\n", + " dataset: The dataset to calculate ppl on.\n", + " enable_thinking: Whether to enable thinking in chat template.\n", + " \"\"\"\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " device_map=\"auto\",\n", + " torch_dtype=torch.bfloat16,\n", + " trust_remote_code=True,\n", + " )\n", + " model.eval()\n", + "\n", + " max_length = model.config.max_position_embeddings\n", + " loss_fct = torch.nn.CrossEntropyLoss(reduction=\"none\")\n", + "\n", + " per_sequence_ppls = []\n", + " total_tokens = 0\n", + "\n", + " for text in tqdm.tqdm(dataset):\n", + " encodings = tokenizer(text, return_tensors=\"pt\")\n", + " seq_len = encodings.input_ids.size(1)\n", + "\n", + " chunk_nlls = []\n", + " for begin_loc in range(0, seq_len, max_length):\n", + " end_loc = min(begin_loc + max_length, seq_len)\n", + " input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)\n", + "\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", + "\n", + " shift_logits = logits[..., :-1, :].contiguous()\n", + " shift_labels = input_ids[..., 1:].contiguous().to(logits.device)\n", + "\n", + " nll = loss_fct(\n", + " shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n", + " )\n", + " chunk_nlls.append(nll)\n", + "\n", + " all_nlls = torch.cat(chunk_nlls)\n", + " if all_nlls.numel() > 0:\n", + " sum_nll_seq = all_nlls.sum().item()\n", + " token_count_seq = all_nlls.numel()\n", + " per_sequence_ppls.append(np.exp(sum_nll_seq / token_count_seq))\n", + " total_tokens += token_count_seq\n", + "\n", + " if total_tokens == 0:\n", + " print(\"\\n\\nNo perplexity results calculated for the student model.\")\n", + " return\n", + "\n", + " ppl = np.mean(per_sequence_ppls)\n", + " print(f\"\\n\\nStudent model Perplexity for the given dataset: {ppl:.4f}\")\n", + "\n", + "\n", + "# Execute\n", + "student_model_id = \"Qwen/Qwen3-0.6B\" # @param {type:\"string\"}\n", + "enable_student_thinking = True # @param {type:\"boolean\"}\n", + "\n", + "calculate_perplexity_student(\n", + " student_model_id, formatted_dataset, enable_student_thinking\n", + ")" + ] + } + ], + "metadata": { + "colab": { + "name": "model_garden_oss_distillation_feasibility.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}