Skip to content

Commit 74376f2

Browse files
vertex-mg-botcopybara-github
authored andcommitted
Refactor qwen3 axolotl notebook.
PiperOrigin-RevId: 769218301
1 parent abc0e9d commit 74376f2

1 file changed

Lines changed: 28 additions & 12 deletions

File tree

notebooks/community/model_garden/model_garden_axolotl_qwen3_finetuning.ipynb

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,13 @@
611611
" raise ValueError(\"Do not override base_model flag here.\")\n",
612612
"axolotl_flag_overrides.append(f\"--base_model={HF_MODEL_ID}\")\n",
613613
"\n",
614-
"# Set model_id and publisher. This is required for Vertex AI fine-tuning job and Vertex AI model deployment.\n",
615-
"publisher = HF_MODEL_ID.split(\"/\")[0]\n",
616-
"model_id = HF_MODEL_ID.split(\"/\")[1]\n",
614+
"base_model = axolotl_config[\"base_model\"]\n",
615+
"for overrides in axolotl_flag_overrides:\n",
616+
" if overrides.startswith(\"--base_model=\"):\n",
617+
" base_model = overrides.split(\"=\")[1]\n",
618+
" break\n",
619+
"publisher = base_model.split(\"/\")[0]\n",
620+
"model_id = base_model.split(\"/\")[1]\n",
617621
"model_id = model_id.replace(\".\", \"-\")"
618622
]
619623
},
@@ -731,7 +735,7 @@
731735
"CLI_PROMPT = \"What is car?\" # @param {type:\"string\"}\n",
732736
"\n",
733737
"if INFERENCE_METHOD == \"gradio\":\n",
734-
" ! cd axolotl && export CUDA_VISIBLE_DEVICES=0 && axolotl inference --base-model=$HF_MODEL_ID $local_config_path --lora-model-dir=$AXOLOTL_OUTPUT_DIR --gradio\n",
738+
" ! cd axolotl && export CUDA_VISIBLE_DEVICES=0 && axolotl inference --base-model=$base_model $local_config_path --lora-model-dir=$AXOLOTL_OUTPUT_DIR --gradio\n",
735739
"elif INFERENCE_METHOD == \"cli\":\n",
736740
" assert CLI_PROMPT, \"CLI_PROMPT must be set if INFERENCE_METHOD is 'cli'.\"\n",
737741
" env = os.environ.copy()\n",
@@ -740,7 +744,7 @@
740744
" \"axolotl\",\n",
741745
" \"inference\",\n",
742746
" local_config_path,\n",
743-
" f\"--base-model={HF_MODEL_ID}\",\n",
747+
" f\"--base-model={base_model}\",\n",
744748
" f\"--lora-model-dir={AXOLOTL_OUTPUT_DIR}\",\n",
745749
" ]\n",
746750
" run_cmd_and_check_output(cmd, env, f\"{CLI_PROMPT}\\x04\", f\"{WORKING_DIR}/axolotl/\")\n",
@@ -775,7 +779,7 @@
775779
" \"python3\",\n",
776780
" \"-m\",\n",
777781
" \"axolotl.cli.merge_lora\",\n",
778-
" f\"--base-model={HF_MODEL_ID}\",\n",
782+
" f\"--base-model={base_model}\",\n",
779783
" f\"--output-dir={AXOLOTL_OUTPUT_DIR}\",\n",
780784
" local_config_path,\n",
781785
"]\n",
@@ -1068,6 +1072,8 @@
10681072
" dtype: str | None = None,\n",
10691073
" enable_trust_remote_code: bool = False,\n",
10701074
" enable_torch_compile: bool = False,\n",
1075+
" torch_compile_max_bs: int | None = None,\n",
1076+
" attention_backend: str = \"\",\n",
10711077
" enable_flashinfer_mla: bool = False,\n",
10721078
" disable_cuda_graph: bool = False,\n",
10731079
" speculative_algorithm: str | None = None,\n",
@@ -1115,6 +1121,11 @@
11151121
"\n",
11161122
" if enable_torch_compile:\n",
11171123
" sglang_args.append(\"--enable-torch-compile\")\n",
1124+
" if torch_compile_max_bs:\n",
1125+
" sglang_args.append(f\"--torch-compile-max-bs={torch_compile_max_bs}\")\n",
1126+
"\n",
1127+
" if attention_backend:\n",
1128+
" sglang_args.append(f\"--attention-backend={attention_backend}\")\n",
11181129
"\n",
11191130
" if enable_flashinfer_mla:\n",
11201131
" sglang_args.append(\"--enable-flashinfer-mla\")\n",
@@ -1144,6 +1155,13 @@
11441155
" if enable_jit_deepgemm:\n",
11451156
" env_vars[\"SGL_ENABLE_JIT_DEEPGEMM\"] = \"1\"\n",
11461157
"\n",
1158+
" # HF_TOKEN is not a compulsory field and may not be defined.\n",
1159+
" try:\n",
1160+
" if HF_TOKEN:\n",
1161+
" env_vars[\"HF_TOKEN\"] = HF_TOKEN\n",
1162+
" except NameError:\n",
1163+
" pass\n",
1164+
"\n",
11471165
" model = aiplatform.Model.upload(\n",
11481166
" display_name=model_name,\n",
11491167
" serving_container_image_uri=SGLANG_DOCKER_URI,\n",
@@ -1186,7 +1204,7 @@
11861204
" \"maxReplicaCount\": 1,\n",
11871205
" },\n",
11881206
" \"system_labels\": {\n",
1189-
" \"NOTEBOOK_NAME\": \"model_garden_axolotl_finetuning.ipynb\",\n",
1207+
" \"NOTEBOOK_NAME\": \"model_garden_axolotl_qwen3_finetuning.ipynb\",\n",
11901208
" \"NOTEBOOK_ENVIRONMENT\": common_util.get_deploy_source(),\n",
11911209
" },\n",
11921210
" },\n",
@@ -1268,11 +1286,11 @@
12681286
"prompt = \"<|im_start|>user What is the best way to diagnose and fix a flickering light in my house?<|im_end|><|im_start|>assistant\" # @param {type: \"string\"}\n",
12691287
"\n",
12701288
"# @markdown By default, Qwen3 has thinking capabilities enabled, similar to QwQ-32B. This means the model will use its reasoning abilities to enhance the quality of generated responses.\n",
1271-
"# @markdown The model will generate think content wrapped in a \\...\\ block, followed by the final response.\n",
1289+
"# @markdown The model will generate think content wrapped in a \\<think>...\\</think> block, followed by the final response.\n",
12721290
"# @markdown `max_new_tokens` may need to be increased to accommodate the additional think content.\n",
12731291
"enable_thinking = True # @param {type:\"boolean\"}\n",
12741292
"if not enable_thinking:\n",
1275-
" prompt += \"\"\n",
1293+
" prompt += \"<think></think>\"\n",
12761294
"\n",
12771295
"\n",
12781296
"max_new_tokens = 1024 # @param {type:\"integer\"}\n",
@@ -1310,9 +1328,7 @@
13101328
")\n",
13111329
"\n",
13121330
"for prediction in response.predictions:\n",
1313-
" print(prediction)\n",
1314-
"\n",
1315-
"# @markdown Click \"Show Code\" to see more details."
1331+
" print(prediction)"
13161332
]
13171333
},
13181334
{

0 commit comments

Comments
 (0)