Skip to content

Commit a9a512c

Browse files
vertex-mg-botcopybara-github
authored andcommitted
Dedicated Endpoint Support for SDXL Dreambooth LoRA Finetuning
PiperOrigin-RevId: 770624167
1 parent 51b2be8 commit a9a512c

1 file changed

Lines changed: 50 additions & 26 deletions

File tree

notebooks/community/model_garden/model_garden_pytorch_sd_xl_finetuning_dreambooth_lora.ipynb

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@
6363
"- Deploy the model to a [Vertex AI Endpoint resource](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).\n",
6464
"- Run online predictions for text-to-image.\n",
6565
"\n",
66+
"\n",
67+
"### File a bug\n",
68+
"\n",
69+
"File a bug on [GitHub](https://github.com/GoogleCloudPlatform/vertex-ai-samples/issues/new) if you encounter any issue with the notebook.\n",
70+
"\n",
6671
"### Costs\n",
6772
"\n",
6873
"This tutorial uses billable components of Google Cloud:\n",
@@ -122,13 +127,14 @@
122127
"from google.cloud import aiplatform, storage\n",
123128
"from huggingface_hub import snapshot_download\n",
124129
"\n",
130+
"if os.environ.get(\"VERTEX_PRODUCT\") != \"COLAB_ENTERPRISE\":\n",
131+
" ! pip install --upgrade tensorflow\n",
125132
"! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git\n",
126133
"\n",
127134
"common_util = importlib.import_module(\n",
128135
" \"vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util\"\n",
129136
")\n",
130137
"\n",
131-
"models, endpoints = {}, {}\n",
132138
"\n",
133139
"# Get the default cloud project id.\n",
134140
"PROJECT_ID = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n",
@@ -188,7 +194,14 @@
188194
"\n",
189195
"! gcloud config set project $PROJECT_ID\n",
190196
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/storage.admin\"\n",
191-
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\""
197+
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\"\n",
198+
"\n",
199+
"models, endpoints = {}, {}\n",
200+
"\n",
201+
"# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint). Note that [dedicated endpoint does not support VPC Service Controls](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type), uncheck the box if you are using VPC-SC.\n",
202+
"use_dedicated_endpoint = True # @param {type:\"boolean\"}\n",
203+
"\n",
204+
"# @markdown Click \"Show Code\" to see more details."
192205
]
193206
},
194207
{
@@ -200,7 +213,7 @@
200213
},
201214
"outputs": [],
202215
"source": [
203-
"# @title Start Dreambooth LoRA finetune\n",
216+
"# @title Set up the Dreambooth LoRA finetune parameters\n",
204217
"\n",
205218
"# @markdown This section uses [Dreambooth LoRA](https://dreambooth.github.io/) to finetune\n",
206219
"# @markdown the [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model\n",
@@ -421,17 +434,19 @@
421434
"\n",
422435
"\n",
423436
"def deploy_model(\n",
424-
" model_id,\n",
425-
" lora_id,\n",
426-
" task,\n",
427-
" accelerator_type,\n",
428-
" machine_type,\n",
429-
" accelerator_count,\n",
437+
" model_id: str,\n",
438+
" lora_id: str,\n",
439+
" task: str,\n",
440+
" accelerator_type: str = \"g2-standard-8\",\n",
441+
" machine_type: str = \"NVIDIA_L4\",\n",
442+
" accelerator_count: int = 1,\n",
443+
" use_dedicated_endpoint: bool = False,\n",
430444
"):\n",
431445
" \"\"\"Create a Vertex AI Endpoint and deploy the specified model to the endpoint.\"\"\"\n",
432446
" model_name = model_id\n",
433447
" endpoint = aiplatform.Endpoint.create(\n",
434-
" display_name=common_util.get_job_name_with_datetime(model_name)\n",
448+
" display_name=common_util.get_job_name_with_datetime(model_name),\n",
449+
" dedicated_endpoint_enabled=use_dedicated_endpoint,\n",
435450
" )\n",
436451
" serving_env = {\n",
437452
" \"MODEL_ID\": model_id,\n",
@@ -466,30 +481,21 @@
466481
" return model, endpoint\n",
467482
"\n",
468483
"\n",
484+
"LABEL = \"sd_xl\"\n",
485+
"\n",
469486
"# Set the model_id to \"stabilityai/stable-diffusion-xl-base-1.0\" to load the OSS pre-trained model.\n",
470-
"models[\"sd_xl\"], endpoints[\"sd_xl\"] = deploy_model(\n",
487+
"models[LABEL], endpoints[LABEL] = deploy_model(\n",
471488
" model_id=model_id,\n",
472489
" lora_id=lora_id,\n",
473490
" task=\"text-to-image-sdxl\",\n",
474491
" accelerator_type=serve_accelerator_type,\n",
475492
" machine_type=serve_machine_type,\n",
476493
" accelerator_count=serve_accelerator_count,\n",
494+
" use_dedicated_endpoint=use_dedicated_endpoint,\n",
477495
")\n",
478-
"print(\"endpoint_name:\", endpoints[\"sd_xl\"].name)\n",
479496
"\n",
480-
"# Loads an existing endpoint instance using the endpoint name:\n",
481-
"# - Using `endpoint_name = endpoint.name` allows us to get the\n",
482-
"# endpoint name of the endpoint `endpoint` created in the cell\n",
483-
"# above.\n",
484-
"# - Alternatively, you can set `endpoint_name = \"1234567890123456789\"` to load\n",
485-
"# an existing endpoint with the ID 1234567890123456789.\n",
486-
"# You may uncomment the code below to load an existing endpoint.\n",
487-
"\n",
488-
"# endpoint_name = \"\" # @param {type:\"string\"}\n",
489-
"# aip_endpoint_name = (\n",
490-
"# f\"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}\"\n",
491-
"# )\n",
492-
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)"
497+
"model = models[LABEL]\n",
498+
"endpoint = endpoints[LABEL]"
493499
]
494500
},
495501
{
@@ -511,6 +517,20 @@
511517
"# @markdown the `concept_prompt` of your instance in the prompt.\n",
512518
"# @markdown You may adjust the parameters below to achieve best image quality.\n",
513519
"\n",
520+
"# Loads an existing endpoint instance using the endpoint name:\n",
521+
"# - Using `endpoint_name = endpoint.name` allows us to get the\n",
522+
"# endpoint name of the endpoint `endpoint` created in the cell\n",
523+
"# above.\n",
524+
"# - Alternatively, you can set `endpoint_name = \"1234567890123456789\"` to load\n",
525+
"# an existing endpoint with the ID 1234567890123456789.\n",
526+
"# You may uncomment the code below to load an existing endpoint.\n",
527+
"\n",
528+
"# endpoint_name = \"\" # @param {type:\"string\"}\n",
529+
"# aip_endpoint_name = (\n",
530+
"# f\"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}\"\n",
531+
"# )\n",
532+
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)\n",
533+
"\n",
514534
"prompt = \"A picture of a sks dog in a house\" # @param {type: \"string\"}\n",
515535
"negative_prompt = \"\" # @param {type: \"string\"}\n",
516536
"height = 1024 # @param {type:\"integer\"}\n",
@@ -526,7 +546,11 @@
526546
" \"num_inference_steps\": num_inference_steps,\n",
527547
" \"guidance_scale\": guidance_scale,\n",
528548
"}\n",
529-
"response = endpoints[\"sd_xl\"].predict(instances=instances, parameters=parameters)\n",
549+
"response = endpoints[\"sd_xl\"].predict(\n",
550+
" instances=instances,\n",
551+
" parameters=parameters,\n",
552+
" use_dedicated_endpoint=use_dedicated_endpoint,\n",
553+
")\n",
530554
"\n",
531555
"images = [\n",
532556
" common_util.base64_to_image(prediction.get(\"output\"))\n",

0 commit comments

Comments
 (0)