diff --git a/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb b/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb
new file mode 100644
index 000000000..cf171e414
--- /dev/null
+++ b/notebooks/community/model_garden/model_garden_oss_distillation_feasibility.ipynb
@@ -0,0 +1,435 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "iUr0AjrgVqUk"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright 2026 Google LLC\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kxnledoMWKxp"
+ },
+ "source": [
+ "# Vertex AI Model Garden - OSS Distillation Feasibility Study\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ "  Run in Workbench\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ "  Run in Colab Enterprise\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ "  View on GitHub\n",
+ " \n",
+ " | \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "spClB3u_WNz1"
+ },
+ "source": [
+ "## Overview\n",
+ "This notebook serves as a critical validation step in the Model Distillation pipeline. Before committing resources to distill a smaller \"Student\" model to mimic a larger \"Teacher\" model, we must first determine if the model pair and the target dataset are compatible.\n",
+ "\n",
+ "### Core Objective\n",
+ "The primary goal is to calculate and compare the Perplexity scores of both the Teacher and Student models on a specific domain dataset, in this case, [syz-ml2025/medmcqa](https://huggingface.co/datasets/syz-ml2025/medmcqa) (Medical Multiple-Choice QA).\n",
+ "\n",
+ "### Why Perplexity Matters\n",
+ "Perplexity is a measurement of how well a probability model predicts a sample.\n",
+ "\n",
+ "- **Teacher Validation:** We use the `deepseek-ai/deepseek-r1-0528-maas` model via the [Vertex AI MaaS API](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/maas/call-open-model-apis). A low teacher perplexity confirms that the teacher understands the domain well enough to provide high-quality supervision.\n",
+ "\n",
+ "- **Gap Identification:** We evaluate the `Qwen/Qwen3-0.6B` student model using the HuggingFace library. By quantifying the performance gap between the teacher and the student, we can estimate the potential \"knowledge headroom\" available for distillation.\n",
+ "\n",
+ "### Workflow Steps\n",
+ "Environment Setup: Installation of prerequisite libraries (Vertex AI SDK, Transformers, Datasets).\n",
+ "\n",
+ "- **Dataset Sampling:** Extraction of a representative subset (200 samples) from the medmcqa training split to ensure efficient benchmarking.\n",
+ "\n",
+ "- **Teacher Inference:** Leveraging Vertex AI Model-as-a-Service (MaaS) to compute log-likelihoods and derive the teacher's perplexity.\n",
+ "\n",
+ "- **Student Inference:** Running the student model on local GPU resources (NVIDIA L4) to calculate its baseline perplexity before any training occurs.\n",
+ "\n",
+ "### Success Criteria\n",
+ "A successful prerequisite check is identified when the Teacher model shows significantly lower perplexity than the Student, suggesting that the Student has a meaningful opportunity to learn the Teacher's superior reasoning and linguistic patterns within the medical context."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1xpacQK2Wn20"
+ },
+ "source": [
+ "## Before you begin"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "R7RIdZD8WqIz"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown Install the prerequisite libraries.\n",
+ "!pip3 install --upgrade datasets transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "9GNQZPRpWto6"
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "# @markdown Create a small test dataset using HF dataset for perplexity calculation.\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "\n",
+ "def create_test_dataset(sample_dataset, test_file_path, dataset_split, num_samples=200):\n",
+ " dataset = load_dataset(sample_dataset, split=dataset_split)\n",
+ " test_samples = dataset.select(range(num_samples))\n",
+ " with open(test_file_path, \"w\") as f:\n",
+ " for sample in test_samples:\n",
+ " f.write(json.dumps(sample) + \"\\n\")\n",
+ "\n",
+ "\n",
+ "def format_dataset(dataset):\n",
+ " \"\"\"Formats the dataset for perplexity calculation.\"\"\"\n",
+ " formatted_data = []\n",
+ " with open(dataset, \"r\") as f:\n",
+ " data = [json.loads(line) for line in f if line.strip()]\n",
+ " for record in tqdm.tqdm(data):\n",
+ " cop_map = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}\n",
+ " text = (\n",
+ " f'{record[\"question\"]}\\nA) {record[\"opa\"]}\\nB) {record[\"opb\"]}\\nC) '\n",
+ " f'{record[\"opc\"]}\\nD) {record[\"opd\"]}'\n",
+ " )\n",
+ " if record[\"exp\"]:\n",
+ " text += f'\\n{record[\"exp\"]}'\n",
+ " text += f'\\nThe answer is {cop_map[record[\"cop\"]]}'\n",
+ " formatted_data.append(text)\n",
+ "\n",
+ " return formatted_data\n",
+ "\n",
+ "\n",
+ "sample_dataset = \"syz-ml2025/medmcqa\" # @param {type:\"string\"}\n",
+ "test_file_path = \"test.jsonl\" # @param {type:\"string\"}\n",
+ "dataset_split = \"train\" # @param {type:\"string\"}\n",
+ "dataset_sample_size = 200 # @param {type:\"integer\"}\n",
+ "\n",
+ "create_test_dataset(\n",
+ " sample_dataset=sample_dataset,\n",
+ " test_file_path=test_file_path,\n",
+ " dataset_split=dataset_split,\n",
+ " num_samples=dataset_sample_size,\n",
+ ")\n",
+ "formatted_dataset = format_dataset(test_file_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "IiiN9YiKW3dv"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown Perplexity calculation on the test dataset using VMG MaaS API for the teacher model.\n",
+ "\n",
+ "import concurrent.futures\n",
+ "import os\n",
+ "import subprocess\n",
+ "import sys\n",
+ "from typing import Any, Tuple\n",
+ "\n",
+ "import openai\n",
+ "import tenacity\n",
+ "\n",
+ "\n",
+ "def get_access_token() -> str:\n",
+ " \"\"\"Gets the access token from gcloud.\"\"\"\n",
+ " try:\n",
+ " return subprocess.check_output(\n",
+ " [\"gcloud\", \"auth\", \"print-access-token\"], encoding=\"utf-8\"\n",
+ " ).strip()\n",
+ " except subprocess.CalledProcessError as e:\n",
+ " print(\"Error getting access token: %r\", e)\n",
+ " sys.exit(1)\n",
+ "\n",
+ "\n",
+ "@tenacity.retry(\n",
+ " stop=tenacity.stop_after_attempt(10),\n",
+ " wait=tenacity.wait_exponential(multiplier=1, min=2, max=60),\n",
+ " retry=tenacity.retry_if_exception_type(\n",
+ " (\n",
+ " openai.APIConnectionError,\n",
+ " openai.RateLimitError,\n",
+ " openai.APIStatusError,\n",
+ " )\n",
+ " ),\n",
+ ")\n",
+ "def _create_completion_with_retry(\n",
+ " client: openai.OpenAI,\n",
+ " messages: list[dict[str, str]],\n",
+ " model: str,\n",
+ ") -> Any:\n",
+ " \"\"\"Creates a completion with retry logic to get logprobs.\"\"\"\n",
+ " return client.chat.completions.create(\n",
+ " model=model,\n",
+ " messages=messages,\n",
+ " max_tokens=16384,\n",
+ " logprobs=True,\n",
+ " extra_body={\"prompt_logprobs\": 0},\n",
+ " temperature=0.0,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def calculate_perplexity_item(\n",
+ " messages,\n",
+ " client: openai.OpenAI,\n",
+ " model: str,\n",
+ ") -> Tuple[float, int]:\n",
+ " \"\"\"Calculates sum neg log likelihood and token count for a single record.\"\"\"\n",
+ " response = _create_completion_with_retry(\n",
+ " client=client,\n",
+ " messages=messages,\n",
+ " model=model,\n",
+ " )\n",
+ "\n",
+ " logprobs_list = []\n",
+ " # Check for prompt_logprobs field first\n",
+ " if hasattr(response, \"prompt_logprobs\") and response.prompt_logprobs:\n",
+ " # prompt_logprobs is like [None, {'token1': {'logprob': -0.1}}, ...]\n",
+ " for token_logprob in response.prompt_logprobs[1:]:\n",
+ " if token_logprob:\n",
+ " logprobs_list.append(list(token_logprob.values())[0][\"logprob\"])\n",
+ " # If not, check logprobs in choices (OpenAI format for chat output tokens)\n",
+ " # in case MaaS populates it with prompt logprobs\n",
+ " elif (\n",
+ " response.choices\n",
+ " and response.choices[0].logprobs\n",
+ " and response.choices[0].logprobs.content\n",
+ " ):\n",
+ " logprobs_list = [\n",
+ " lp.logprob\n",
+ " for lp in response.choices[0].logprobs.content\n",
+ " if lp and lp.logprob is not None\n",
+ " ]\n",
+ " else:\n",
+ " return 0.0, 0\n",
+ "\n",
+ " if not logprobs_list:\n",
+ " return 0.0, 0\n",
+ "\n",
+ " sum_nll = -np.sum(logprobs_list)\n",
+ " token_count = len(logprobs_list)\n",
+ " return sum_nll, token_count\n",
+ "\n",
+ "\n",
+ "def calculate_perplexity_teacher(\n",
+ " model: str,\n",
+ " project_id: str,\n",
+ " region: str,\n",
+ " dataset: list[str],\n",
+ " max_workers: int = 10,\n",
+ ") -> None:\n",
+ " \"\"\"Calculate the Perplexity score of the teacher model using VMG MaaS API.\n",
+ "\n",
+ " Args:\n",
+ " model: Model ID for MaaS.\n",
+ " project_id: GCP Project ID.\n",
+ " region: GCP Region.\n",
+ " dataset: The dataset.\n",
+ " max_workers: Number of parallel workers.\n",
+ " \"\"\"\n",
+ " api_key = get_access_token()\n",
+ " if region == \"global\":\n",
+ " api_endpoint = \"aiplatform.googleapis.com\"\n",
+ " else:\n",
+ " api_endpoint = f\"{region}-aiplatform.googleapis.com\"\n",
+ " base_url = f\"https://{api_endpoint}/v1/projects/{project_id}/locations/{region}/endpoints/openapi\"\n",
+ " print(f\"Using MaaS base URL for the teacher model: {base_url}\")\n",
+ "\n",
+ " client = openai.OpenAI(\n",
+ " base_url=base_url,\n",
+ " api_key=api_key,\n",
+ " )\n",
+ " per_sequence_ppls = []\n",
+ " total_tokens = 0\n",
+ "\n",
+ " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
+ " records = [[{\"role\": \"user\", \"content\": text}] for text in dataset]\n",
+ "\n",
+ " futures = [\n",
+ " executor.submit(\n",
+ " calculate_perplexity_item,\n",
+ " record,\n",
+ " client,\n",
+ " model,\n",
+ " )\n",
+ " for record in records\n",
+ " ]\n",
+ " for future in tqdm.tqdm(\n",
+ " concurrent.futures.as_completed(futures), total=len(futures)\n",
+ " ):\n",
+ " sum_nll, token_count = future.result()\n",
+ " if token_count > 0:\n",
+ " total_tokens += token_count\n",
+ " per_sequence_ppls.append(np.exp(sum_nll / token_count))\n",
+ "\n",
+ " if total_tokens == 0:\n",
+ " print(\"\\n\\nNo perplexity results calculated for the teacher model.\")\n",
+ " return\n",
+ "\n",
+ " ppl = np.mean(per_sequence_ppls)\n",
+ " print(f\"\\n\\nTeacher model Perplexity for the given dataset: {ppl:.4f}\")\n",
+ "\n",
+ "\n",
+ "# Execute\n",
+ "teacher_maas_model_id = \"deepseek-ai/deepseek-r1-0528-maas\" # @param {type:\"string\"}\n",
+ "teacher_project = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n",
+ "teacher_region = \"us-central1\" # @param {type:\"string\"}\n",
+ "\n",
+ "calculate_perplexity_teacher(\n",
+ " teacher_maas_model_id,\n",
+ " teacher_project,\n",
+ " teacher_region,\n",
+ " formatted_dataset,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "EY8qkPP1W5rU"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown Perplexity calculation on the test dataset using HuggingFace for the student model. This section will require a GPU runtime environment to execute.\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import tqdm\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "\n",
+ "\n",
+ "def calculate_perplexity_student(model_name, dataset, enable_thinking=True) -> None:\n",
+ " \"\"\"Runs perplexity calculation on input dataset for student model using HuggingFace.\n",
+ "\n",
+ " Args:\n",
+ " model_name: HuggingFace model ID or local path.\n",
+ " dataset: The dataset to calculate ppl on.\n",
+ " enable_thinking: Whether to enable thinking in chat template.\n",
+ " \"\"\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " device_map=\"auto\",\n",
+ " torch_dtype=torch.bfloat16,\n",
+ " trust_remote_code=True,\n",
+ " )\n",
+ " model.eval()\n",
+ "\n",
+ " max_length = model.config.max_position_embeddings\n",
+ " loss_fct = torch.nn.CrossEntropyLoss(reduction=\"none\")\n",
+ "\n",
+ " per_sequence_ppls = []\n",
+ " total_tokens = 0\n",
+ "\n",
+ " for text in tqdm.tqdm(dataset):\n",
+ " encodings = tokenizer(text, return_tensors=\"pt\")\n",
+ " seq_len = encodings.input_ids.size(1)\n",
+ "\n",
+ " chunk_nlls = []\n",
+ " for begin_loc in range(0, seq_len, max_length):\n",
+ " end_loc = min(begin_loc + max_length, seq_len)\n",
+ " input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " outputs = model(input_ids)\n",
+ " logits = outputs.logits\n",
+ "\n",
+ " shift_logits = logits[..., :-1, :].contiguous()\n",
+ " shift_labels = input_ids[..., 1:].contiguous().to(logits.device)\n",
+ "\n",
+ " nll = loss_fct(\n",
+ " shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n",
+ " )\n",
+ " chunk_nlls.append(nll)\n",
+ "\n",
+ " all_nlls = torch.cat(chunk_nlls)\n",
+ " if all_nlls.numel() > 0:\n",
+ " sum_nll_seq = all_nlls.sum().item()\n",
+ " token_count_seq = all_nlls.numel()\n",
+ " per_sequence_ppls.append(np.exp(sum_nll_seq / token_count_seq))\n",
+ " total_tokens += token_count_seq\n",
+ "\n",
+ " if total_tokens == 0:\n",
+ " print(\"\\n\\nNo perplexity results calculated for the student model.\")\n",
+ " return\n",
+ "\n",
+ " ppl = np.mean(per_sequence_ppls)\n",
+ " print(f\"\\n\\nStudent model Perplexity for the given dataset: {ppl:.4f}\")\n",
+ "\n",
+ "\n",
+ "# Execute\n",
+ "student_model_id = \"Qwen/Qwen3-0.6B\" # @param {type:\"string\"}\n",
+ "enable_student_thinking = True # @param {type:\"boolean\"}\n",
+ "\n",
+ "calculate_perplexity_student(\n",
+ " student_model_id, formatted_dataset, enable_student_thinking\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "name": "model_garden_oss_distillation_feasibility.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}