Skip to content

Commit a72b7bc

Browse files
vertex-mg-botcopybara-github
authored andcommitted
The main changes:
- `get_deployment_pod_name` now extract the app selector to query the pods - remove dependency to service, using instead pod port instead - adds a `POD_PORT` as template variable to allow to pass the port from the UI PiperOrigin-RevId: 771205143
1 parent e1cd3ce commit a72b7bc

1 file changed

Lines changed: 151 additions & 135 deletions

File tree

notebooks/community/model_garden/gke_model_ui_deployment_notebook_auto.ipynb

Lines changed: 151 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"id": "Pr9TgOcV9vAXeqGiyTaTI5kS",
76
"metadata": {
87
"cellView": "form",
98
"id": "Pr9TgOcV9vAXeqGiyTaTI5kS"
@@ -27,7 +26,6 @@
2726
},
2827
{
2928
"cell_type": "markdown",
30-
"id": "M1CpgYundFwz",
3129
"metadata": {
3230
"id": "M1CpgYundFwz"
3331
},
@@ -50,7 +48,6 @@
5048
},
5149
{
5250
"cell_type": "markdown",
53-
"id": "t2jj2XOgkS4F",
5451
"metadata": {
5552
"id": "t2jj2XOgkS4F"
5653
},
@@ -124,7 +121,6 @@
124121
{
125122
"cell_type": "code",
126123
"execution_count": null,
127-
"id": "XMf-T58TkDy1",
128124
"metadata": {
129125
"cellView": "form",
130126
"id": "XMf-T58TkDy1"
@@ -155,15 +151,14 @@
155151
{
156152
"cell_type": "code",
157153
"execution_count": null,
158-
"id": "IKGTaN84p8rX",
159154
"metadata": {
160155
"cellView": "form",
161156
"id": "IKGTaN84p8rX"
162157
},
163158
"outputs": [],
164159
"source": [
165-
"# @title # Chat completion for text-only models { vertical-output: true}\n",
166-
"# @markdown You may send prompts to the model server for prediction.\n",
160+
"# @title # Chat completion for text-only models {vertical-output: true}\n",
161+
"# @markdown Run cell to prompt the model server for prediction.\n",
167162
"# @markdown\n",
168163
"# @markdown * **user_prompt (string):** This is the text prompt you provide to the language model. It's the question or instruction e (e.g., \"Explain neural networks\").\n",
169164
"# @markdown * **temperature (number):** This parameter controls the randomness of the model's output. It influences how the model selects the next token in the sequence it generates. Typical values range from 0.2 to 1.0.\n",
@@ -180,133 +175,171 @@
180175
"REGION = \"\" # @param {type:\"string\", isTemplate:true}\n",
181176
"NAMESPACE = \"\" # @param {type:\"string\", isTemplate:true}\n",
182177
"DEPLOYMENT = \"\" # @param {type:\"string\", isTemplate:true}\n",
183-
"DEPLOYMENT_APP_LABEL = \"\" # @param {type:\"string\", isTemplate:true}\n",
178+
"POD_PORT = \"\" # @param {type:\"string\", isTemplate:true}\n",
184179
"\n",
185-
"SERVICE = f\"{DEPLOYMENT}-service\"\n",
186180
"\n",
187-
"\n",
188-
"def _run_kubectl(cmd):\n",
189-
" \"\"\"Executes a kubectl command and returns its stdout.\"\"\"\n",
190-
" result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=60)\n",
191-
" return result.stdout.strip()\n",
192-
"\n",
193-
"\n",
194-
"def fetch_cluster_credential(cluster, region, project_id):\n",
181+
"def _run_kubectl(cmd, timeout=60):\n",
182+
" \"\"\"Executes a kubectl command.\"\"\"\n",
195183
" try:\n",
196-
" # Ensure credentials for the target cluster\n",
197-
" cred_cmd = [\n",
198-
" \"gcloud\",\n",
199-
" \"container\",\n",
200-
" \"clusters\",\n",
201-
" \"get-credentials\",\n",
202-
" cluster,\n",
203-
" f\"--location={region}\",\n",
204-
" f\"--project={project_id}\",\n",
205-
" ]\n",
206-
" _run_kubectl(cred_cmd)\n",
207-
" except Exception as e:\n",
208-
" # Original code prints error and returns empty dict\n",
209-
" print(f\"Error fetching cluster credentials: {e}\")\n",
210-
" return {}\n",
211-
"\n",
212-
"\n",
213-
"def get_deployment_pod_name(deployment, namespace, deployment_app_label):\n",
214-
" \"\"\"Finds the running pod name for a given deployment and namespace.\"\"\"\n",
184+
" result = subprocess.run(\n",
185+
" cmd, capture_output=True, text=True, check=True, timeout=timeout\n",
186+
" )\n",
187+
" return result.stdout.strip()\n",
188+
" except subprocess.CalledProcessError as e:\n",
189+
" raise RuntimeError(\n",
190+
" f\"Kubectl command failed: {' '.join(e.cmd)}\\nStderr: {e.stderr}\"\n",
191+
" ) from e\n",
192+
" except subprocess.TimeoutExpired as e:\n",
193+
" raise RuntimeError(f\"Kubectl command timed out: {' '.join(e.cmd)}\") from e\n",
194+
"\n",
195+
"\n",
196+
"def fetch_cluster_credentials(cluster, region, project_id):\n",
197+
" \"\"\"Ensures credentials for the target GKE cluster.\"\"\"\n",
198+
" cred_cmd = [\n",
199+
" \"gcloud\",\n",
200+
" \"container\",\n",
201+
" \"clusters\",\n",
202+
" \"get-credentials\",\n",
203+
" cluster,\n",
204+
" f\"--location={region}\",\n",
205+
" f\"--project={project_id}\",\n",
206+
" ]\n",
207+
" _run_kubectl(cred_cmd)\n",
215208
"\n",
209+
"\n",
210+
"def get_deployment_selector_labels(deployment_name, namespace):\n",
211+
" \"\"\"Retrieves the selector labels for a given Kubernetes deployment.\"\"\"\n",
216212
" cmd = [\n",
217213
" \"kubectl\",\n",
218214
" \"get\",\n",
219-
" \"pods\",\n",
215+
" \"deployment\",\n",
216+
" deployment_name,\n",
220217
" \"-n\",\n",
221218
" namespace,\n",
222219
" \"-o\",\n",
223220
" \"json\",\n",
224-
" \"-l\",\n",
225-
" f\"app={deployment_app_label}\",\n",
226-
" \"--field-selector=status.phase=Running\",\n",
227221
" ]\n",
228-
" try:\n",
229-
" pods_json = _run_kubectl(cmd)\n",
230-
" pods = json.loads(pods_json)\n",
231-
" if pods.get(\"items\"):\n",
232-
" return pods[\"items\"][0][\"metadata\"][\"name\"]\n",
233-
" print(f\"No running pods found for {deployment} in {namespace}.\")\n",
234-
" return None\n",
235-
" except (\n",
236-
" subprocess.CalledProcessError,\n",
237-
" json.JSONDecodeError,\n",
238-
" IndexError,\n",
239-
" KeyError,\n",
240-
" ) as e:\n",
241-
" print(f\"Error getting pod name for {deployment} in {namespace}: {e}\")\n",
242-
" return None\n",
243-
"\n",
244-
"\n",
245-
"def check_inference_label(pod_name, namespace):\n",
246-
" \"\"\"Checks if the specified pod has the vLLM inference server label.\"\"\"\n",
222+
" deployment_json = _run_kubectl(cmd)\n",
223+
" deployment_data = json.loads(deployment_json)\n",
224+
"\n",
225+
" selector_labels = (\n",
226+
" deployment_data.get(\"spec\", {}).get(\"selector\", {}).get(\"matchLabels\")\n",
227+
" )\n",
228+
" if not selector_labels:\n",
229+
" raise RuntimeError(\n",
230+
" f\"No selector labels found for deployment '{deployment_name}' in\"\n",
231+
" f\" namespace '{namespace}'.\"\n",
232+
" )\n",
233+
" return selector_labels\n",
247234
"\n",
248-
" cmd = [\"kubectl\", \"get\", \"pod\", pod_name, \"-n\", namespace, \"-o\", \"json\"]\n",
249-
" try:\n",
250-
" pod_json = _run_kubectl(cmd)\n",
251-
" labels = json.loads(pod_json).get(\"metadata\", {}).get(\"labels\", {})\n",
252-
" return labels.get(\"ai.gke.io/inference-server\") == \"vllm\"\n",
253-
" except (subprocess.CalledProcessError, json.JSONDecodeError, KeyError) as e:\n",
254-
" print(f\"Error checking labels for pod {pod_name} in {namespace}: {e}\")\n",
255-
" return False\n",
256235
"\n",
236+
"def get_running_pod_name(deployment_name, namespace):\n",
237+
" \"\"\"Retrieves the name of a running pod associated with a deployment.\"\"\"\n",
238+
" selector_labels = get_deployment_selector_labels(deployment_name, namespace)\n",
239+
" label_selector_str = \",\".join(f\"{k}={v}\" for k, v in selector_labels.items())\n",
257240
"\n",
258-
"def get_service_endpoint(service, namespace):\n",
259-
" \"\"\"Retrieve the service endpoint of the deployment\"\"\"\n",
260-
" endpoint_cmd = [\n",
241+
" cmd = [\n",
261242
" \"kubectl\",\n",
262243
" \"get\",\n",
263-
" \"endpoints\",\n",
264-
" service,\n",
244+
" \"pods\",\n",
265245
" \"-n\",\n",
266246
" namespace,\n",
247+
" \"-o\",\n",
248+
" \"json\",\n",
249+
" \"-l\",\n",
250+
" label_selector_str,\n",
251+
" \"--field-selector=status.phase=Running\",\n",
267252
" ]\n",
268-
" try:\n",
269-
" endpoint_output = _run_kubectl(endpoint_cmd).splitlines()\n",
270-
" if len(endpoint_output) < 2 or len(endpoint_output[1].split()) < 2:\n",
271-
" print(f\"Endpoint data incomplete for {service}.\")\n",
272-
" return None\n",
273-
" endpoint = endpoint_output[1].split()[\n",
274-
" 1\n",
275-
" ] # Assumes format: NAME ENDPOINTS AGE -> service ip:port,... age\n",
276-
" return endpoint\n",
277-
" except subprocess.CalledProcessError as e:\n",
278-
" print(f\"Error getting endpoints for {service}: {e}\")\n",
279-
" return None\n",
253+
" pods_json = _run_kubectl(cmd)\n",
254+
" pods_data = json.loads(pods_json)\n",
255+
"\n",
256+
" if not pods_data.get(\"items\"):\n",
257+
" raise RuntimeError(\n",
258+
" f\"No running pods found for deployment '{deployment_name}' in namespace\"\n",
259+
" f\" '{namespace}' with selector '{label_selector_str}'.\"\n",
260+
" )\n",
261+
" return pods_data[\"items\"][0][\"metadata\"][\"name\"]\n",
280262
"\n",
281263
"\n",
282-
"def process_response(request, pod_name, pod_endpoint, is_vllm_inference, namespace):\n",
283-
" \"\"\"Sends a request to the pod and processes the response.\"\"\"\n",
264+
"def check_vllm_inference_label(pod_name, namespace):\n",
265+
" \"\"\"Checks if the specified pod has the vLLM inference server label.\"\"\"\n",
266+
" cmd = [\"kubectl\", \"get\", \"pod\", pod_name, \"-n\", namespace, \"-o\", \"json\"]\n",
267+
" pod_json = _run_kubectl(cmd)\n",
268+
" labels = json.loads(pod_json).get(\"metadata\", {}).get(\"labels\", {})\n",
269+
" return labels.get(\"ai.gke.io/inference-server\") == \"vllm\"\n",
284270
"\n",
285-
" json_data_escaped = json.dumps(request).replace(\"'\", \"'\\\\''\")\n",
271+
"\n",
272+
"def send_inference_request(\n",
273+
" request_payload, pod_name, pod_port, is_vllm_inference, namespace\n",
274+
"):\n",
275+
" \"\"\"Sends an inference request to the specified pod and returns the model's response.\"\"\"\n",
276+
" json_data_escaped = json.dumps(request_payload).replace(\"'\", \"'\\\\''\")\n",
286277
" curl_cmd = (\n",
287278
" f\"kubectl exec -n {namespace} -t {pod_name} -- curl -s -X POST\"\n",
288-
" f' http://{pod_endpoint}/generate -H \"Content-Type: application/json\"'\n",
289-
" f\" -d '{json_data_escaped}' 2> /dev/null\"\n",
279+
" f' http://localhost:{pod_port}/generate -H \"Content-Type:'\n",
280+
" ' application/json\"'\n",
281+
" f\" -d '{json_data_escaped}' 2\u003e /dev/null\"\n",
290282
" )\n",
283+
"\n",
284+
" response_raw = _run_kubectl([\"bash\", \"-c\", curl_cmd])\n",
285+
"\n",
286+
" if not response_raw:\n",
287+
" raise RuntimeError(f\"Empty response received from pod '{pod_name}'.\")\n",
288+
"\n",
291289
" try:\n",
292-
" response_raw = _run_kubectl([\"bash\", \"-c\", curl_cmd])\n",
293-
" if not response_raw:\n",
294-
" return f\"Error: Empty response from pod {pod_name}.\"\n",
295290
" first_line = response_raw.splitlines()[0]\n",
296291
" data = json.loads(first_line)\n",
292+
" except json.JSONDecodeError as e:\n",
293+
" raise RuntimeError(\n",
294+
" f\"Failed to decode JSON response from pod: {e}. Raw: {response_raw}\"\n",
295+
" ) from e\n",
296+
" except IndexError:\n",
297+
" raise RuntimeError(\n",
298+
" f\"Unexpected empty response line from pod. Raw: {response_raw}\"\n",
299+
" )\n",
300+
"\n",
301+
" if is_vllm_inference:\n",
302+
" predictions = data.get(\"predictions\")\n",
303+
" if isinstance(predictions, list) and predictions:\n",
304+
" return predictions[0]\n",
305+
" raise RuntimeError(f\"Unexpected vLLM response format. Raw data: {data}\")\n",
306+
" else: # TGI format\n",
307+
" generated_text = data.get(\"generated_text\")\n",
308+
" if generated_text is not None:\n",
309+
" return generated_text\n",
310+
" raise RuntimeError(f\"Unexpected TGI response format. Raw data: {data}\")\n",
311+
"\n",
312+
"\n",
313+
"# --- Main Execution Logic ---\n",
314+
"\n",
315+
"\n",
316+
"def execute_chat_completion(\n",
317+
" deployment_name, namespace, pod_port, user_prompt, temperature, max_tokens\n",
318+
"):\n",
319+
" \"\"\"Executes the full chat completion process: fetches credentials, finds a pod,\n",
320+
"\n",
321+
" determines inference type, sends a request, and returns the response.\n",
322+
" \"\"\"\n",
323+
" display(Markdown(\"Establishing cluster credentials...\"))\n",
324+
" fetch_cluster_credentials(CLUSTER, REGION, PROJECT_ID)\n",
325+
"\n",
326+
" display(Markdown(\"Retrieving pod information...\"))\n",
327+
" pod_name = get_running_pod_name(deployment_name, namespace)\n",
328+
" display(Markdown(f\"Successfully identified pod: `{pod_name}`\"))\n",
329+
"\n",
330+
" is_vllm = check_vllm_inference_label(pod_name, namespace)\n",
331+
"\n",
332+
" request_payload = {\n",
333+
" \"max_tokens\": max_tokens,\n",
334+
" \"temperature\": temperature,\n",
335+
" \"prompt\" if is_vllm else \"inputs\": user_prompt,\n",
336+
" }\n",
337+
" display(Markdown(\"Sending inference request...\"))\n",
338+
" response = send_inference_request(\n",
339+
" request_payload, pod_name, pod_port, is_vllm, namespace\n",
340+
" )\n",
297341
"\n",
298-
" if is_vllm_inference: # vLLM format\n",
299-
" predictions = data.get(\"predictions\")\n",
300-
" if isinstance(predictions, (list, tuple)) and predictions:\n",
301-
" return predictions[0]\n",
302-
" return f\"Error: Unexpected vLLM format. Raw: {first_line}\"\n",
303-
" else: # TGI format\n",
304-
" generated_text = data.get(\"generated_text\")\n",
305-
" if generated_text is not None:\n",
306-
" return generated_text\n",
307-
" return f\"Error: Unexpected TGI format. Raw: {first_line}\"\n",
308-
" except Exception as e:\n",
309-
" return f\"Unexpected error during response processing: {e}\"\n",
342+
" return response\n",
310343
"\n",
311344
"\n",
312345
"# --- Widgets Setup ---\n",
@@ -330,40 +363,24 @@
330363
"\n",
331364
"# --- Submit Button Logic ---\n",
332365
"def on_submit_clicked(b):\n",
333-
" \"\"\"Handles the submit button click event.\"\"\"\n",
334366
" with output_area_response:\n",
335367
" clear_output()\n",
368+
" display(Markdown(\"Loading...\"))\n",
336369
"\n",
337-
" fetch_cluster_credential(CLUSTER, REGION, PROJECT_ID)\n",
338-
"\n",
339-
" # retrieve deployment pod\n",
340-
" pod_name = get_deployment_pod_name(DEPLOYMENT, NAMESPACE, DEPLOYMENT_APP_LABEL)\n",
341-
" if not pod_name:\n",
342-
" display(\n",
343-
" Markdown(f\"**Error:** Could not find running pod for `{DEPLOYMENT}`.\")\n",
344-
" )\n",
345-
" return\n",
346-
"\n",
347-
" # build the request message\n",
348-
" is_vllm = check_inference_label(pod_name, NAMESPACE)\n",
349-
" request = {\n",
350-
" \"max_tokens\": max_tokens_widget.value,\n",
351-
" \"temperature\": temperature_widget.value,\n",
352-
" \"prompt\" if is_vllm else \"inputs\": user_prompt_widget.value,\n",
353-
" }\n",
354-
"\n",
355-
" # retrieve service endpoint for the deployment\n",
356-
" endpoint = get_service_endpoint(SERVICE, NAMESPACE)\n",
357-
" if not endpoint:\n",
358-
" display(Markdown(f\"**Error getting endpoints for `{SERVICE}`:**\\n\"))\n",
359-
" return\n",
360-
"\n",
361-
" # prompt test the deployment endpoint\n",
362370
" try:\n",
363-
" response = process_response(request, pod_name, endpoint, is_vllm, NAMESPACE)\n",
364-
" display(Markdown(f\"**Response:**\\n\\n{response}\"))\n",
371+
" model_response = execute_chat_completion(\n",
372+
" DEPLOYMENT,\n",
373+
" NAMESPACE,\n",
374+
" POD_PORT,\n",
375+
" user_prompt_widget.value,\n",
376+
" temperature_widget.value,\n",
377+
" max_tokens_widget.value,\n",
378+
" )\n",
379+
" clear_output()\n",
380+
" display(Markdown(f\"**Response:**\\n\\n{model_response}\"))\n",
365381
" except Exception as e:\n",
366-
" display(Markdown(f\"**Unexpected Error:**\\n```\\n{e}\\n```\"))\n",
382+
" clear_output()\n",
383+
" display(Markdown(f\"**An error occurred:**\\n```\\n{e}\\n```\"))\n",
367384
"\n",
368385
"\n",
369386
"# --- Display Widgets ---\n",
@@ -379,7 +396,6 @@
379396
},
380397
{
381398
"cell_type": "markdown",
382-
"id": "5b6ZM2K3fux0",
383399
"metadata": {
384400
"id": "5b6ZM2K3fux0"
385401
},
@@ -453,7 +469,7 @@
453469
"metadata": {
454470
"colab": {
455471
"name": "gke_model_ui_deployment_notebook_auto.ipynb",
456-
"toc_visible": true
472+
"provenance": []
457473
},
458474
"kernelspec": {
459475
"display_name": "Python 3",

0 commit comments

Comments
 (0)