|
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | 5 | "execution_count": null, |
6 | | - "id": "Pr9TgOcV9vAXeqGiyTaTI5kS", |
7 | 6 | "metadata": { |
8 | 7 | "cellView": "form", |
9 | 8 | "id": "Pr9TgOcV9vAXeqGiyTaTI5kS" |
|
27 | 26 | }, |
28 | 27 | { |
29 | 28 | "cell_type": "markdown", |
30 | | - "id": "M1CpgYundFwz", |
31 | 29 | "metadata": { |
32 | 30 | "id": "M1CpgYundFwz" |
33 | 31 | }, |
|
50 | 48 | }, |
51 | 49 | { |
52 | 50 | "cell_type": "markdown", |
53 | | - "id": "t2jj2XOgkS4F", |
54 | 51 | "metadata": { |
55 | 52 | "id": "t2jj2XOgkS4F" |
56 | 53 | }, |
|
124 | 121 | { |
125 | 122 | "cell_type": "code", |
126 | 123 | "execution_count": null, |
127 | | - "id": "XMf-T58TkDy1", |
128 | 124 | "metadata": { |
129 | 125 | "cellView": "form", |
130 | 126 | "id": "XMf-T58TkDy1" |
|
155 | 151 | { |
156 | 152 | "cell_type": "code", |
157 | 153 | "execution_count": null, |
158 | | - "id": "IKGTaN84p8rX", |
159 | 154 | "metadata": { |
160 | 155 | "cellView": "form", |
161 | 156 | "id": "IKGTaN84p8rX" |
162 | 157 | }, |
163 | 158 | "outputs": [], |
164 | 159 | "source": [ |
165 | | - "# @title # Chat completion for text-only models { vertical-output: true}\n", |
166 | | - "# @markdown You may send prompts to the model server for prediction.\n", |
| 160 | + "# @title # Chat completion for text-only models {vertical-output: true}\n", |
| 161 | + "# @markdown Run cell to prompt the model server for prediction.\n", |
167 | 162 | "# @markdown\n", |
168 | 163 | "# @markdown * **user_prompt (string):** This is the text prompt you provide to the language model. It's the question or instruction e (e.g., \"Explain neural networks\").\n", |
169 | 164 | "# @markdown * **temperature (number):** This parameter controls the randomness of the model's output. It influences how the model selects the next token in the sequence it generates. Typical values range from 0.2 to 1.0.\n", |
|
180 | 175 | "REGION = \"\" # @param {type:\"string\", isTemplate:true}\n", |
181 | 176 | "NAMESPACE = \"\" # @param {type:\"string\", isTemplate:true}\n", |
182 | 177 | "DEPLOYMENT = \"\" # @param {type:\"string\", isTemplate:true}\n", |
183 | | - "DEPLOYMENT_APP_LABEL = \"\" # @param {type:\"string\", isTemplate:true}\n", |
| 178 | + "POD_PORT = \"\" # @param {type:\"string\", isTemplate:true}\n", |
184 | 179 | "\n", |
185 | | - "SERVICE = f\"{DEPLOYMENT}-service\"\n", |
186 | 180 | "\n", |
187 | | - "\n", |
188 | | - "def _run_kubectl(cmd):\n", |
189 | | - " \"\"\"Executes a kubectl command and returns its stdout.\"\"\"\n", |
190 | | - " result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=60)\n", |
191 | | - " return result.stdout.strip()\n", |
192 | | - "\n", |
193 | | - "\n", |
194 | | - "def fetch_cluster_credential(cluster, region, project_id):\n", |
| 181 | + "def _run_kubectl(cmd, timeout=60):\n", |
| 182 | + " \"\"\"Executes a kubectl command.\"\"\"\n", |
195 | 183 | " try:\n", |
196 | | - " # Ensure credentials for the target cluster\n", |
197 | | - " cred_cmd = [\n", |
198 | | - " \"gcloud\",\n", |
199 | | - " \"container\",\n", |
200 | | - " \"clusters\",\n", |
201 | | - " \"get-credentials\",\n", |
202 | | - " cluster,\n", |
203 | | - " f\"--location={region}\",\n", |
204 | | - " f\"--project={project_id}\",\n", |
205 | | - " ]\n", |
206 | | - " _run_kubectl(cred_cmd)\n", |
207 | | - " except Exception as e:\n", |
208 | | - " # Original code prints error and returns empty dict\n", |
209 | | - " print(f\"Error fetching cluster credentials: {e}\")\n", |
210 | | - " return {}\n", |
211 | | - "\n", |
212 | | - "\n", |
213 | | - "def get_deployment_pod_name(deployment, namespace, deployment_app_label):\n", |
214 | | - " \"\"\"Finds the running pod name for a given deployment and namespace.\"\"\"\n", |
| 184 | + " result = subprocess.run(\n", |
| 185 | + " cmd, capture_output=True, text=True, check=True, timeout=timeout\n", |
| 186 | + " )\n", |
| 187 | + " return result.stdout.strip()\n", |
| 188 | + " except subprocess.CalledProcessError as e:\n", |
| 189 | + " raise RuntimeError(\n", |
| 190 | + " f\"Kubectl command failed: {' '.join(e.cmd)}\\nStderr: {e.stderr}\"\n", |
| 191 | + " ) from e\n", |
| 192 | + " except subprocess.TimeoutExpired as e:\n", |
| 193 | + " raise RuntimeError(f\"Kubectl command timed out: {' '.join(e.cmd)}\") from e\n", |
| 194 | + "\n", |
| 195 | + "\n", |
| 196 | + "def fetch_cluster_credentials(cluster, region, project_id):\n", |
| 197 | + " \"\"\"Ensures credentials for the target GKE cluster.\"\"\"\n", |
| 198 | + " cred_cmd = [\n", |
| 199 | + " \"gcloud\",\n", |
| 200 | + " \"container\",\n", |
| 201 | + " \"clusters\",\n", |
| 202 | + " \"get-credentials\",\n", |
| 203 | + " cluster,\n", |
| 204 | + " f\"--location={region}\",\n", |
| 205 | + " f\"--project={project_id}\",\n", |
| 206 | + " ]\n", |
| 207 | + " _run_kubectl(cred_cmd)\n", |
215 | 208 | "\n", |
| 209 | + "\n", |
| 210 | + "def get_deployment_selector_labels(deployment_name, namespace):\n", |
| 211 | + " \"\"\"Retrieves the selector labels for a given Kubernetes deployment.\"\"\"\n", |
216 | 212 | " cmd = [\n", |
217 | 213 | " \"kubectl\",\n", |
218 | 214 | " \"get\",\n", |
219 | | - " \"pods\",\n", |
| 215 | + " \"deployment\",\n", |
| 216 | + " deployment_name,\n", |
220 | 217 | " \"-n\",\n", |
221 | 218 | " namespace,\n", |
222 | 219 | " \"-o\",\n", |
223 | 220 | " \"json\",\n", |
224 | | - " \"-l\",\n", |
225 | | - " f\"app={deployment_app_label}\",\n", |
226 | | - " \"--field-selector=status.phase=Running\",\n", |
227 | 221 | " ]\n", |
228 | | - " try:\n", |
229 | | - " pods_json = _run_kubectl(cmd)\n", |
230 | | - " pods = json.loads(pods_json)\n", |
231 | | - " if pods.get(\"items\"):\n", |
232 | | - " return pods[\"items\"][0][\"metadata\"][\"name\"]\n", |
233 | | - " print(f\"No running pods found for {deployment} in {namespace}.\")\n", |
234 | | - " return None\n", |
235 | | - " except (\n", |
236 | | - " subprocess.CalledProcessError,\n", |
237 | | - " json.JSONDecodeError,\n", |
238 | | - " IndexError,\n", |
239 | | - " KeyError,\n", |
240 | | - " ) as e:\n", |
241 | | - " print(f\"Error getting pod name for {deployment} in {namespace}: {e}\")\n", |
242 | | - " return None\n", |
243 | | - "\n", |
244 | | - "\n", |
245 | | - "def check_inference_label(pod_name, namespace):\n", |
246 | | - " \"\"\"Checks if the specified pod has the vLLM inference server label.\"\"\"\n", |
| 222 | + " deployment_json = _run_kubectl(cmd)\n", |
| 223 | + " deployment_data = json.loads(deployment_json)\n", |
| 224 | + "\n", |
| 225 | + " selector_labels = (\n", |
| 226 | + " deployment_data.get(\"spec\", {}).get(\"selector\", {}).get(\"matchLabels\")\n", |
| 227 | + " )\n", |
| 228 | + " if not selector_labels:\n", |
| 229 | + " raise RuntimeError(\n", |
| 230 | + " f\"No selector labels found for deployment '{deployment_name}' in\"\n", |
| 231 | + " f\" namespace '{namespace}'.\"\n", |
| 232 | + " )\n", |
| 233 | + " return selector_labels\n", |
247 | 234 | "\n", |
248 | | - " cmd = [\"kubectl\", \"get\", \"pod\", pod_name, \"-n\", namespace, \"-o\", \"json\"]\n", |
249 | | - " try:\n", |
250 | | - " pod_json = _run_kubectl(cmd)\n", |
251 | | - " labels = json.loads(pod_json).get(\"metadata\", {}).get(\"labels\", {})\n", |
252 | | - " return labels.get(\"ai.gke.io/inference-server\") == \"vllm\"\n", |
253 | | - " except (subprocess.CalledProcessError, json.JSONDecodeError, KeyError) as e:\n", |
254 | | - " print(f\"Error checking labels for pod {pod_name} in {namespace}: {e}\")\n", |
255 | | - " return False\n", |
256 | 235 | "\n", |
| 236 | + "def get_running_pod_name(deployment_name, namespace):\n", |
| 237 | + " \"\"\"Retrieves the name of a running pod associated with a deployment.\"\"\"\n", |
| 238 | + " selector_labels = get_deployment_selector_labels(deployment_name, namespace)\n", |
| 239 | + " label_selector_str = \",\".join(f\"{k}={v}\" for k, v in selector_labels.items())\n", |
257 | 240 | "\n", |
258 | | - "def get_service_endpoint(service, namespace):\n", |
259 | | - " \"\"\"Retrieve the service endpoint of the deployment\"\"\"\n", |
260 | | - " endpoint_cmd = [\n", |
| 241 | + " cmd = [\n", |
261 | 242 | " \"kubectl\",\n", |
262 | 243 | " \"get\",\n", |
263 | | - " \"endpoints\",\n", |
264 | | - " service,\n", |
| 244 | + " \"pods\",\n", |
265 | 245 | " \"-n\",\n", |
266 | 246 | " namespace,\n", |
| 247 | + " \"-o\",\n", |
| 248 | + " \"json\",\n", |
| 249 | + " \"-l\",\n", |
| 250 | + " label_selector_str,\n", |
| 251 | + " \"--field-selector=status.phase=Running\",\n", |
267 | 252 | " ]\n", |
268 | | - " try:\n", |
269 | | - " endpoint_output = _run_kubectl(endpoint_cmd).splitlines()\n", |
270 | | - " if len(endpoint_output) < 2 or len(endpoint_output[1].split()) < 2:\n", |
271 | | - " print(f\"Endpoint data incomplete for {service}.\")\n", |
272 | | - " return None\n", |
273 | | - " endpoint = endpoint_output[1].split()[\n", |
274 | | - " 1\n", |
275 | | - " ] # Assumes format: NAME ENDPOINTS AGE -> service ip:port,... age\n", |
276 | | - " return endpoint\n", |
277 | | - " except subprocess.CalledProcessError as e:\n", |
278 | | - " print(f\"Error getting endpoints for {service}: {e}\")\n", |
279 | | - " return None\n", |
| 253 | + " pods_json = _run_kubectl(cmd)\n", |
| 254 | + " pods_data = json.loads(pods_json)\n", |
| 255 | + "\n", |
| 256 | + " if not pods_data.get(\"items\"):\n", |
| 257 | + " raise RuntimeError(\n", |
| 258 | + " f\"No running pods found for deployment '{deployment_name}' in namespace\"\n", |
| 259 | + " f\" '{namespace}' with selector '{label_selector_str}'.\"\n", |
| 260 | + " )\n", |
| 261 | + " return pods_data[\"items\"][0][\"metadata\"][\"name\"]\n", |
280 | 262 | "\n", |
281 | 263 | "\n", |
282 | | - "def process_response(request, pod_name, pod_endpoint, is_vllm_inference, namespace):\n", |
283 | | - " \"\"\"Sends a request to the pod and processes the response.\"\"\"\n", |
| 264 | + "def check_vllm_inference_label(pod_name, namespace):\n", |
| 265 | + " \"\"\"Checks if the specified pod has the vLLM inference server label.\"\"\"\n", |
| 266 | + " cmd = [\"kubectl\", \"get\", \"pod\", pod_name, \"-n\", namespace, \"-o\", \"json\"]\n", |
| 267 | + " pod_json = _run_kubectl(cmd)\n", |
| 268 | + " labels = json.loads(pod_json).get(\"metadata\", {}).get(\"labels\", {})\n", |
| 269 | + " return labels.get(\"ai.gke.io/inference-server\") == \"vllm\"\n", |
284 | 270 | "\n", |
285 | | - " json_data_escaped = json.dumps(request).replace(\"'\", \"'\\\\''\")\n", |
| 271 | + "\n", |
| 272 | + "def send_inference_request(\n", |
| 273 | + " request_payload, pod_name, pod_port, is_vllm_inference, namespace\n", |
| 274 | + "):\n", |
| 275 | + " \"\"\"Sends an inference request to the specified pod and returns the model's response.\"\"\"\n", |
| 276 | + " json_data_escaped = json.dumps(request_payload).replace(\"'\", \"'\\\\''\")\n", |
286 | 277 | " curl_cmd = (\n", |
287 | 278 | " f\"kubectl exec -n {namespace} -t {pod_name} -- curl -s -X POST\"\n", |
288 | | - " f' http://{pod_endpoint}/generate -H \"Content-Type: application/json\"'\n", |
289 | | - " f\" -d '{json_data_escaped}' 2> /dev/null\"\n", |
| 279 | + " f' http://localhost:{pod_port}/generate -H \"Content-Type:'\n", |
| 280 | + " ' application/json\"'\n", |
| 281 | + " f\" -d '{json_data_escaped}' 2\u003e /dev/null\"\n", |
290 | 282 | " )\n", |
| 283 | + "\n", |
| 284 | + " response_raw = _run_kubectl([\"bash\", \"-c\", curl_cmd])\n", |
| 285 | + "\n", |
| 286 | + " if not response_raw:\n", |
| 287 | + " raise RuntimeError(f\"Empty response received from pod '{pod_name}'.\")\n", |
| 288 | + "\n", |
291 | 289 | " try:\n", |
292 | | - " response_raw = _run_kubectl([\"bash\", \"-c\", curl_cmd])\n", |
293 | | - " if not response_raw:\n", |
294 | | - " return f\"Error: Empty response from pod {pod_name}.\"\n", |
295 | 290 | " first_line = response_raw.splitlines()[0]\n", |
296 | 291 | " data = json.loads(first_line)\n", |
| 292 | + " except json.JSONDecodeError as e:\n", |
| 293 | + " raise RuntimeError(\n", |
| 294 | + " f\"Failed to decode JSON response from pod: {e}. Raw: {response_raw}\"\n", |
| 295 | + " ) from e\n", |
| 296 | + " except IndexError:\n", |
| 297 | + " raise RuntimeError(\n", |
| 298 | + " f\"Unexpected empty response line from pod. Raw: {response_raw}\"\n", |
| 299 | + " )\n", |
| 300 | + "\n", |
| 301 | + " if is_vllm_inference:\n", |
| 302 | + " predictions = data.get(\"predictions\")\n", |
| 303 | + " if isinstance(predictions, list) and predictions:\n", |
| 304 | + " return predictions[0]\n", |
| 305 | + " raise RuntimeError(f\"Unexpected vLLM response format. Raw data: {data}\")\n", |
| 306 | + " else: # TGI format\n", |
| 307 | + " generated_text = data.get(\"generated_text\")\n", |
| 308 | + " if generated_text is not None:\n", |
| 309 | + " return generated_text\n", |
| 310 | + " raise RuntimeError(f\"Unexpected TGI response format. Raw data: {data}\")\n", |
| 311 | + "\n", |
| 312 | + "\n", |
| 313 | + "# --- Main Execution Logic ---\n", |
| 314 | + "\n", |
| 315 | + "\n", |
| 316 | + "def execute_chat_completion(\n", |
| 317 | + " deployment_name, namespace, pod_port, user_prompt, temperature, max_tokens\n", |
| 318 | + "):\n", |
| 319 | + " \"\"\"Executes the full chat completion process: fetches credentials, finds a pod,\n", |
| 320 | + "\n", |
| 321 | + " determines inference type, sends a request, and returns the response.\n", |
| 322 | + " \"\"\"\n", |
| 323 | + " display(Markdown(\"Establishing cluster credentials...\"))\n", |
| 324 | + " fetch_cluster_credentials(CLUSTER, REGION, PROJECT_ID)\n", |
| 325 | + "\n", |
| 326 | + " display(Markdown(\"Retrieving pod information...\"))\n", |
| 327 | + " pod_name = get_running_pod_name(deployment_name, namespace)\n", |
| 328 | + " display(Markdown(f\"Successfully identified pod: `{pod_name}`\"))\n", |
| 329 | + "\n", |
| 330 | + " is_vllm = check_vllm_inference_label(pod_name, namespace)\n", |
| 331 | + "\n", |
| 332 | + " request_payload = {\n", |
| 333 | + " \"max_tokens\": max_tokens,\n", |
| 334 | + " \"temperature\": temperature,\n", |
| 335 | + " \"prompt\" if is_vllm else \"inputs\": user_prompt,\n", |
| 336 | + " }\n", |
| 337 | + " display(Markdown(\"Sending inference request...\"))\n", |
| 338 | + " response = send_inference_request(\n", |
| 339 | + " request_payload, pod_name, pod_port, is_vllm, namespace\n", |
| 340 | + " )\n", |
297 | 341 | "\n", |
298 | | - " if is_vllm_inference: # vLLM format\n", |
299 | | - " predictions = data.get(\"predictions\")\n", |
300 | | - " if isinstance(predictions, (list, tuple)) and predictions:\n", |
301 | | - " return predictions[0]\n", |
302 | | - " return f\"Error: Unexpected vLLM format. Raw: {first_line}\"\n", |
303 | | - " else: # TGI format\n", |
304 | | - " generated_text = data.get(\"generated_text\")\n", |
305 | | - " if generated_text is not None:\n", |
306 | | - " return generated_text\n", |
307 | | - " return f\"Error: Unexpected TGI format. Raw: {first_line}\"\n", |
308 | | - " except Exception as e:\n", |
309 | | - " return f\"Unexpected error during response processing: {e}\"\n", |
| 342 | + " return response\n", |
310 | 343 | "\n", |
311 | 344 | "\n", |
312 | 345 | "# --- Widgets Setup ---\n", |
|
330 | 363 | "\n", |
331 | 364 | "# --- Submit Button Logic ---\n", |
332 | 365 | "def on_submit_clicked(b):\n", |
333 | | - " \"\"\"Handles the submit button click event.\"\"\"\n", |
334 | 366 | " with output_area_response:\n", |
335 | 367 | " clear_output()\n", |
| 368 | + " display(Markdown(\"Loading...\"))\n", |
336 | 369 | "\n", |
337 | | - " fetch_cluster_credential(CLUSTER, REGION, PROJECT_ID)\n", |
338 | | - "\n", |
339 | | - " # retrieve deployment pod\n", |
340 | | - " pod_name = get_deployment_pod_name(DEPLOYMENT, NAMESPACE, DEPLOYMENT_APP_LABEL)\n", |
341 | | - " if not pod_name:\n", |
342 | | - " display(\n", |
343 | | - " Markdown(f\"**Error:** Could not find running pod for `{DEPLOYMENT}`.\")\n", |
344 | | - " )\n", |
345 | | - " return\n", |
346 | | - "\n", |
347 | | - " # build the request message\n", |
348 | | - " is_vllm = check_inference_label(pod_name, NAMESPACE)\n", |
349 | | - " request = {\n", |
350 | | - " \"max_tokens\": max_tokens_widget.value,\n", |
351 | | - " \"temperature\": temperature_widget.value,\n", |
352 | | - " \"prompt\" if is_vllm else \"inputs\": user_prompt_widget.value,\n", |
353 | | - " }\n", |
354 | | - "\n", |
355 | | - " # retrieve service endpoint for the deployment\n", |
356 | | - " endpoint = get_service_endpoint(SERVICE, NAMESPACE)\n", |
357 | | - " if not endpoint:\n", |
358 | | - " display(Markdown(f\"**Error getting endpoints for `{SERVICE}`:**\\n\"))\n", |
359 | | - " return\n", |
360 | | - "\n", |
361 | | - " # prompt test the deployment endpoint\n", |
362 | 370 | " try:\n", |
363 | | - " response = process_response(request, pod_name, endpoint, is_vllm, NAMESPACE)\n", |
364 | | - " display(Markdown(f\"**Response:**\\n\\n{response}\"))\n", |
| 371 | + " model_response = execute_chat_completion(\n", |
| 372 | + " DEPLOYMENT,\n", |
| 373 | + " NAMESPACE,\n", |
| 374 | + " POD_PORT,\n", |
| 375 | + " user_prompt_widget.value,\n", |
| 376 | + " temperature_widget.value,\n", |
| 377 | + " max_tokens_widget.value,\n", |
| 378 | + " )\n", |
| 379 | + " clear_output()\n", |
| 380 | + " display(Markdown(f\"**Response:**\\n\\n{model_response}\"))\n", |
365 | 381 | " except Exception as e:\n", |
366 | | - " display(Markdown(f\"**Unexpected Error:**\\n```\\n{e}\\n```\"))\n", |
| 382 | + " clear_output()\n", |
| 383 | + " display(Markdown(f\"**An error occurred:**\\n```\\n{e}\\n```\"))\n", |
367 | 384 | "\n", |
368 | 385 | "\n", |
369 | 386 | "# --- Display Widgets ---\n", |
|
379 | 396 | }, |
380 | 397 | { |
381 | 398 | "cell_type": "markdown", |
382 | | - "id": "5b6ZM2K3fux0", |
383 | 399 | "metadata": { |
384 | 400 | "id": "5b6ZM2K3fux0" |
385 | 401 | }, |
|
453 | 469 | "metadata": { |
454 | 470 | "colab": { |
455 | 471 | "name": "gke_model_ui_deployment_notebook_auto.ipynb", |
456 | | - "toc_visible": true |
| 472 | + "provenance": [] |
457 | 473 | }, |
458 | 474 | "kernelspec": { |
459 | 475 | "display_name": "Python 3", |
|
0 commit comments