11"""Common util functions for notebook."""
22
33import base64
4+ from collections .abc import Sequence
45import datetime
56import io
67import json
78import os
89import subprocess
910import time
10- from typing import Any , Dict , Sequence
11+ from typing import Any
1112
1213from google import auth
1314from google .cloud import storage
@@ -283,7 +284,7 @@ def decode_image(
283284 return image
284285
285286
286- def get_label_map (label_map_yaml_filepath : str ) -> Dict [int , str ]:
287+ def get_label_map (label_map_yaml_filepath : str ) -> dict [int , str ]:
287288 """Returns class id to label mapping given a filepath to the label map.
288289
289290 Args:
@@ -509,6 +510,17 @@ def get_quota(project_id: str, region: str, resource_id: str) -> int:
509510 ):
510511 return - 1
511512 all_regions_data = quota_data [0 ]["consumerQuotaLimits" ][0 ]["quotaBuckets" ]
513+
514+ # If the quota data does not have dimensions, it is global quota. However,
515+ # global quota may be overridden by regional quota. So we need to check the
516+ # global quota first.
517+ global_quota = - 1
518+ if (
519+ all_regions_data
520+ and "dimensions" not in all_regions_data [0 ]
521+ and "effectiveLimit" in all_regions_data [0 ]
522+ ):
523+ global_quota = int (all_regions_data [0 ]["effectiveLimit" ])
512524 for region_data in all_regions_data :
513525 if (
514526 region_data .get ("dimensions" )
@@ -518,12 +530,13 @@ def get_quota(project_id: str, region: str, resource_id: str) -> int:
518530 return int (region_data ["effectiveLimit" ])
519531 else :
520532 return 0
521- return - 1
533+ return global_quota
522534
523535
524536def get_resource_id (
525537 accelerator_type : str ,
526538 is_for_training : bool ,
539+ is_spot : bool = False ,
527540 is_restricted_image : bool = False ,
528541 is_dynamic_workload_scheduler : bool = False ,
529542) -> str :
@@ -533,6 +546,7 @@ def get_resource_id(
533546 accelerator_type: The accelerator type.
534547 is_for_training: Whether the resource is used for training. Set false for
535548 serving use case.
549+ is_spot: Whether the resource is used with Spot.
536550 is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
537551 is_dynamic_workload_scheduler: Whether the resource is used with Dynamic
538552 Workload Scheduler.
@@ -548,6 +562,7 @@ def get_resource_id(
548562 "NVIDIA_A100_80GB" : "nvidia_a100_80gb_gpus" ,
549563 "NVIDIA_H100_80GB" : "nvidia_h100_gpus" ,
550564 "NVIDIA_H100_MEGA_80GB" : "nvidia_h100_mega_gpus" ,
565+ "NVIDIA_H200_141GB" : "nvidia_h200_gpus" ,
551566 "NVIDIA_TESLA_T4" : "nvidia_t4_gpus" ,
552567 "TPU_V5e" : "tpu_v5e" ,
553568 "TPU_V3" : "tpu_v3" ,
@@ -563,6 +578,10 @@ def get_resource_id(
563578 restricted_image_training_accelerator_map = {
564579 "NVIDIA_A100_80GB" : "restricted_image_training_nvidia_a100_80gb_gpus" ,
565580 }
581+ spot_serving_accelerator_map = {
582+ key : f"custom_model_serving_preemptible_{ accelerator_suffix_map [key ]} "
583+ for key in accelerator_suffix_map
584+ }
566585 serving_accelerator_map = {
567586 key : f"custom_model_serving_{ accelerator_suffix_map [key ]} "
568587 for key in accelerator_suffix_map
@@ -591,8 +610,11 @@ def get_resource_id(
591610 else :
592611 if is_dynamic_workload_scheduler :
593612 raise ValueError ("Dynamic Workload Scheduler does not work for serving." )
594- if accelerator_type in serving_accelerator_map :
595- return serving_accelerator_map [accelerator_type ]
613+ accelerator_map = (
614+ spot_serving_accelerator_map if is_spot else serving_accelerator_map
615+ )
616+ if accelerator_type in accelerator_map :
617+ return accelerator_map [accelerator_type ]
596618 else :
597619 raise ValueError (
598620 f"Could not find accelerator type: { accelerator_type } for serving."
@@ -605,13 +627,28 @@ def check_quota(
605627 accelerator_type : str ,
606628 accelerator_count : int ,
607629 is_for_training : bool ,
630+ is_spot : bool = False ,
608631 is_restricted_image : bool = False ,
609632 is_dynamic_workload_scheduler : bool = False ,
610- ):
611- """Checks if the project and the region has the required quota."""
633+ ) -> None :
634+ """Checks if the project and the region has the required quota.
635+
636+ Args:
637+ project_id: The project id.
638+ region: The region.
639+ accelerator_type: The accelerator type.
640+ accelerator_count: The number of accelerators to check quota for.
641+ is_for_training: Whether the resource is used for training. Set false for
642+ serving use case.
643+ is_spot: Whether the resource is used with Spot.
644+ is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
645+ is_dynamic_workload_scheduler: Whether the resource is used with Dynamic
646+ Workload Scheduler.
647+ """
612648 resource_id = get_resource_id (
613649 accelerator_type ,
614650 is_for_training = is_for_training ,
651+ is_spot = is_spot ,
615652 is_restricted_image = is_restricted_image ,
616653 is_dynamic_workload_scheduler = is_dynamic_workload_scheduler ,
617654 )
0 commit comments