Skip to content

Commit d0fb60f

Browse files
rayandasoriyaRayan Dasoriya
andauthored
Add optional support for global quota check (#4096)
Co-authored-by: Rayan Dasoriya <dasoriya@google.com>
1 parent 7d4fb0f commit d0fb60f

1 file changed

Lines changed: 44 additions & 7 deletions

File tree

  • community-content/vertex_model_garden/model_oss/notebook_util

community-content/vertex_model_garden/model_oss/notebook_util/common_util.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Common util functions for notebook."""
22

33
import base64
4+
from collections.abc import Sequence
45
import datetime
56
import io
67
import json
78
import os
89
import subprocess
910
import time
10-
from typing import Any, Dict, Sequence
11+
from typing import Any
1112

1213
from google import auth
1314
from google.cloud import storage
@@ -283,7 +284,7 @@ def decode_image(
283284
return image
284285

285286

286-
def get_label_map(label_map_yaml_filepath: str) -> Dict[int, str]:
287+
def get_label_map(label_map_yaml_filepath: str) -> dict[int, str]:
287288
"""Returns class id to label mapping given a filepath to the label map.
288289
289290
Args:
@@ -509,6 +510,17 @@ def get_quota(project_id: str, region: str, resource_id: str) -> int:
509510
):
510511
return -1
511512
all_regions_data = quota_data[0]["consumerQuotaLimits"][0]["quotaBuckets"]
513+
514+
# If the quota data does not have dimensions, it is global quota. However,
515+
# global quota may be overridden by regional quota. So we need to check the
516+
# global quota first.
517+
global_quota = -1
518+
if (
519+
all_regions_data
520+
and "dimensions" not in all_regions_data[0]
521+
and "effectiveLimit" in all_regions_data[0]
522+
):
523+
global_quota = int(all_regions_data[0]["effectiveLimit"])
512524
for region_data in all_regions_data:
513525
if (
514526
region_data.get("dimensions")
@@ -518,12 +530,13 @@ def get_quota(project_id: str, region: str, resource_id: str) -> int:
518530
return int(region_data["effectiveLimit"])
519531
else:
520532
return 0
521-
return -1
533+
return global_quota
522534

523535

524536
def get_resource_id(
525537
accelerator_type: str,
526538
is_for_training: bool,
539+
is_spot: bool = False,
527540
is_restricted_image: bool = False,
528541
is_dynamic_workload_scheduler: bool = False,
529542
) -> str:
@@ -533,6 +546,7 @@ def get_resource_id(
533546
accelerator_type: The accelerator type.
534547
is_for_training: Whether the resource is used for training. Set false for
535548
serving use case.
549+
is_spot: Whether the resource is used with Spot.
536550
is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
537551
is_dynamic_workload_scheduler: Whether the resource is used with Dynamic
538552
Workload Scheduler.
@@ -548,6 +562,7 @@ def get_resource_id(
548562
"NVIDIA_A100_80GB": "nvidia_a100_80gb_gpus",
549563
"NVIDIA_H100_80GB": "nvidia_h100_gpus",
550564
"NVIDIA_H100_MEGA_80GB": "nvidia_h100_mega_gpus",
565+
"NVIDIA_H200_141GB": "nvidia_h200_gpus",
551566
"NVIDIA_TESLA_T4": "nvidia_t4_gpus",
552567
"TPU_V5e": "tpu_v5e",
553568
"TPU_V3": "tpu_v3",
@@ -563,6 +578,10 @@ def get_resource_id(
563578
restricted_image_training_accelerator_map = {
564579
"NVIDIA_A100_80GB": "restricted_image_training_nvidia_a100_80gb_gpus",
565580
}
581+
spot_serving_accelerator_map = {
582+
key: f"custom_model_serving_preemptible_{accelerator_suffix_map[key]}"
583+
for key in accelerator_suffix_map
584+
}
566585
serving_accelerator_map = {
567586
key: f"custom_model_serving_{accelerator_suffix_map[key]}"
568587
for key in accelerator_suffix_map
@@ -591,8 +610,11 @@ def get_resource_id(
591610
else:
592611
if is_dynamic_workload_scheduler:
593612
raise ValueError("Dynamic Workload Scheduler does not work for serving.")
594-
if accelerator_type in serving_accelerator_map:
595-
return serving_accelerator_map[accelerator_type]
613+
accelerator_map = (
614+
spot_serving_accelerator_map if is_spot else serving_accelerator_map
615+
)
616+
if accelerator_type in accelerator_map:
617+
return accelerator_map[accelerator_type]
596618
else:
597619
raise ValueError(
598620
f"Could not find accelerator type: {accelerator_type} for serving."
@@ -605,13 +627,28 @@ def check_quota(
605627
accelerator_type: str,
606628
accelerator_count: int,
607629
is_for_training: bool,
630+
is_spot: bool = False,
608631
is_restricted_image: bool = False,
609632
is_dynamic_workload_scheduler: bool = False,
610-
):
611-
"""Checks if the project and the region has the required quota."""
633+
) -> None:
634+
"""Checks if the project and the region has the required quota.
635+
636+
Args:
637+
project_id: The project id.
638+
region: The region.
639+
accelerator_type: The accelerator type.
640+
accelerator_count: The number of accelerators to check quota for.
641+
is_for_training: Whether the resource is used for training. Set false for
642+
serving use case.
643+
is_spot: Whether the resource is used with Spot.
644+
is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
645+
is_dynamic_workload_scheduler: Whether the resource is used with Dynamic
646+
Workload Scheduler.
647+
"""
612648
resource_id = get_resource_id(
613649
accelerator_type,
614650
is_for_training=is_for_training,
651+
is_spot=is_spot,
615652
is_restricted_image=is_restricted_image,
616653
is_dynamic_workload_scheduler=is_dynamic_workload_scheduler,
617654
)

0 commit comments

Comments
 (0)