|
611 | 611 | " raise ValueError(\"Do not override base_model flag here.\")\n", |
612 | 612 | "axolotl_flag_overrides.append(f\"--base_model={HF_MODEL_ID}\")\n", |
613 | 613 | "\n", |
614 | | - "# Set model_id and publisher. This is required for Vertex AI fine-tuning job and Vertex AI model deployment.\n", |
615 | | - "publisher = HF_MODEL_ID.split(\"/\")[0]\n", |
616 | | - "model_id = HF_MODEL_ID.split(\"/\")[1]\n", |
| 614 | + "base_model = axolotl_config[\"base_model\"]\n", |
| 615 | + "for overrides in axolotl_flag_overrides:\n", |
| 616 | + " if overrides.startswith(\"--base_model=\"):\n", |
| 617 | + " base_model = overrides.split(\"=\")[1]\n", |
| 618 | + " break\n", |
| 619 | + "publisher = base_model.split(\"/\")[0]\n", |
| 620 | + "model_id = base_model.split(\"/\")[1]\n", |
617 | 621 | "model_id = model_id.replace(\".\", \"-\")" |
618 | 622 | ] |
619 | 623 | }, |
|
731 | 735 | "CLI_PROMPT = \"What is car?\" # @param {type:\"string\"}\n", |
732 | 736 | "\n", |
733 | 737 | "if INFERENCE_METHOD == \"gradio\":\n", |
734 | | - " ! cd axolotl && export CUDA_VISIBLE_DEVICES=0 && axolotl inference --base-model=$HF_MODEL_ID $local_config_path --lora-model-dir=$AXOLOTL_OUTPUT_DIR --gradio\n", |
| 738 | + " ! cd axolotl && export CUDA_VISIBLE_DEVICES=0 && axolotl inference --base-model=$base_model $local_config_path --lora-model-dir=$AXOLOTL_OUTPUT_DIR --gradio\n", |
735 | 739 | "elif INFERENCE_METHOD == \"cli\":\n", |
736 | 740 | " assert CLI_PROMPT, \"CLI_PROMPT must be set if INFERENCE_METHOD is 'cli'.\"\n", |
737 | 741 | " env = os.environ.copy()\n", |
|
740 | 744 | " \"axolotl\",\n", |
741 | 745 | " \"inference\",\n", |
742 | 746 | " local_config_path,\n", |
743 | | - " f\"--base-model={HF_MODEL_ID}\",\n", |
| 747 | + " f\"--base-model={base_model}\",\n", |
744 | 748 | " f\"--lora-model-dir={AXOLOTL_OUTPUT_DIR}\",\n", |
745 | 749 | " ]\n", |
746 | 750 | " run_cmd_and_check_output(cmd, env, f\"{CLI_PROMPT}\\x04\", f\"{WORKING_DIR}/axolotl/\")\n", |
|
775 | 779 | " \"python3\",\n", |
776 | 780 | " \"-m\",\n", |
777 | 781 | " \"axolotl.cli.merge_lora\",\n", |
778 | | - " f\"--base-model={HF_MODEL_ID}\",\n", |
| 782 | + " f\"--base-model={base_model}\",\n", |
779 | 783 | " f\"--output-dir={AXOLOTL_OUTPUT_DIR}\",\n", |
780 | 784 | " local_config_path,\n", |
781 | 785 | "]\n", |
|
1068 | 1072 | " dtype: str | None = None,\n", |
1069 | 1073 | " enable_trust_remote_code: bool = False,\n", |
1070 | 1074 | " enable_torch_compile: bool = False,\n", |
| 1075 | + " torch_compile_max_bs: int | None = None,\n", |
| 1076 | + " attention_backend: str = \"\",\n", |
1071 | 1077 | " enable_flashinfer_mla: bool = False,\n", |
1072 | 1078 | " disable_cuda_graph: bool = False,\n", |
1073 | 1079 | " speculative_algorithm: str | None = None,\n", |
|
1115 | 1121 | "\n", |
1116 | 1122 | " if enable_torch_compile:\n", |
1117 | 1123 | " sglang_args.append(\"--enable-torch-compile\")\n", |
| 1124 | + " if torch_compile_max_bs:\n", |
| 1125 | + " sglang_args.append(f\"--torch-compile-max-bs={torch_compile_max_bs}\")\n", |
| 1126 | + "\n", |
| 1127 | + " if attention_backend:\n", |
| 1128 | + " sglang_args.append(f\"--attention-backend={attention_backend}\")\n", |
1118 | 1129 | "\n", |
1119 | 1130 | " if enable_flashinfer_mla:\n", |
1120 | 1131 | " sglang_args.append(\"--enable-flashinfer-mla\")\n", |
|
1144 | 1155 | " if enable_jit_deepgemm:\n", |
1145 | 1156 | " env_vars[\"SGL_ENABLE_JIT_DEEPGEMM\"] = \"1\"\n", |
1146 | 1157 | "\n", |
| 1158 | + " # HF_TOKEN is not a compulsory field and may not be defined.\n", |
| 1159 | + " try:\n", |
| 1160 | + " if HF_TOKEN:\n", |
| 1161 | + " env_vars[\"HF_TOKEN\"] = HF_TOKEN\n", |
| 1162 | + " except NameError:\n", |
| 1163 | + " pass\n", |
| 1164 | + "\n", |
1147 | 1165 | " model = aiplatform.Model.upload(\n", |
1148 | 1166 | " display_name=model_name,\n", |
1149 | 1167 | " serving_container_image_uri=SGLANG_DOCKER_URI,\n", |
|
1186 | 1204 | " \"maxReplicaCount\": 1,\n", |
1187 | 1205 | " },\n", |
1188 | 1206 | " \"system_labels\": {\n", |
1189 | | - " \"NOTEBOOK_NAME\": \"model_garden_axolotl_finetuning.ipynb\",\n", |
| 1207 | + " \"NOTEBOOK_NAME\": \"model_garden_axolotl_qwen3_finetuning.ipynb\",\n", |
1190 | 1208 | " \"NOTEBOOK_ENVIRONMENT\": common_util.get_deploy_source(),\n", |
1191 | 1209 | " },\n", |
1192 | 1210 | " },\n", |
|
1268 | 1286 | "prompt = \"<|im_start|>user What is the best way to diagnose and fix a flickering light in my house?<|im_end|><|im_start|>assistant\" # @param {type: \"string\"}\n", |
1269 | 1287 | "\n", |
1270 | 1288 | "# @markdown By default, Qwen3 has thinking capabilities enabled, similar to QwQ-32B. This means the model will use its reasoning abilities to enhance the quality of generated responses.\n", |
1271 | | - "# @markdown The model will generate think content wrapped in a \\...\\ block, followed by the final response.\n", |
| 1289 | + "# @markdown The model will generate think content wrapped in a \\<think>...\\</think> block, followed by the final response.\n", |
1272 | 1290 | "# @markdown `max_new_tokens` may need to be increased to accommodate the additional think content.\n", |
1273 | 1291 | "enable_thinking = True # @param {type:\"boolean\"}\n", |
1274 | 1292 | "if not enable_thinking:\n", |
1275 | | - " prompt += \"\"\n", |
| 1293 | + " prompt += \"<think></think>\"\n", |
1276 | 1294 | "\n", |
1277 | 1295 | "\n", |
1278 | 1296 | "max_new_tokens = 1024 # @param {type:\"integer\"}\n", |
|
1310 | 1328 | ")\n", |
1311 | 1329 | "\n", |
1312 | 1330 | "for prediction in response.predictions:\n", |
1313 | | - " print(prediction)\n", |
1314 | | - "\n", |
1315 | | - "# @markdown Click \"Show Code\" to see more details." |
| 1331 | + " print(prediction)" |
1316 | 1332 | ] |
1317 | 1333 | }, |
1318 | 1334 | { |
|
0 commit comments