diff --git a/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/common_util.py b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/common_util.py new file mode 100644 index 000000000..9fa76247e --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/common_util.py @@ -0,0 +1,746 @@ +"""Common util functions for notebook.""" + +import base64 +from collections.abc import Sequence +import datetime +import io +import json +import os +import subprocess +import time +from typing import Any + +from google import auth +from google.cloud import storage +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +import requests +import tensorflow as tf +import yaml + + +GCS_URI_PREFIX = "gs://" +CHECKPOINT_BUCKET = "gs://model_garden_checkpoints" + + +def convert_numpy_array_to_byte_string_via_tf_tensor( + np_array: np.ndarray, +) -> str: + """Serializes a numpy array to tensor bytes. + + Args: + np_array: A numpy array. + + Returns: + A tensor bytes. + """ + tensor_array = tf.convert_to_tensor(np_array) + tensor_byte_string = tf.io.serialize_tensor(tensor_array) + return tensor_byte_string.numpy() + + +def get_jpeg_bytes(local_image_path: str, new_width: int = -1) -> bytes: + """Returns jpeg bytes given an image path and resizes if required. + + Args: + local_image_path: A string of local image path. + new_width: An integer of new image width. + + Returns: + A jpeg bytes. + """ + image = Image.open(local_image_path) + if new_width <= 0: + new_image = image + else: + width, height = image.size + print("original input image size: ", width, " , ", height) + new_height = int(height * new_width / width) + print("new input image size: ", new_width, " , ", new_height) + new_image = image.resize((new_width, new_height)) + buffered = io.BytesIO() + new_image.save(buffered, format="JPEG") + return buffered.getvalue() + + +def gcs_fuse_path(path: str) -> str: + """Try to convert path to gcsfuse path if it starts with gs:// else do not modify it. + + Args: + path: A string of path. + + Returns: + A gcsfuse path. + """ + path = path.strip() + if path.startswith("gs://"): + return "/gcs/" + path[5:] + return path + + +def get_job_name_with_datetime(prefix: str) -> str: + """Gets a job name by adding current time to prefix. + + Args: + prefix: A string of job name prefix. + + Returns: + A job name. + """ + now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + job_name = f"{prefix}-{now}".replace("_", "-") + return job_name + + +def create_job_name(prefix: str) -> str: + """Creates a job name. + + Args: + prefix: A string of job name prefix. + + Returns: + A job name. + """ + user = os.environ.get("USER") + now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + job_name = f"{prefix}-{user}-{now}".replace("_", "-") + return job_name + + +def save_subset_annotation( + input_annotation_path: str, output_annotation_path: str +): + """Saves a subset of COCO annotation json file with CCA 4.0 license. + + Args: + input_annotation_path: A string of input annotation path. + output_annotation_path: A string of output annotation path. + """ + + with open(input_annotation_path) as f: + coco_json = json.load(f) + + img_ids = set() + images = [] + annotations = [] + + for img in coco_json["images"]: + if img["license"] in [4, 5]: # CCA 4.0 license. + img_ids.add(img["id"]) + images.append(img) + + for ann in coco_json["annotations"]: + if ann["image_id"] in img_ids: + annotations.append(ann) + + new_json = { + "info": coco_json["info"], + "licenses": coco_json["licenses"], + "images": images, + "annotations": annotations, + "categories": coco_json["categories"], + } + + with open(output_annotation_path, "w") as f: + json.dump(new_json, f) + + +def image_to_base64(image: Any, image_format: str = "JPEG") -> str: + """Converts an image to base64. + + Args: + image: A PIL.Image instance. + image_format: A string of image format. + + Returns: + A base64 string. + """ + buffer = io.BytesIO() + image.save(buffer, format=image_format) + image_str = base64.b64encode(buffer.getvalue()).decode("utf-8") + return image_str + + +def base64_to_image(image_str: str) -> Any: + """Convert base64 encoded string to an image. + + Args: + image_str: A string of base64 encoded image. + + Returns: + A PIL.Image instance. + """ + image = Image.open(io.BytesIO(base64.b64decode(image_str))) + return image + + +def image_grid(imgs: Sequence[Any], rows: int = 2, cols: int = 2) -> Any: + """Creates an image grid. + + Args: + imgs: A list of PIL.Image instances. + rows: An integer of number of rows. + cols: An integer of number of columns. + + Returns: + A PIL.Image instance. + """ + w, h = imgs[0].size + grid = Image.new( + mode="RGB", size=(cols * w + 10 * cols, rows * h), color=(255, 255, 255) + ) + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w + 10 * i, i // cols * h)) + return grid + + +def display_image(image: Any): + """Displays an image. + + Args: + image: A PIL.Image instance. + """ + _ = plt.figure(figsize=(20, 15)) + plt.grid(False) + plt.imshow(image) + + +def download_gcs_file_to_local(gcs_uri: str, local_path: str): + """Download a gcs file to a local path. + + Args: + gcs_uri: A string of file path on GCS. + local_path: A string of local file path. + """ + if not gcs_uri.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{gcs_uri} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + client = storage.Client() + os.makedirs(os.path.dirname(local_path), exist_ok=True) + with open(local_path, "wb") as f: + client.download_blob_to_file(gcs_uri, f) + + +def download_image(url: str) -> str: + """Downloads an image from the given URL. + + Args: + url: The URL of the image to download. + + Returns: + base64 encoded image. + """ + response = requests.get(url) + return Image.open(io.BytesIO(response.content)) # pytype: disable=bad-return-type # pillow-102-upgrade + + +def resize_image(image: Any, new_width: int = 1000) -> Any: + """Resizes an image to a certain width. + + Args: + image: The image which has to be resized. + new_width: New width of the image. + + Returns: + New resized image. + """ + width, height = image.size + new_height = int(height * new_width / width) + new_img = image.resize((new_width, new_height)) + return new_img + + +def load_img(path: str) -> Any: + """Reads image from path and return PIL.Image instance. + + Args: + path: A string of image path. + + Returns: + A PIL.Image instance. + """ + img = tf.io.read_file(path) + img = tf.image.decode_jpeg(img, channels=3) + return Image.fromarray(np.uint8(img)).convert("RGB") + + +def decode_image( + image_str_tensor: tf.string, new_height: int, new_width: int +) -> tf.float32: + """Converts and resizes image bytes to image tensor. + + Args: + image_str_tensor: A string of image bytes. + new_height: An integer of new image height. + new_width: An integer of new image width. + + Returns: + An image tensor. + """ + image = tf.io.decode_image(image_str_tensor, 3, expand_animations=False) + image = tf.image.resize(image, (new_height, new_width)) + return image + + +def get_label_map(label_map_yaml_filepath: str) -> dict[int, str]: + """Returns class id to label mapping given a filepath to the label map. + + Args: + label_map_yaml_filepath: A string of label map yaml file path. + + Returns: + A dictionary of class id to label mapping. + """ + with tf.io.gfile.GFile(label_map_yaml_filepath, "rb") as input_file: + label_map = yaml.safe_load(input_file.read())["label_map"] + return label_map + + +def get_prediction_instances(test_filepath: str, new_width: int = -1) -> Any: + """Generate instance from image path to pass to Vertex AI Endpoint for prediction. + + Args: + test_filepath: A string of test image path. + new_width: An integer of new image width. + + Returns: + A list of instances. + """ + if new_width <= 0: + with tf.io.gfile.GFile(test_filepath, "rb") as input_file: + encoded_string = base64.b64encode(input_file.read()).decode("utf-8") + else: + img = load_img(test_filepath) + width, height = img.size + print("original input image size: ", width, " , ", height) + new_height = int(height * new_width / width) + new_img = img.resize((new_width, new_height)) + print("resized input image size: ", new_width, " , ", new_height) + buffered = io.BytesIO() + new_img.save(buffered, format="JPEG") + encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8") + + instances = [{ + "encoded_image": {"b64": encoded_string}, + }] + return instances + + +def vqa_predict( + endpoint: Any, + question_prompts: Sequence[str], + image: Any, + language_code: str = "en", + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> Sequence[str]: + """Predicts the answer to a question about an image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instances = [] + if question_prompts: + # Format question prompt + question_prompt_format = "answer {} {}\n" + for question_prompt in question_prompts: + if question_prompt: + instances.append({ + "prompt": question_prompt_format.format( + language_code, question_prompt + ), + "image": resized_image_base64, + }) + else: + instances.append({ + "image": resized_image_base64, + }) + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return [pred.get("response") for pred in response.predictions] + + +def caption_predict( + endpoint: Any, + language_code: str, + image: Any, + caption_prompt: bool = False, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Predicts a caption for a given image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + + if caption_prompt: + # Format caption prompt + caption_prompt_format = "caption {}\n" + instance["prompt"] = caption_prompt_format.format(language_code) + + instances = [instance] + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def ocr_predict( + endpoint: Any, + ocr_prompt: str, + image: Any, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Extracts text from a given image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + if ocr_prompt: + instance["prompt"] = ocr_prompt + instances = [instance] + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def detect_predict( + endpoint: Any, + detect_prompt: str, + image: Any, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Predicts the answer to a question about an image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + if detect_prompt: + instance["prompt"] = detect_prompt + instances = [instance] + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def copy_model_artifacts( + model_id: str, + model_source: str, + model_destination: str, +) -> None: + """Copies model artifacts from model_source to model_destination. + + model_source and model_destination should be GCS path. + + Args: + model_id: The model id. + model_source: The source of the model artifact. + model_destination: The destination of the model artifact. + """ + if not model_source.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{model_source} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + if not model_destination.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{model_destination} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + model_source = f"{model_source}/{model_id}" + model_destination = f"{model_destination}/{model_id}" + print("Copying model artifact from ", model_source, " to ", model_destination) + subprocess.check_output([ + "gcloud", + "storage", + "cp", + "-r", + model_source, + model_destination, + ]) + + +def get_quota(project_id: str, region: str, resource_id: str) -> int: + """Returns the quota for a resource in a region. + + Args: + project_id: The project id. + region: The region. + resource_id: The resource id. + + Returns: + The quota for the resource in the region. Returns -1 if can not figure out + the quota. + + Raises: + RuntimeError: If the command to get quota fails. + """ + service_endpoint = "aiplatform.googleapis.com" + + command = ( + "gcloud alpha services quota list" + f" --service={service_endpoint} --consumer=projects/{project_id}" + f" --filter='{service_endpoint}/{resource_id}' --format=json" + ) + process = subprocess.run( + command, shell=True, capture_output=True, text=True, check=True + ) + if process.returncode == 0: + quota_data = json.loads(process.stdout) + else: + raise RuntimeError(f"Error fetching quota data: {process.stderr}") + + if not quota_data or "consumerQuotaLimits" not in quota_data[0]: + return -1 + if ( + not quota_data[0]["consumerQuotaLimits"] + or "quotaBuckets" not in quota_data[0]["consumerQuotaLimits"][0] + ): + return -1 + all_regions_data = quota_data[0]["consumerQuotaLimits"][0]["quotaBuckets"] + + # If the quota data does not have dimensions, it is global quota. However, + # global quota may be overridden by regional quota. So we need to check the + # global quota first. + global_quota = -1 + if ( + all_regions_data + and "dimensions" not in all_regions_data[0] + and "effectiveLimit" in all_regions_data[0] + ): + global_quota = int(all_regions_data[0]["effectiveLimit"]) + for region_data in all_regions_data: + if ( + region_data.get("dimensions") + and region_data["dimensions"]["region"] == region + ): + if "effectiveLimit" in region_data: + return int(region_data["effectiveLimit"]) + else: + return 0 + return global_quota + + +def get_resource_id( + accelerator_type: str, + is_for_training: bool, + is_spot: bool = False, + is_restricted_image: bool = False, + is_dynamic_workload_scheduler: bool = False, +) -> str: + """Returns the resource id for a given accelerator type and the use case. + + Args: + accelerator_type: The accelerator type. + is_for_training: Whether the resource is used for training. Set false for + serving use case. + is_spot: Whether the resource is used with Spot. + is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`. + is_dynamic_workload_scheduler: Whether the resource is used with Dynamic + Workload Scheduler. + + Returns: + The resource id. + """ + accelerator_suffix_map = { + "NVIDIA_TESLA_V100": "nvidia_v100_gpus", + "NVIDIA_TESLA_P100": "nvidia_p100_gpus", + "NVIDIA_L4": "nvidia_l4_gpus", + "NVIDIA_TESLA_A100": "nvidia_a100_gpus", + "NVIDIA_A100_80GB": "nvidia_a100_80gb_gpus", + "NVIDIA_H100_80GB": "nvidia_h100_gpus", + "NVIDIA_H100_MEGA_80GB": "nvidia_h100_mega_gpus", + "NVIDIA_H200_141GB": "nvidia_h200_gpus", + "NVIDIA_TESLA_T4": "nvidia_t4_gpus", + "TPU_V6e": "tpu_v6e", + "TPU_V5e": "tpu_v5e", + "TPU_V3": "tpu_v3", + } + default_training_accelerator_map = { + key: f"custom_model_training_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + dws_training_accelerator_map = { + key: f"custom_model_training_preemptible_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + restricted_image_training_accelerator_map = { + "NVIDIA_A100_80GB": "restricted_image_training_nvidia_a100_80gb_gpus", + } + spot_serving_accelerator_map = { + key: f"custom_model_serving_preemptible_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + serving_accelerator_map = { + key: f"custom_model_serving_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + + if is_for_training: + if is_restricted_image and is_dynamic_workload_scheduler: + raise ValueError( + "Dynamic Workload Scheduler does not work for restricted image" + " training." + ) + training_accelerator_map = ( + restricted_image_training_accelerator_map + if is_restricted_image + else default_training_accelerator_map + ) + if accelerator_type in training_accelerator_map: + if is_dynamic_workload_scheduler: + return dws_training_accelerator_map[accelerator_type] + else: + return training_accelerator_map[accelerator_type] + else: + raise ValueError( + f"Could not find accelerator type: {accelerator_type} for training." + ) + else: + if is_dynamic_workload_scheduler: + raise ValueError("Dynamic Workload Scheduler does not work for serving.") + accelerator_map = ( + spot_serving_accelerator_map if is_spot else serving_accelerator_map + ) + if accelerator_type in accelerator_map: + return accelerator_map[accelerator_type] + else: + raise ValueError( + f"Could not find accelerator type: {accelerator_type} for serving." + ) + + +def check_quota( + project_id: str, + region: str, + accelerator_type: str, + accelerator_count: int, + is_for_training: bool, + is_spot: bool = False, + is_restricted_image: bool = False, + is_dynamic_workload_scheduler: bool = False, +) -> None: + """Checks if the project and the region has the required quota. + + Args: + project_id: The project id. + region: The region. + accelerator_type: The accelerator type. + accelerator_count: The number of accelerators to check quota for. + is_for_training: Whether the resource is used for training. Set false for + serving use case. + is_spot: Whether the resource is used with Spot. + is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`. + is_dynamic_workload_scheduler: Whether the resource is used with Dynamic + Workload Scheduler. + """ + resource_id = get_resource_id( + accelerator_type, + is_for_training=is_for_training, + is_spot=is_spot, + is_restricted_image=is_restricted_image, + is_dynamic_workload_scheduler=is_dynamic_workload_scheduler, + ) + quota = get_quota(project_id, region, resource_id) + quota_request_instruction = ( + "Either use " + "a different region or request additional quota. Follow " + "instructions here " + "https://cloud.google.com/docs/quotas/view-manage#requesting_higher_quota" + " to check quota in a region or request additional quota for " + "your project." + ) + if quota == -1: + raise ValueError( + f"Quota not found for: {resource_id} in {region}." + f" {quota_request_instruction}" + ) + if quota < accelerator_count: + raise ValueError( + f"Quota not enough for {resource_id} in {region}: {quota} <" + f" {accelerator_count}. {quota_request_instruction}" + ) + + +def get_deploy_source() -> str: + """Gets deploy_source string based on running environment.""" + vertex_product = os.environ.get("VERTEX_PRODUCT", "") + match vertex_product: + case "COLAB_ENTERPRISE": + return "notebook_colab_enterprise" + case "WORKBENCH_INSTANCE": + return "notebook_workbench" + case _: + # Legacy workbench, legacy colab, or other custom environments. + return "notebook_environment_unspecified" + + +def _is_operation_done(op_name: str, region: str) -> bool: + """Checks if the operation is done. + + Args: + op_name: The name of the operation to poll. + region: The region of the operation. + + Returns: + True if the operation is done, False otherwise. + + Raises: + ValueError: If the operation failed. + """ + creds, _ = auth.default() + auth_req = auth.transport.requests.Request() + creds.refresh(auth_req) + headers = { + "Authorization": f"Bearer {creds.token}", + } + url = f"https://{region}-aiplatform.googleapis.com/ui/{op_name}" + response = requests.get(url, headers=headers) + operation_data = response.json() + if "error" in operation_data: + raise ValueError(f"Operation failed: {operation_data['error']}") + return operation_data.get("done", False) + + +def poll_and_wait( + op_name: str, region: str, total_wait: int, interval: int = 60 +) -> None: + """Polls the operation and waits for it to complete. + + Args: + op_name: The name of the operation to poll. + region: The region of the operation. + total_wait: The total wait time in seconds. + interval: The interval between each poll in seconds. + + Raises: + TimeoutError: If the operation times out. + """ + start_time = time.time() + while True: + if _is_operation_done(op_name, region): + break + time_elapsed = time.time() - start_time + if time_elapsed > total_wait: + raise TimeoutError( + f"Operation timed out after {int(time_elapsed)} seconds." + ) + print( + "\rStill waiting for operation... Elapsed time in seconds:" + f" {int(time_elapsed):<6}", + end="", + flush=True, + ) + time.sleep(interval) diff --git a/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/dataset_validation_util.py b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/dataset_validation_util.py new file mode 100644 index 000000000..57b794d5d --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/dataset_validation_util.py @@ -0,0 +1,592 @@ +"""Functions for dataset validation. + +This tool is used to validate the dataset against the given template. +""" + +from collections.abc import Callable +import json +import multiprocessing +import os +import subprocess +from typing import Any, Union +from absl import logging +import accelerate +import datasets +import transformers + +GCS_URI_PREFIX = "gs://" +GCSFUSE_URI_PREFIX = "/gcs/" +LOCAL_BASE_MODEL_DIR = "/tmp/base_model_dir" +LOCAL_TEMPLATE_DIR = "/tmp/template_dir" +_TEMPLATE_DIRNAME = "templates" +_VERTEX_AI_SAMPLES_GITHUB_REPO_NAME = "vertex-ai-samples" +_VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR = ( + "community-content/vertex_model_garden/model_oss/peft/train/vmg/templates" +) +_MODELS_REQUIRING_PAD_TOKEN = ("llama", "falcon", "mistral", "mixtral") +_MODELS_REQUIRING_EOS_TOEKN = ("gemma-2b", "gemma-7b") +_DESCRIPTION_KEY = "description" +_SOURCE_KEY = "source" +_PROMPT_INPUT_KEY = "prompt_input" +_PROMPT_NO_INPUT_KEY = "prompt_no_input" +_RESPONSE_SEPARATOR = "response_separator" +_INSTRUCTION_SEPARATOR = "instruction_separator" +_CHAT_TEMPLATE_KEY = "chat_template" +_KNOWN_KEYS = ( + _DESCRIPTION_KEY, + _SOURCE_KEY, + _PROMPT_INPUT_KEY, + _PROMPT_NO_INPUT_KEY, + _RESPONSE_SEPARATOR, + _INSTRUCTION_SEPARATOR, + _CHAT_TEMPLATE_KEY, +) + + +def is_gcs_path(input_path: str) -> bool: + """Checks if the input path is a Google Cloud Storage (GCS) path. + + Args: + input_path: The input path to be checked. + + Returns: + True if the input path is a GCS path, False otherwise. + """ + return input_path is not None and input_path.startswith(GCS_URI_PREFIX) + + +def force_gcs_fuse_path(gcs_uri: str) -> str: + """Converts gs:// uris to their /gcs/ equivalents. No-op for other uris. + + Args: + gcs_uri: The GCS URI to convert. + + Returns: + The converted GCS URI. + """ + if is_gcs_path(gcs_uri): + return GCSFUSE_URI_PREFIX + gcs_uri[len(GCS_URI_PREFIX) :] + else: + return gcs_uri + + +def download_gcs_uri_to_local( + gcs_uri: str, + destination_dir: str = LOCAL_BASE_MODEL_DIR, + check_path_exists: bool = True, +) -> str: + """Downloads GCS URI to local. + + If GCS URI is a directory, gs://some/folder is downloaded to + /destination_dir/folder. If GCS URI is a file, gs://some/file is downloaded to + /destination_dir/file. + + Args: + gcs_uri: GCS URI to download. + destination_dir: Local directory directory. + check_path_exists: Whether to check if the path exists. + + Returns: + Local path to target folder/file. + """ + target = os.path.join( + destination_dir, + os.path.basename(os.path.normpath(gcs_uri)), + ) + if check_path_exists and os.path.exists(target): + logging.info("File %s already exists.", target) + return target + if accelerate.PartialState().is_local_main_process: + logging.info( + "Downloading file(s) from %s to %s...", gcs_uri, destination_dir + ) + if not os.path.exists(destination_dir): + os.mkdir(destination_dir) + subprocess.check_output([ + "gsutil", + "-m", + "cp", + "-r", + gcs_uri, + destination_dir, + ]) + logging.info("Downloaded file(s) from %s to %s.", gcs_uri, destination_dir) + # Make sure ALL processes process to next step after data downloading is done. + # It matters for the main process to wait for other processes as well. + accelerate.PartialState().wait_for_everyone() + return target + + +def get_template(template_path: str) -> dict[str, str]: + """Gets the template dictionary given the file path. + + Args: + template_path: Path to the template file. + + Returns: + A dictionary of the template. + + Raises: + ValueError: If the template file does not exist or contains unknown keys. + """ + if is_gcs_path(template_path): + template_path = force_gcs_fuse_path(template_path) + elif not os.path.isfile(template_path): + template_path = os.path.join( + os.path.dirname(__file__), + _TEMPLATE_DIRNAME, + template_path + ".json", + ) + if not os.path.isfile(template_path): + raise ValueError(f"Template file {template_path} does not exist.") + with open(template_path, "r") as f: + template_json: dict[str, str] = json.load(f) + for key in template_json: + if key not in _KNOWN_KEYS: + raise ValueError(f"Unknown key {key} in template {template_path}.") + return template_json + + +def get_response_separator(template_json: dict[str, str]) -> Union[str, None]: + return template_json.get(_RESPONSE_SEPARATOR, None) + + +def get_instruction_separator( + template_json: dict[str, str], +) -> Union[str, None]: + return template_json.get(_INSTRUCTION_SEPARATOR, None) + + +def _format_template_fn( + template: str, + input_column: str, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> Callable[[dict[str, str]], dict[str, str]]: + """Formats a dataset example according to a template. + + Args: + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + input_column: The input column in the dataset to be used or updated by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + A function that formats data according to the template. + """ + template_json = get_template(template) + + if _CHAT_TEMPLATE_KEY not in template_json: + + def format_fn(example: dict[str, str]) -> dict[str, str]: + format_dict = {key: value for key, value in example.items()} + if format_dict.get(input_column): + format_str = template_json[_PROMPT_INPUT_KEY] + elif _PROMPT_NO_INPUT_KEY in template_json: + format_str = template_json[_PROMPT_NO_INPUT_KEY] + else: + raise KeyError( + f"The template {os.path.basename(template)} does not contain" + f" {_PROMPT_INPUT_KEY} or {_PROMPT_NO_INPUT_KEY} key." + ) + try: + return {input_column: format_str.format(**format_dict)} + except KeyError as e: + raise KeyError( + f"The template {os.path.basename(template)} contains a key {e} in" + f" {_PROMPT_INPUT_KEY} or {_PROMPT_NO_INPUT_KEY} that does not" + " exist in the dataset example. The dataset example looks like" + f" {format_dict}." + ) from e + + return format_fn + elif ( + _PROMPT_INPUT_KEY in template_json + or _PROMPT_NO_INPUT_KEY in template_json + ): + raise ValueError( + f"chat_template templates do not support {_PROMPT_INPUT_KEY} or" + f" {_PROMPT_NO_INPUT_KEY} templates." + ) + else: + if tokenizer is None: + raise ValueError("A tokenizer is required for chat_template templates.") + # Assign HuggingFace jinja template. + tokenizer.chat_template = template_json[_CHAT_TEMPLATE_KEY] + + def format_fn(example: dict[str, str]) -> dict[str, str]: + try: + return { + input_column: tokenizer.apply_chat_template( + example[input_column], + tokenize=False, + add_generation_prompt=False, + ) + } + except KeyError as e: + raise KeyError( + f"The template {os.path.basename(template)} contains a key {e} in" + f" {_CHAT_TEMPLATE_KEY} that does not exist in the dataset example." + ) from e + + return format_fn + + +def _get_split_string( + split: str, + dataset_percent: int | None = None, + dataset_k_rows: int | None = None, +) -> str: + """Gets the formatted split string for the dataset. + + This is used to format the split string as per + https://huggingface.co/docs/datasets/v2.21.0/loading#slice-splits. Also, this + function will only be used to load the partial dataset for validating the + dataset against the template. + + Args: + split: Split of the dataset. + dataset_percent: The percentage of the dataset to load. + dataset_k_rows: The top k sequences to load from the dataset. + + Returns: + A formatted split string. + """ + # Validate the dataset_percent and dataset_k_rows values. + if dataset_percent and dataset_k_rows: + raise ValueError( + "You can set either validate_percentage_of_dataset or" + " validate_k_rows_of_dataset, but not both." + ) + + if dataset_percent: + logging.info("Loading %d percent of the dataset...", dataset_percent) + return f"{split}[:{dataset_percent}%]" + + if dataset_k_rows: + logging.info("Loading top %d rows of the dataset...", dataset_k_rows) + return f"{split}[:{dataset_k_rows}]" + + return split + + +def _github_template_path(template: str) -> str: + """Generates the path to the template in the Vertex AI Samples GitHub repo. + + Args: + template: Name of the template. + + Returns: + The path to the template in the Vertex AI Samples GitHub repo. + """ + # vertex-ai-samples directory may lie under separate directory depending on + # the scratch_dir parameter in the notebook execution environment. + vertex_ai_samples_abs_path = os.getcwd().split( + _VERTEX_AI_SAMPLES_GITHUB_REPO_NAME + )[0] + return os.path.join( + vertex_ai_samples_abs_path, + _VERTEX_AI_SAMPLES_GITHUB_REPO_NAME, + _VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR, + template + ".json", + ) + + +def _get_dataset( + dataset_name: str, + split: str, + num_proc: int | None = None, +) -> datasets.DatasetDict: + """Gets a dataset. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + num_proc: Number of processors to use. + + Returns: + A dataset. + """ + dataset_name = force_gcs_fuse_path(dataset_name) + if os.path.isfile(dataset_name): + # Custom dataset. + return datasets.load_dataset( + "json", + data_files=[dataset_name], + split=split, + num_proc=num_proc, + ) + # HF dataset. + return datasets.load_dataset(dataset_name, split=split, num_proc=num_proc) + + +def should_add_pad_token(model_id: str) -> bool: + """Returns whether the model requires adding a special pad token. + + Args: + model_id: The name of the model. + + Returns: + True if the model requires adding a special pad token, False otherwise. + """ + return any(s.lower() in model_id.lower() for s in _MODELS_REQUIRING_PAD_TOKEN) + + +def should_add_eos_token(model_id: str) -> bool: + """Returns whether the model requires adding a special eos token. + + Args: + model_id: The name of the model. + + Returns: + True if the model requires adding a special eos token, False otherwise. + """ + return any(m in model_id for m in _MODELS_REQUIRING_EOS_TOEKN) + + +def load_tokenizer( + pretrained_model_id: str, + padding_side: str | None = None, + access_token: str | None = None, +) -> transformers.AutoTokenizer: + """Loads tokenizer based on `pretrained_model_id`. + + Args: + pretrained_model_id: The name of the pretrained model. + padding_side: The side to pad the input on. + access_token: The access token to use for the tokenizer. + + Returns: + The tokenizer. + """ + tokenizer_kwargs = {} + if should_add_eos_token(pretrained_model_id): + tokenizer_kwargs["add_eos_token"] = True + if padding_side: + tokenizer_kwargs["padding_side"] = padding_side + + with accelerate.PartialState().local_main_process_first(): + tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_id, + trust_remote_code=False, + use_fast=True, + token=access_token, + **tokenizer_kwargs, + ) + + if should_add_pad_token(pretrained_model_id): + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + return tokenizer + + +def get_filtered_dataset( + dataset: Any, + input_column: str, + max_seq_length: int, + tokenizer: transformers.PreTrainedTokenizer, + example_removed_threshold: float = 50.0, +) -> Any: + """Returns the dataset by removing examples that are longer than max_seq_length. + + Args: + dataset: The dataset to filter. + input_column: The input column in the dataset to be used. + max_seq_length: The maximum sequence length. + tokenizer: The tokenizer. + example_removed_threshold: The percent threshold for the number of examples + removed from the dataset. It should be in the range of [0, 100]. + + Returns: + The filtered dataset. + + Raises: + ValueError: If more than `example_removed_threshold` of the dataset is + filtered out. + """ + actual_dataset_length = len(dataset) + filtered_dataset = dataset.filter( + lambda x: len(tokenizer(x[input_column])["input_ids"]) <= max_seq_length + ) + filtered_dataset_length = len(filtered_dataset) + if actual_dataset_length != filtered_dataset_length: + examples_removed_percent = ( + (actual_dataset_length - filtered_dataset_length) + * 100 + / actual_dataset_length + ) + logging.info( + "(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >" + " max-seq-length. Filtering out %d example(s) which are longer than" + " max-seq-length.", + 100 - examples_removed_percent, + max_seq_length, + examples_removed_percent, + actual_dataset_length - filtered_dataset_length, + ) + if examples_removed_percent > example_removed_threshold: + raise ValueError( + "More than %.2f%% of the dataset is filtered out. This may be due to" + " small value of max-seq-length(%d) or incorrect template. Please" + " increase the max-seq-length or check the template." + % (examples_removed_percent, max_seq_length) + ) + print(f"Some formatted examples from the dataset are: {filtered_dataset[:5]}") + return filtered_dataset + + +def format_dataset( + dataset: datasets.Dataset, + input_column: str, + template: str = None, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> datasets.Dataset: + """Takes a raw dataset and formats it using a template and tokenizer. + + Args: + dataset: The raw (unprocessed) dataset to format. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + A dataset compatible with the template. + """ + return dataset.map( + _format_template_fn( + template, + input_column=input_column, + tokenizer=tokenizer, + ) + ) + + +def load_dataset_with_template( + dataset_name: str, + split: str, + input_column: str, + template: str = None, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> tuple[Any, Any]: + """Loads dataset with templates. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + The raw dataset and the dataset compatible with the template. + """ + raw = _get_dataset(dataset_name, split=split) + if template: + templated = format_dataset(raw, input_column, template, tokenizer) + else: + templated = None + + return raw, templated + + +def validate_dataset_with_template( + dataset_name: str, + split: str, + input_column: str, + template: str, + tokenizer: transformers.PreTrainedTokenizer | None = None, + max_seq_length: int | None = None, + use_multiprocessing: bool = False, + validate_percentage_of_dataset: int | None = None, + validate_k_rows_of_dataset: int | None = None, + example_removed_threshold: float = 50.0, +) -> Any: + """Validates dataset with templates. + + This function will be used to load the dataset and validate it against the + template. In case of validation, we also allow the users to load the dataset + partially by allowing them to read x% or top k rows of the dataset. To + validate the dataset, the template file must be available in the GCS bucket + and the dataset must be available either in the GCS bucket or Hugging Face. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + max_seq_length: The maximum sequence length. + use_multiprocessing: If True, it will use multiprocessing to load the + dataset. + validate_percentage_of_dataset: The percentage of the dataset to load. + validate_k_rows_of_dataset: The top k sequences to load from the dataset. + example_removed_threshold: The threshold for the number of examples removed + from the dataset. + + Returns: + None if the validation is successful, otherwise returns the error message. + """ + if not template: + raise ValueError("template is required for validate_dataset.") + + if not dataset_name: + raise ValueError("dataset_name is empty.") + + if not split: + raise ValueError("split is empty.") + + split = _get_split_string( + split, + validate_percentage_of_dataset, + validate_k_rows_of_dataset, + ) + + num_proc = multiprocessing.cpu_count() if use_multiprocessing else 1 + + # gcsfuse cannot be used from the notebook runtime env. Hence, we have + # to download dataset and template from gcs to local. + if is_gcs_path(dataset_name): + dataset_name = download_gcs_uri_to_local(dataset_name, LOCAL_BASE_MODEL_DIR) + + if is_gcs_path(template): + template_path = download_gcs_uri_to_local(template, LOCAL_TEMPLATE_DIR) + elif os.path.isfile(_github_template_path(template)): + template_path = _github_template_path(template) + else: + raise ValueError( + f"Template file {template} does not exist. To validate the" + " dataset, please provide a valid GCS path for the template or a valid" + " template name from" + f" https://github.com/GoogleCloudPlatform/{_VERTEX_AI_SAMPLES_GITHUB_REPO_NAME}/tree/main/{_VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR}." + ) + + dataset = format_dataset( + _get_dataset(dataset_name, split, num_proc), + input_column, + template_path, + tokenizer, + ) + + if tokenizer is not None: + get_filtered_dataset( + dataset=dataset, + input_column=input_column, + max_seq_length=max_seq_length, + tokenizer=tokenizer, + example_removed_threshold=example_removed_threshold, + ) + print( + "Dataset {} is compatible with the {} template.".format( + os.path.basename(dataset_name), os.path.basename(template) + ) + ) diff --git a/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/gcp_utils.py b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/gcp_utils.py new file mode 100644 index 000000000..34345df99 --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/gcp_utils.py @@ -0,0 +1,275 @@ +"""Utility functions for interacting with Google Cloud Platform.""" + +import datetime +import logging +import os +import subprocess +import uuid + +from google.cloud import aiplatform +import requests + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_project_id() -> str: + """Read cloud project id from metadata service.""" + project_request = requests.get( + "http://metadata.google.internal/computeMetadata/v1/project/project-id", + headers={"Metadata-Flavor": "Google"}, + ) + return project_request.text + + +def get_region() -> str: + """Read region from metadata service.""" + region_request = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/region", + headers={"Metadata-Flavor": "Google"}, + ) + return region_request.text.split("/")[-1] + + +# Get the default cloud project id and region +PROJECT_ID = get_project_id() +REGION = get_region() + + +def init_aiplatform(project: str = None, location: str = None) -> None: + """Initialize the Vertex AI SDK. + + Args: + project: The Google Cloud project ID. + location: The Google Cloud location. + """ + project = PROJECT_ID if project is None else project + location = REGION if location is None else location + aiplatform.init(project=project, location=location) + subprocess.call([ + "gcloud", + "services", + "enable", + "aiplatform.googleapis.com", + "compute.googleapis.com", + ]) + + +def run_command(command: list[str]) -> str: + """Runs a shell command and returns the output. + + Args: + command: The shell command to run as a list. + + Returns: + The output of the command. + """ + try: + result = subprocess.run( + command, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + return result.stdout + except subprocess.CalledProcessError as e: + logger.error("Error: %s", e.stderr) + raise e + + +def enable_apis() -> None: + """Enable the Vertex AI API and Compute Engine API.""" + logger.info("Enabling Vertex AI API and Compute Engine API.") + run_command([ + "gcloud", + "services", + "enable", + "aiplatform.googleapis.com", + "compute.googleapis.com", + ]) + + +def setup_buckets(bucket_uri: str, model_bucket_name: str) -> tuple[str, str]: + """Set up Cloud Storage buckets for storing experiment artifacts. + + Args: + bucket_uri: The bucket URI provided by the user. + model_bucket_name: The name of the model bucket. + + Returns: + A tuple containing the bucket name and model bucket path. + """ + if not bucket_uri.strip(): + # Generate a default bucket URI if none provided + now = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + bucket_uri = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}" + logger.info("No bucket URI provided. Using default bucket: %s", bucket_uri) + else: + if not bucket_uri.startswith("gs://"): + raise ValueError("Bucket URI must start with 'gs://'.") + # Remove any trailing slashes + bucket_uri = bucket_uri.rstrip("/") + + bucket_name = "/".join(bucket_uri.split("/")[:3]) + + # Check if bucket exists + try: + run_command(["gsutil", "ls", "-b", bucket_uri]) + logger.info("Bucket %s already exists.", bucket_uri) + except subprocess.CalledProcessError: + logger.info("Creating bucket %s.", bucket_uri) + # Create the bucket in the same region as the project + run_command(["gsutil", "mb", "-l", REGION, bucket_uri]) + + # Construct the model bucket path + model_bucket = os.path.join(bucket_uri, model_bucket_name) + + # Check if the model bucket exists (as a folder within the main bucket) + try: + run_command(["gsutil", "ls", model_bucket]) + logger.info("Model bucket %s already exists.", model_bucket) + except subprocess.CalledProcessError: + logger.info("Creating model bucket %s.", model_bucket) + # Create the model bucket folder + run_command(["gsutil", "cp", "/dev/null", model_bucket + "/"]) + + return bucket_name, model_bucket + + +def get_service_account() -> str: + """Get the default service account.""" + shell_output = run_command(["gcloud", "projects", "describe", PROJECT_ID]) + project_number_line = next( + (line for line in shell_output.splitlines() if "projectNumber" in line), + None, + ) + if project_number_line: + project_number = project_number_line.split(":")[1].strip().replace("'", "") + service_account = f"{project_number}-compute@developer.gserviceaccount.com" + logger.info("Using default Service Account: %s", service_account) + return service_account + else: + raise ValueError("Could not find project number in gcloud output.") + + +def get_project_number() -> str: + """Get the default project number.""" + shell_output = run_command(["gcloud", "projects", "describe", PROJECT_ID]) + project_number_line = next( + (line for line in shell_output.splitlines() if "projectNumber" in line), + None, + ) + if project_number_line: + project_number = project_number_line.split(":")[1].strip().replace("'", "") + logger.info("Using default Project Number: %s", project_number) + return project_number + else: + raise ValueError("Could not find project number in gcloud output.") + + +def provision_permissions(service_account: str, bucket_name: str) -> None: + """Provision permissions to the service account with the GCS bucket.""" + if bucket_name: + run_command([ + "gsutil", + "iam", + "ch", + f"serviceAccount:{service_account}:roles/storage.admin", + bucket_name, + ]) + + +def set_gcloud_project() -> None: + """Set gcloud config project.""" + run_command(["gcloud", "config", "set", "project", PROJECT_ID]) + + +def initialize( + bucket_uri: str, model_bucket_name: str, create_bucket: bool +) -> tuple[str, str]: + """Initialize the environment. + + Args: + bucket_uri: The bucket URI provided by the user. + model_bucket_name: The name of the model bucket. + create_bucket: Whether to create the bucket or not. + + Returns: + A tuple containing the model bucket path and service account. + """ + enable_apis() + bucket_name = None + if create_bucket: + bucket_name, model_bucket = setup_buckets(bucket_uri, model_bucket_name) + else: + model_bucket = None + service_account = get_service_account() + provision_permissions(service_account, bucket_name) + set_gcloud_project() + return model_bucket, service_account + + +def clean_resources_ui( + project_id: str, + region: str, + endpoint_name: str, + delete_bucket: bool, + bucket_name: str = None, +) -> str: + """UI function for cleaning a specific Vertex AI endpoint and its model.""" + if delete_bucket and not bucket_name: + raise ValueError("Bucket name is required when 'Delete Bucket' is checked.") + + try: + delete_endpoint_and_model(project_id, region, endpoint_name) + bucket_status_message = "" + if delete_bucket: + bucket_status_message = delete_gcs_bucket(bucket_name) + if endpoint_name: + return ( + f"Endpoint {endpoint_name} and associated model deleted successfully!" + f" {bucket_status_message}" + ) + else: + return ( + "There are currently no endpoints available for deletion." + f" {bucket_status_message}" + ) + except Exception as e: # pylint: disable=broad-exception-caught + return f"Error cleaning up resources: {e}" + + +def delete_endpoint_and_model( + project_id: str, region: str, endpoint_name: str +) -> None: + """Deletes a specific Vertex AI endpoint and its associated model.""" + if endpoint_name: + endpoint_id = endpoint_name.split(" - ")[0] + endpoint_resource_name = ( + f"projects/{project_id}/locations/{region}/endpoints/{endpoint_id}" + ) + endpoint = aiplatform.Endpoint( + endpoint_resource_name, project=project_id, location=region + ) + deployed_models = endpoint.list_models() + for deployed_model in deployed_models: + endpoint.undeploy(deployed_model_id=deployed_model.id) + model = aiplatform.Model(deployed_model.model) + model.delete() + endpoint.delete() + + +def delete_gcs_bucket(bucket_name: str) -> str: + """Deletes a GCS bucket using gsutil.""" + try: + run_command(["gsutil", "-m", "rm", "-r", bucket_name]) + logger.info("Bucket %s deleted using gsutil.", bucket_name) + return f"Bucket {bucket_name} deleted successfully!" + except subprocess.CalledProcessError as e: + logger.error( + "Error deleting bucket %s using gsutil: %s", bucket_name, str(e) + ) + return f"Bucket {bucket_name} could not be found or deleted. " diff --git a/notebooks/community/model_garden/docker_source_codes/notebook_util/common_util.py b/notebooks/community/model_garden/docker_source_codes/notebook_util/common_util.py new file mode 100644 index 000000000..9fa76247e --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/notebook_util/common_util.py @@ -0,0 +1,746 @@ +"""Common util functions for notebook.""" + +import base64 +from collections.abc import Sequence +import datetime +import io +import json +import os +import subprocess +import time +from typing import Any + +from google import auth +from google.cloud import storage +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +import requests +import tensorflow as tf +import yaml + + +GCS_URI_PREFIX = "gs://" +CHECKPOINT_BUCKET = "gs://model_garden_checkpoints" + + +def convert_numpy_array_to_byte_string_via_tf_tensor( + np_array: np.ndarray, +) -> str: + """Serializes a numpy array to tensor bytes. + + Args: + np_array: A numpy array. + + Returns: + A tensor bytes. + """ + tensor_array = tf.convert_to_tensor(np_array) + tensor_byte_string = tf.io.serialize_tensor(tensor_array) + return tensor_byte_string.numpy() + + +def get_jpeg_bytes(local_image_path: str, new_width: int = -1) -> bytes: + """Returns jpeg bytes given an image path and resizes if required. + + Args: + local_image_path: A string of local image path. + new_width: An integer of new image width. + + Returns: + A jpeg bytes. + """ + image = Image.open(local_image_path) + if new_width <= 0: + new_image = image + else: + width, height = image.size + print("original input image size: ", width, " , ", height) + new_height = int(height * new_width / width) + print("new input image size: ", new_width, " , ", new_height) + new_image = image.resize((new_width, new_height)) + buffered = io.BytesIO() + new_image.save(buffered, format="JPEG") + return buffered.getvalue() + + +def gcs_fuse_path(path: str) -> str: + """Try to convert path to gcsfuse path if it starts with gs:// else do not modify it. + + Args: + path: A string of path. + + Returns: + A gcsfuse path. + """ + path = path.strip() + if path.startswith("gs://"): + return "/gcs/" + path[5:] + return path + + +def get_job_name_with_datetime(prefix: str) -> str: + """Gets a job name by adding current time to prefix. + + Args: + prefix: A string of job name prefix. + + Returns: + A job name. + """ + now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + job_name = f"{prefix}-{now}".replace("_", "-") + return job_name + + +def create_job_name(prefix: str) -> str: + """Creates a job name. + + Args: + prefix: A string of job name prefix. + + Returns: + A job name. + """ + user = os.environ.get("USER") + now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + job_name = f"{prefix}-{user}-{now}".replace("_", "-") + return job_name + + +def save_subset_annotation( + input_annotation_path: str, output_annotation_path: str +): + """Saves a subset of COCO annotation json file with CCA 4.0 license. + + Args: + input_annotation_path: A string of input annotation path. + output_annotation_path: A string of output annotation path. + """ + + with open(input_annotation_path) as f: + coco_json = json.load(f) + + img_ids = set() + images = [] + annotations = [] + + for img in coco_json["images"]: + if img["license"] in [4, 5]: # CCA 4.0 license. + img_ids.add(img["id"]) + images.append(img) + + for ann in coco_json["annotations"]: + if ann["image_id"] in img_ids: + annotations.append(ann) + + new_json = { + "info": coco_json["info"], + "licenses": coco_json["licenses"], + "images": images, + "annotations": annotations, + "categories": coco_json["categories"], + } + + with open(output_annotation_path, "w") as f: + json.dump(new_json, f) + + +def image_to_base64(image: Any, image_format: str = "JPEG") -> str: + """Converts an image to base64. + + Args: + image: A PIL.Image instance. + image_format: A string of image format. + + Returns: + A base64 string. + """ + buffer = io.BytesIO() + image.save(buffer, format=image_format) + image_str = base64.b64encode(buffer.getvalue()).decode("utf-8") + return image_str + + +def base64_to_image(image_str: str) -> Any: + """Convert base64 encoded string to an image. + + Args: + image_str: A string of base64 encoded image. + + Returns: + A PIL.Image instance. + """ + image = Image.open(io.BytesIO(base64.b64decode(image_str))) + return image + + +def image_grid(imgs: Sequence[Any], rows: int = 2, cols: int = 2) -> Any: + """Creates an image grid. + + Args: + imgs: A list of PIL.Image instances. + rows: An integer of number of rows. + cols: An integer of number of columns. + + Returns: + A PIL.Image instance. + """ + w, h = imgs[0].size + grid = Image.new( + mode="RGB", size=(cols * w + 10 * cols, rows * h), color=(255, 255, 255) + ) + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w + 10 * i, i // cols * h)) + return grid + + +def display_image(image: Any): + """Displays an image. + + Args: + image: A PIL.Image instance. + """ + _ = plt.figure(figsize=(20, 15)) + plt.grid(False) + plt.imshow(image) + + +def download_gcs_file_to_local(gcs_uri: str, local_path: str): + """Download a gcs file to a local path. + + Args: + gcs_uri: A string of file path on GCS. + local_path: A string of local file path. + """ + if not gcs_uri.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{gcs_uri} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + client = storage.Client() + os.makedirs(os.path.dirname(local_path), exist_ok=True) + with open(local_path, "wb") as f: + client.download_blob_to_file(gcs_uri, f) + + +def download_image(url: str) -> str: + """Downloads an image from the given URL. + + Args: + url: The URL of the image to download. + + Returns: + base64 encoded image. + """ + response = requests.get(url) + return Image.open(io.BytesIO(response.content)) # pytype: disable=bad-return-type # pillow-102-upgrade + + +def resize_image(image: Any, new_width: int = 1000) -> Any: + """Resizes an image to a certain width. + + Args: + image: The image which has to be resized. + new_width: New width of the image. + + Returns: + New resized image. + """ + width, height = image.size + new_height = int(height * new_width / width) + new_img = image.resize((new_width, new_height)) + return new_img + + +def load_img(path: str) -> Any: + """Reads image from path and return PIL.Image instance. + + Args: + path: A string of image path. + + Returns: + A PIL.Image instance. + """ + img = tf.io.read_file(path) + img = tf.image.decode_jpeg(img, channels=3) + return Image.fromarray(np.uint8(img)).convert("RGB") + + +def decode_image( + image_str_tensor: tf.string, new_height: int, new_width: int +) -> tf.float32: + """Converts and resizes image bytes to image tensor. + + Args: + image_str_tensor: A string of image bytes. + new_height: An integer of new image height. + new_width: An integer of new image width. + + Returns: + An image tensor. + """ + image = tf.io.decode_image(image_str_tensor, 3, expand_animations=False) + image = tf.image.resize(image, (new_height, new_width)) + return image + + +def get_label_map(label_map_yaml_filepath: str) -> dict[int, str]: + """Returns class id to label mapping given a filepath to the label map. + + Args: + label_map_yaml_filepath: A string of label map yaml file path. + + Returns: + A dictionary of class id to label mapping. + """ + with tf.io.gfile.GFile(label_map_yaml_filepath, "rb") as input_file: + label_map = yaml.safe_load(input_file.read())["label_map"] + return label_map + + +def get_prediction_instances(test_filepath: str, new_width: int = -1) -> Any: + """Generate instance from image path to pass to Vertex AI Endpoint for prediction. + + Args: + test_filepath: A string of test image path. + new_width: An integer of new image width. + + Returns: + A list of instances. + """ + if new_width <= 0: + with tf.io.gfile.GFile(test_filepath, "rb") as input_file: + encoded_string = base64.b64encode(input_file.read()).decode("utf-8") + else: + img = load_img(test_filepath) + width, height = img.size + print("original input image size: ", width, " , ", height) + new_height = int(height * new_width / width) + new_img = img.resize((new_width, new_height)) + print("resized input image size: ", new_width, " , ", new_height) + buffered = io.BytesIO() + new_img.save(buffered, format="JPEG") + encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8") + + instances = [{ + "encoded_image": {"b64": encoded_string}, + }] + return instances + + +def vqa_predict( + endpoint: Any, + question_prompts: Sequence[str], + image: Any, + language_code: str = "en", + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> Sequence[str]: + """Predicts the answer to a question about an image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instances = [] + if question_prompts: + # Format question prompt + question_prompt_format = "answer {} {}\n" + for question_prompt in question_prompts: + if question_prompt: + instances.append({ + "prompt": question_prompt_format.format( + language_code, question_prompt + ), + "image": resized_image_base64, + }) + else: + instances.append({ + "image": resized_image_base64, + }) + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return [pred.get("response") for pred in response.predictions] + + +def caption_predict( + endpoint: Any, + language_code: str, + image: Any, + caption_prompt: bool = False, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Predicts a caption for a given image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + + if caption_prompt: + # Format caption prompt + caption_prompt_format = "caption {}\n" + instance["prompt"] = caption_prompt_format.format(language_code) + + instances = [instance] + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def ocr_predict( + endpoint: Any, + ocr_prompt: str, + image: Any, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Extracts text from a given image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + if ocr_prompt: + instance["prompt"] = ocr_prompt + instances = [instance] + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def detect_predict( + endpoint: Any, + detect_prompt: str, + image: Any, + new_width: int = 1000, + use_dedicated_endpoint: bool = False, +) -> str: + """Predicts the answer to a question about an image using an Endpoint.""" + # Resize and convert image to base64 string. + resized_image = resize_image(image, new_width) + resized_image_base64 = image_to_base64(resized_image) + + instance = {"image": resized_image_base64} + if detect_prompt: + instance["prompt"] = detect_prompt + instances = [instance] + + response = endpoint.predict( + instances=instances, use_dedicated_endpoint=use_dedicated_endpoint + ) + return response.predictions[0].get("response") + + +def copy_model_artifacts( + model_id: str, + model_source: str, + model_destination: str, +) -> None: + """Copies model artifacts from model_source to model_destination. + + model_source and model_destination should be GCS path. + + Args: + model_id: The model id. + model_source: The source of the model artifact. + model_destination: The destination of the model artifact. + """ + if not model_source.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{model_source} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + if not model_destination.startswith(GCS_URI_PREFIX): + raise ValueError( + f"{model_destination} is not a GCS path starting with {GCS_URI_PREFIX}." + ) + model_source = f"{model_source}/{model_id}" + model_destination = f"{model_destination}/{model_id}" + print("Copying model artifact from ", model_source, " to ", model_destination) + subprocess.check_output([ + "gcloud", + "storage", + "cp", + "-r", + model_source, + model_destination, + ]) + + +def get_quota(project_id: str, region: str, resource_id: str) -> int: + """Returns the quota for a resource in a region. + + Args: + project_id: The project id. + region: The region. + resource_id: The resource id. + + Returns: + The quota for the resource in the region. Returns -1 if can not figure out + the quota. + + Raises: + RuntimeError: If the command to get quota fails. + """ + service_endpoint = "aiplatform.googleapis.com" + + command = ( + "gcloud alpha services quota list" + f" --service={service_endpoint} --consumer=projects/{project_id}" + f" --filter='{service_endpoint}/{resource_id}' --format=json" + ) + process = subprocess.run( + command, shell=True, capture_output=True, text=True, check=True + ) + if process.returncode == 0: + quota_data = json.loads(process.stdout) + else: + raise RuntimeError(f"Error fetching quota data: {process.stderr}") + + if not quota_data or "consumerQuotaLimits" not in quota_data[0]: + return -1 + if ( + not quota_data[0]["consumerQuotaLimits"] + or "quotaBuckets" not in quota_data[0]["consumerQuotaLimits"][0] + ): + return -1 + all_regions_data = quota_data[0]["consumerQuotaLimits"][0]["quotaBuckets"] + + # If the quota data does not have dimensions, it is global quota. However, + # global quota may be overridden by regional quota. So we need to check the + # global quota first. + global_quota = -1 + if ( + all_regions_data + and "dimensions" not in all_regions_data[0] + and "effectiveLimit" in all_regions_data[0] + ): + global_quota = int(all_regions_data[0]["effectiveLimit"]) + for region_data in all_regions_data: + if ( + region_data.get("dimensions") + and region_data["dimensions"]["region"] == region + ): + if "effectiveLimit" in region_data: + return int(region_data["effectiveLimit"]) + else: + return 0 + return global_quota + + +def get_resource_id( + accelerator_type: str, + is_for_training: bool, + is_spot: bool = False, + is_restricted_image: bool = False, + is_dynamic_workload_scheduler: bool = False, +) -> str: + """Returns the resource id for a given accelerator type and the use case. + + Args: + accelerator_type: The accelerator type. + is_for_training: Whether the resource is used for training. Set false for + serving use case. + is_spot: Whether the resource is used with Spot. + is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`. + is_dynamic_workload_scheduler: Whether the resource is used with Dynamic + Workload Scheduler. + + Returns: + The resource id. + """ + accelerator_suffix_map = { + "NVIDIA_TESLA_V100": "nvidia_v100_gpus", + "NVIDIA_TESLA_P100": "nvidia_p100_gpus", + "NVIDIA_L4": "nvidia_l4_gpus", + "NVIDIA_TESLA_A100": "nvidia_a100_gpus", + "NVIDIA_A100_80GB": "nvidia_a100_80gb_gpus", + "NVIDIA_H100_80GB": "nvidia_h100_gpus", + "NVIDIA_H100_MEGA_80GB": "nvidia_h100_mega_gpus", + "NVIDIA_H200_141GB": "nvidia_h200_gpus", + "NVIDIA_TESLA_T4": "nvidia_t4_gpus", + "TPU_V6e": "tpu_v6e", + "TPU_V5e": "tpu_v5e", + "TPU_V3": "tpu_v3", + } + default_training_accelerator_map = { + key: f"custom_model_training_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + dws_training_accelerator_map = { + key: f"custom_model_training_preemptible_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + restricted_image_training_accelerator_map = { + "NVIDIA_A100_80GB": "restricted_image_training_nvidia_a100_80gb_gpus", + } + spot_serving_accelerator_map = { + key: f"custom_model_serving_preemptible_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + serving_accelerator_map = { + key: f"custom_model_serving_{accelerator_suffix_map[key]}" + for key in accelerator_suffix_map + } + + if is_for_training: + if is_restricted_image and is_dynamic_workload_scheduler: + raise ValueError( + "Dynamic Workload Scheduler does not work for restricted image" + " training." + ) + training_accelerator_map = ( + restricted_image_training_accelerator_map + if is_restricted_image + else default_training_accelerator_map + ) + if accelerator_type in training_accelerator_map: + if is_dynamic_workload_scheduler: + return dws_training_accelerator_map[accelerator_type] + else: + return training_accelerator_map[accelerator_type] + else: + raise ValueError( + f"Could not find accelerator type: {accelerator_type} for training." + ) + else: + if is_dynamic_workload_scheduler: + raise ValueError("Dynamic Workload Scheduler does not work for serving.") + accelerator_map = ( + spot_serving_accelerator_map if is_spot else serving_accelerator_map + ) + if accelerator_type in accelerator_map: + return accelerator_map[accelerator_type] + else: + raise ValueError( + f"Could not find accelerator type: {accelerator_type} for serving." + ) + + +def check_quota( + project_id: str, + region: str, + accelerator_type: str, + accelerator_count: int, + is_for_training: bool, + is_spot: bool = False, + is_restricted_image: bool = False, + is_dynamic_workload_scheduler: bool = False, +) -> None: + """Checks if the project and the region has the required quota. + + Args: + project_id: The project id. + region: The region. + accelerator_type: The accelerator type. + accelerator_count: The number of accelerators to check quota for. + is_for_training: Whether the resource is used for training. Set false for + serving use case. + is_spot: Whether the resource is used with Spot. + is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`. + is_dynamic_workload_scheduler: Whether the resource is used with Dynamic + Workload Scheduler. + """ + resource_id = get_resource_id( + accelerator_type, + is_for_training=is_for_training, + is_spot=is_spot, + is_restricted_image=is_restricted_image, + is_dynamic_workload_scheduler=is_dynamic_workload_scheduler, + ) + quota = get_quota(project_id, region, resource_id) + quota_request_instruction = ( + "Either use " + "a different region or request additional quota. Follow " + "instructions here " + "https://cloud.google.com/docs/quotas/view-manage#requesting_higher_quota" + " to check quota in a region or request additional quota for " + "your project." + ) + if quota == -1: + raise ValueError( + f"Quota not found for: {resource_id} in {region}." + f" {quota_request_instruction}" + ) + if quota < accelerator_count: + raise ValueError( + f"Quota not enough for {resource_id} in {region}: {quota} <" + f" {accelerator_count}. {quota_request_instruction}" + ) + + +def get_deploy_source() -> str: + """Gets deploy_source string based on running environment.""" + vertex_product = os.environ.get("VERTEX_PRODUCT", "") + match vertex_product: + case "COLAB_ENTERPRISE": + return "notebook_colab_enterprise" + case "WORKBENCH_INSTANCE": + return "notebook_workbench" + case _: + # Legacy workbench, legacy colab, or other custom environments. + return "notebook_environment_unspecified" + + +def _is_operation_done(op_name: str, region: str) -> bool: + """Checks if the operation is done. + + Args: + op_name: The name of the operation to poll. + region: The region of the operation. + + Returns: + True if the operation is done, False otherwise. + + Raises: + ValueError: If the operation failed. + """ + creds, _ = auth.default() + auth_req = auth.transport.requests.Request() + creds.refresh(auth_req) + headers = { + "Authorization": f"Bearer {creds.token}", + } + url = f"https://{region}-aiplatform.googleapis.com/ui/{op_name}" + response = requests.get(url, headers=headers) + operation_data = response.json() + if "error" in operation_data: + raise ValueError(f"Operation failed: {operation_data['error']}") + return operation_data.get("done", False) + + +def poll_and_wait( + op_name: str, region: str, total_wait: int, interval: int = 60 +) -> None: + """Polls the operation and waits for it to complete. + + Args: + op_name: The name of the operation to poll. + region: The region of the operation. + total_wait: The total wait time in seconds. + interval: The interval between each poll in seconds. + + Raises: + TimeoutError: If the operation times out. + """ + start_time = time.time() + while True: + if _is_operation_done(op_name, region): + break + time_elapsed = time.time() - start_time + if time_elapsed > total_wait: + raise TimeoutError( + f"Operation timed out after {int(time_elapsed)} seconds." + ) + print( + "\rStill waiting for operation... Elapsed time in seconds:" + f" {int(time_elapsed):<6}", + end="", + flush=True, + ) + time.sleep(interval) diff --git a/notebooks/community/model_garden/docker_source_codes/notebook_util/dataset_validation_util.py b/notebooks/community/model_garden/docker_source_codes/notebook_util/dataset_validation_util.py new file mode 100644 index 000000000..57b794d5d --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/notebook_util/dataset_validation_util.py @@ -0,0 +1,592 @@ +"""Functions for dataset validation. + +This tool is used to validate the dataset against the given template. +""" + +from collections.abc import Callable +import json +import multiprocessing +import os +import subprocess +from typing import Any, Union +from absl import logging +import accelerate +import datasets +import transformers + +GCS_URI_PREFIX = "gs://" +GCSFUSE_URI_PREFIX = "/gcs/" +LOCAL_BASE_MODEL_DIR = "/tmp/base_model_dir" +LOCAL_TEMPLATE_DIR = "/tmp/template_dir" +_TEMPLATE_DIRNAME = "templates" +_VERTEX_AI_SAMPLES_GITHUB_REPO_NAME = "vertex-ai-samples" +_VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR = ( + "community-content/vertex_model_garden/model_oss/peft/train/vmg/templates" +) +_MODELS_REQUIRING_PAD_TOKEN = ("llama", "falcon", "mistral", "mixtral") +_MODELS_REQUIRING_EOS_TOEKN = ("gemma-2b", "gemma-7b") +_DESCRIPTION_KEY = "description" +_SOURCE_KEY = "source" +_PROMPT_INPUT_KEY = "prompt_input" +_PROMPT_NO_INPUT_KEY = "prompt_no_input" +_RESPONSE_SEPARATOR = "response_separator" +_INSTRUCTION_SEPARATOR = "instruction_separator" +_CHAT_TEMPLATE_KEY = "chat_template" +_KNOWN_KEYS = ( + _DESCRIPTION_KEY, + _SOURCE_KEY, + _PROMPT_INPUT_KEY, + _PROMPT_NO_INPUT_KEY, + _RESPONSE_SEPARATOR, + _INSTRUCTION_SEPARATOR, + _CHAT_TEMPLATE_KEY, +) + + +def is_gcs_path(input_path: str) -> bool: + """Checks if the input path is a Google Cloud Storage (GCS) path. + + Args: + input_path: The input path to be checked. + + Returns: + True if the input path is a GCS path, False otherwise. + """ + return input_path is not None and input_path.startswith(GCS_URI_PREFIX) + + +def force_gcs_fuse_path(gcs_uri: str) -> str: + """Converts gs:// uris to their /gcs/ equivalents. No-op for other uris. + + Args: + gcs_uri: The GCS URI to convert. + + Returns: + The converted GCS URI. + """ + if is_gcs_path(gcs_uri): + return GCSFUSE_URI_PREFIX + gcs_uri[len(GCS_URI_PREFIX) :] + else: + return gcs_uri + + +def download_gcs_uri_to_local( + gcs_uri: str, + destination_dir: str = LOCAL_BASE_MODEL_DIR, + check_path_exists: bool = True, +) -> str: + """Downloads GCS URI to local. + + If GCS URI is a directory, gs://some/folder is downloaded to + /destination_dir/folder. If GCS URI is a file, gs://some/file is downloaded to + /destination_dir/file. + + Args: + gcs_uri: GCS URI to download. + destination_dir: Local directory directory. + check_path_exists: Whether to check if the path exists. + + Returns: + Local path to target folder/file. + """ + target = os.path.join( + destination_dir, + os.path.basename(os.path.normpath(gcs_uri)), + ) + if check_path_exists and os.path.exists(target): + logging.info("File %s already exists.", target) + return target + if accelerate.PartialState().is_local_main_process: + logging.info( + "Downloading file(s) from %s to %s...", gcs_uri, destination_dir + ) + if not os.path.exists(destination_dir): + os.mkdir(destination_dir) + subprocess.check_output([ + "gsutil", + "-m", + "cp", + "-r", + gcs_uri, + destination_dir, + ]) + logging.info("Downloaded file(s) from %s to %s.", gcs_uri, destination_dir) + # Make sure ALL processes process to next step after data downloading is done. + # It matters for the main process to wait for other processes as well. + accelerate.PartialState().wait_for_everyone() + return target + + +def get_template(template_path: str) -> dict[str, str]: + """Gets the template dictionary given the file path. + + Args: + template_path: Path to the template file. + + Returns: + A dictionary of the template. + + Raises: + ValueError: If the template file does not exist or contains unknown keys. + """ + if is_gcs_path(template_path): + template_path = force_gcs_fuse_path(template_path) + elif not os.path.isfile(template_path): + template_path = os.path.join( + os.path.dirname(__file__), + _TEMPLATE_DIRNAME, + template_path + ".json", + ) + if not os.path.isfile(template_path): + raise ValueError(f"Template file {template_path} does not exist.") + with open(template_path, "r") as f: + template_json: dict[str, str] = json.load(f) + for key in template_json: + if key not in _KNOWN_KEYS: + raise ValueError(f"Unknown key {key} in template {template_path}.") + return template_json + + +def get_response_separator(template_json: dict[str, str]) -> Union[str, None]: + return template_json.get(_RESPONSE_SEPARATOR, None) + + +def get_instruction_separator( + template_json: dict[str, str], +) -> Union[str, None]: + return template_json.get(_INSTRUCTION_SEPARATOR, None) + + +def _format_template_fn( + template: str, + input_column: str, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> Callable[[dict[str, str]], dict[str, str]]: + """Formats a dataset example according to a template. + + Args: + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + input_column: The input column in the dataset to be used or updated by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + A function that formats data according to the template. + """ + template_json = get_template(template) + + if _CHAT_TEMPLATE_KEY not in template_json: + + def format_fn(example: dict[str, str]) -> dict[str, str]: + format_dict = {key: value for key, value in example.items()} + if format_dict.get(input_column): + format_str = template_json[_PROMPT_INPUT_KEY] + elif _PROMPT_NO_INPUT_KEY in template_json: + format_str = template_json[_PROMPT_NO_INPUT_KEY] + else: + raise KeyError( + f"The template {os.path.basename(template)} does not contain" + f" {_PROMPT_INPUT_KEY} or {_PROMPT_NO_INPUT_KEY} key." + ) + try: + return {input_column: format_str.format(**format_dict)} + except KeyError as e: + raise KeyError( + f"The template {os.path.basename(template)} contains a key {e} in" + f" {_PROMPT_INPUT_KEY} or {_PROMPT_NO_INPUT_KEY} that does not" + " exist in the dataset example. The dataset example looks like" + f" {format_dict}." + ) from e + + return format_fn + elif ( + _PROMPT_INPUT_KEY in template_json + or _PROMPT_NO_INPUT_KEY in template_json + ): + raise ValueError( + f"chat_template templates do not support {_PROMPT_INPUT_KEY} or" + f" {_PROMPT_NO_INPUT_KEY} templates." + ) + else: + if tokenizer is None: + raise ValueError("A tokenizer is required for chat_template templates.") + # Assign HuggingFace jinja template. + tokenizer.chat_template = template_json[_CHAT_TEMPLATE_KEY] + + def format_fn(example: dict[str, str]) -> dict[str, str]: + try: + return { + input_column: tokenizer.apply_chat_template( + example[input_column], + tokenize=False, + add_generation_prompt=False, + ) + } + except KeyError as e: + raise KeyError( + f"The template {os.path.basename(template)} contains a key {e} in" + f" {_CHAT_TEMPLATE_KEY} that does not exist in the dataset example." + ) from e + + return format_fn + + +def _get_split_string( + split: str, + dataset_percent: int | None = None, + dataset_k_rows: int | None = None, +) -> str: + """Gets the formatted split string for the dataset. + + This is used to format the split string as per + https://huggingface.co/docs/datasets/v2.21.0/loading#slice-splits. Also, this + function will only be used to load the partial dataset for validating the + dataset against the template. + + Args: + split: Split of the dataset. + dataset_percent: The percentage of the dataset to load. + dataset_k_rows: The top k sequences to load from the dataset. + + Returns: + A formatted split string. + """ + # Validate the dataset_percent and dataset_k_rows values. + if dataset_percent and dataset_k_rows: + raise ValueError( + "You can set either validate_percentage_of_dataset or" + " validate_k_rows_of_dataset, but not both." + ) + + if dataset_percent: + logging.info("Loading %d percent of the dataset...", dataset_percent) + return f"{split}[:{dataset_percent}%]" + + if dataset_k_rows: + logging.info("Loading top %d rows of the dataset...", dataset_k_rows) + return f"{split}[:{dataset_k_rows}]" + + return split + + +def _github_template_path(template: str) -> str: + """Generates the path to the template in the Vertex AI Samples GitHub repo. + + Args: + template: Name of the template. + + Returns: + The path to the template in the Vertex AI Samples GitHub repo. + """ + # vertex-ai-samples directory may lie under separate directory depending on + # the scratch_dir parameter in the notebook execution environment. + vertex_ai_samples_abs_path = os.getcwd().split( + _VERTEX_AI_SAMPLES_GITHUB_REPO_NAME + )[0] + return os.path.join( + vertex_ai_samples_abs_path, + _VERTEX_AI_SAMPLES_GITHUB_REPO_NAME, + _VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR, + template + ".json", + ) + + +def _get_dataset( + dataset_name: str, + split: str, + num_proc: int | None = None, +) -> datasets.DatasetDict: + """Gets a dataset. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + num_proc: Number of processors to use. + + Returns: + A dataset. + """ + dataset_name = force_gcs_fuse_path(dataset_name) + if os.path.isfile(dataset_name): + # Custom dataset. + return datasets.load_dataset( + "json", + data_files=[dataset_name], + split=split, + num_proc=num_proc, + ) + # HF dataset. + return datasets.load_dataset(dataset_name, split=split, num_proc=num_proc) + + +def should_add_pad_token(model_id: str) -> bool: + """Returns whether the model requires adding a special pad token. + + Args: + model_id: The name of the model. + + Returns: + True if the model requires adding a special pad token, False otherwise. + """ + return any(s.lower() in model_id.lower() for s in _MODELS_REQUIRING_PAD_TOKEN) + + +def should_add_eos_token(model_id: str) -> bool: + """Returns whether the model requires adding a special eos token. + + Args: + model_id: The name of the model. + + Returns: + True if the model requires adding a special eos token, False otherwise. + """ + return any(m in model_id for m in _MODELS_REQUIRING_EOS_TOEKN) + + +def load_tokenizer( + pretrained_model_id: str, + padding_side: str | None = None, + access_token: str | None = None, +) -> transformers.AutoTokenizer: + """Loads tokenizer based on `pretrained_model_id`. + + Args: + pretrained_model_id: The name of the pretrained model. + padding_side: The side to pad the input on. + access_token: The access token to use for the tokenizer. + + Returns: + The tokenizer. + """ + tokenizer_kwargs = {} + if should_add_eos_token(pretrained_model_id): + tokenizer_kwargs["add_eos_token"] = True + if padding_side: + tokenizer_kwargs["padding_side"] = padding_side + + with accelerate.PartialState().local_main_process_first(): + tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_id, + trust_remote_code=False, + use_fast=True, + token=access_token, + **tokenizer_kwargs, + ) + + if should_add_pad_token(pretrained_model_id): + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + return tokenizer + + +def get_filtered_dataset( + dataset: Any, + input_column: str, + max_seq_length: int, + tokenizer: transformers.PreTrainedTokenizer, + example_removed_threshold: float = 50.0, +) -> Any: + """Returns the dataset by removing examples that are longer than max_seq_length. + + Args: + dataset: The dataset to filter. + input_column: The input column in the dataset to be used. + max_seq_length: The maximum sequence length. + tokenizer: The tokenizer. + example_removed_threshold: The percent threshold for the number of examples + removed from the dataset. It should be in the range of [0, 100]. + + Returns: + The filtered dataset. + + Raises: + ValueError: If more than `example_removed_threshold` of the dataset is + filtered out. + """ + actual_dataset_length = len(dataset) + filtered_dataset = dataset.filter( + lambda x: len(tokenizer(x[input_column])["input_ids"]) <= max_seq_length + ) + filtered_dataset_length = len(filtered_dataset) + if actual_dataset_length != filtered_dataset_length: + examples_removed_percent = ( + (actual_dataset_length - filtered_dataset_length) + * 100 + / actual_dataset_length + ) + logging.info( + "(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >" + " max-seq-length. Filtering out %d example(s) which are longer than" + " max-seq-length.", + 100 - examples_removed_percent, + max_seq_length, + examples_removed_percent, + actual_dataset_length - filtered_dataset_length, + ) + if examples_removed_percent > example_removed_threshold: + raise ValueError( + "More than %.2f%% of the dataset is filtered out. This may be due to" + " small value of max-seq-length(%d) or incorrect template. Please" + " increase the max-seq-length or check the template." + % (examples_removed_percent, max_seq_length) + ) + print(f"Some formatted examples from the dataset are: {filtered_dataset[:5]}") + return filtered_dataset + + +def format_dataset( + dataset: datasets.Dataset, + input_column: str, + template: str = None, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> datasets.Dataset: + """Takes a raw dataset and formats it using a template and tokenizer. + + Args: + dataset: The raw (unprocessed) dataset to format. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + A dataset compatible with the template. + """ + return dataset.map( + _format_template_fn( + template, + input_column=input_column, + tokenizer=tokenizer, + ) + ) + + +def load_dataset_with_template( + dataset_name: str, + split: str, + input_column: str, + template: str = None, + tokenizer: transformers.PreTrainedTokenizer | None = None, +) -> tuple[Any, Any]: + """Loads dataset with templates. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + + Returns: + The raw dataset and the dataset compatible with the template. + """ + raw = _get_dataset(dataset_name, split=split) + if template: + templated = format_dataset(raw, input_column, template, tokenizer) + else: + templated = None + + return raw, templated + + +def validate_dataset_with_template( + dataset_name: str, + split: str, + input_column: str, + template: str, + tokenizer: transformers.PreTrainedTokenizer | None = None, + max_seq_length: int | None = None, + use_multiprocessing: bool = False, + validate_percentage_of_dataset: int | None = None, + validate_k_rows_of_dataset: int | None = None, + example_removed_threshold: float = 50.0, +) -> Any: + """Validates dataset with templates. + + This function will be used to load the dataset and validate it against the + template. In case of validation, we also allow the users to load the dataset + partially by allowing them to read x% or top k rows of the dataset. To + validate the dataset, the template file must be available in the GCS bucket + and the dataset must be available either in the GCS bucket or Hugging Face. + + Args: + dataset_name: Name of the dataset or path to a custom dataset. + split: Split of the dataset. + input_column: The input column in the dataset to be used or updaded by the + template. If it does not exist, the template's `prompt_no_input` will be + used, and the input_column will be created. + template: Name of the JSON template file under `templates/` or GCS path to + the template file. + tokenizer: The tokenizer to use for chat_template templates. + max_seq_length: The maximum sequence length. + use_multiprocessing: If True, it will use multiprocessing to load the + dataset. + validate_percentage_of_dataset: The percentage of the dataset to load. + validate_k_rows_of_dataset: The top k sequences to load from the dataset. + example_removed_threshold: The threshold for the number of examples removed + from the dataset. + + Returns: + None if the validation is successful, otherwise returns the error message. + """ + if not template: + raise ValueError("template is required for validate_dataset.") + + if not dataset_name: + raise ValueError("dataset_name is empty.") + + if not split: + raise ValueError("split is empty.") + + split = _get_split_string( + split, + validate_percentage_of_dataset, + validate_k_rows_of_dataset, + ) + + num_proc = multiprocessing.cpu_count() if use_multiprocessing else 1 + + # gcsfuse cannot be used from the notebook runtime env. Hence, we have + # to download dataset and template from gcs to local. + if is_gcs_path(dataset_name): + dataset_name = download_gcs_uri_to_local(dataset_name, LOCAL_BASE_MODEL_DIR) + + if is_gcs_path(template): + template_path = download_gcs_uri_to_local(template, LOCAL_TEMPLATE_DIR) + elif os.path.isfile(_github_template_path(template)): + template_path = _github_template_path(template) + else: + raise ValueError( + f"Template file {template} does not exist. To validate the" + " dataset, please provide a valid GCS path for the template or a valid" + " template name from" + f" https://github.com/GoogleCloudPlatform/{_VERTEX_AI_SAMPLES_GITHUB_REPO_NAME}/tree/main/{_VERTEX_AI_SAMPLES_GITHUB_TEMPLATE_DIR}." + ) + + dataset = format_dataset( + _get_dataset(dataset_name, split, num_proc), + input_column, + template_path, + tokenizer, + ) + + if tokenizer is not None: + get_filtered_dataset( + dataset=dataset, + input_column=input_column, + max_seq_length=max_seq_length, + tokenizer=tokenizer, + example_removed_threshold=example_removed_threshold, + ) + print( + "Dataset {} is compatible with the {} template.".format( + os.path.basename(dataset_name), os.path.basename(template) + ) + ) diff --git a/notebooks/community/model_garden/docker_source_codes/notebook_util/gcp_utils.py b/notebooks/community/model_garden/docker_source_codes/notebook_util/gcp_utils.py new file mode 100644 index 000000000..34345df99 --- /dev/null +++ b/notebooks/community/model_garden/docker_source_codes/notebook_util/gcp_utils.py @@ -0,0 +1,275 @@ +"""Utility functions for interacting with Google Cloud Platform.""" + +import datetime +import logging +import os +import subprocess +import uuid + +from google.cloud import aiplatform +import requests + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_project_id() -> str: + """Read cloud project id from metadata service.""" + project_request = requests.get( + "http://metadata.google.internal/computeMetadata/v1/project/project-id", + headers={"Metadata-Flavor": "Google"}, + ) + return project_request.text + + +def get_region() -> str: + """Read region from metadata service.""" + region_request = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/region", + headers={"Metadata-Flavor": "Google"}, + ) + return region_request.text.split("/")[-1] + + +# Get the default cloud project id and region +PROJECT_ID = get_project_id() +REGION = get_region() + + +def init_aiplatform(project: str = None, location: str = None) -> None: + """Initialize the Vertex AI SDK. + + Args: + project: The Google Cloud project ID. + location: The Google Cloud location. + """ + project = PROJECT_ID if project is None else project + location = REGION if location is None else location + aiplatform.init(project=project, location=location) + subprocess.call([ + "gcloud", + "services", + "enable", + "aiplatform.googleapis.com", + "compute.googleapis.com", + ]) + + +def run_command(command: list[str]) -> str: + """Runs a shell command and returns the output. + + Args: + command: The shell command to run as a list. + + Returns: + The output of the command. + """ + try: + result = subprocess.run( + command, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + return result.stdout + except subprocess.CalledProcessError as e: + logger.error("Error: %s", e.stderr) + raise e + + +def enable_apis() -> None: + """Enable the Vertex AI API and Compute Engine API.""" + logger.info("Enabling Vertex AI API and Compute Engine API.") + run_command([ + "gcloud", + "services", + "enable", + "aiplatform.googleapis.com", + "compute.googleapis.com", + ]) + + +def setup_buckets(bucket_uri: str, model_bucket_name: str) -> tuple[str, str]: + """Set up Cloud Storage buckets for storing experiment artifacts. + + Args: + bucket_uri: The bucket URI provided by the user. + model_bucket_name: The name of the model bucket. + + Returns: + A tuple containing the bucket name and model bucket path. + """ + if not bucket_uri.strip(): + # Generate a default bucket URI if none provided + now = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + bucket_uri = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}" + logger.info("No bucket URI provided. Using default bucket: %s", bucket_uri) + else: + if not bucket_uri.startswith("gs://"): + raise ValueError("Bucket URI must start with 'gs://'.") + # Remove any trailing slashes + bucket_uri = bucket_uri.rstrip("/") + + bucket_name = "/".join(bucket_uri.split("/")[:3]) + + # Check if bucket exists + try: + run_command(["gsutil", "ls", "-b", bucket_uri]) + logger.info("Bucket %s already exists.", bucket_uri) + except subprocess.CalledProcessError: + logger.info("Creating bucket %s.", bucket_uri) + # Create the bucket in the same region as the project + run_command(["gsutil", "mb", "-l", REGION, bucket_uri]) + + # Construct the model bucket path + model_bucket = os.path.join(bucket_uri, model_bucket_name) + + # Check if the model bucket exists (as a folder within the main bucket) + try: + run_command(["gsutil", "ls", model_bucket]) + logger.info("Model bucket %s already exists.", model_bucket) + except subprocess.CalledProcessError: + logger.info("Creating model bucket %s.", model_bucket) + # Create the model bucket folder + run_command(["gsutil", "cp", "/dev/null", model_bucket + "/"]) + + return bucket_name, model_bucket + + +def get_service_account() -> str: + """Get the default service account.""" + shell_output = run_command(["gcloud", "projects", "describe", PROJECT_ID]) + project_number_line = next( + (line for line in shell_output.splitlines() if "projectNumber" in line), + None, + ) + if project_number_line: + project_number = project_number_line.split(":")[1].strip().replace("'", "") + service_account = f"{project_number}-compute@developer.gserviceaccount.com" + logger.info("Using default Service Account: %s", service_account) + return service_account + else: + raise ValueError("Could not find project number in gcloud output.") + + +def get_project_number() -> str: + """Get the default project number.""" + shell_output = run_command(["gcloud", "projects", "describe", PROJECT_ID]) + project_number_line = next( + (line for line in shell_output.splitlines() if "projectNumber" in line), + None, + ) + if project_number_line: + project_number = project_number_line.split(":")[1].strip().replace("'", "") + logger.info("Using default Project Number: %s", project_number) + return project_number + else: + raise ValueError("Could not find project number in gcloud output.") + + +def provision_permissions(service_account: str, bucket_name: str) -> None: + """Provision permissions to the service account with the GCS bucket.""" + if bucket_name: + run_command([ + "gsutil", + "iam", + "ch", + f"serviceAccount:{service_account}:roles/storage.admin", + bucket_name, + ]) + + +def set_gcloud_project() -> None: + """Set gcloud config project.""" + run_command(["gcloud", "config", "set", "project", PROJECT_ID]) + + +def initialize( + bucket_uri: str, model_bucket_name: str, create_bucket: bool +) -> tuple[str, str]: + """Initialize the environment. + + Args: + bucket_uri: The bucket URI provided by the user. + model_bucket_name: The name of the model bucket. + create_bucket: Whether to create the bucket or not. + + Returns: + A tuple containing the model bucket path and service account. + """ + enable_apis() + bucket_name = None + if create_bucket: + bucket_name, model_bucket = setup_buckets(bucket_uri, model_bucket_name) + else: + model_bucket = None + service_account = get_service_account() + provision_permissions(service_account, bucket_name) + set_gcloud_project() + return model_bucket, service_account + + +def clean_resources_ui( + project_id: str, + region: str, + endpoint_name: str, + delete_bucket: bool, + bucket_name: str = None, +) -> str: + """UI function for cleaning a specific Vertex AI endpoint and its model.""" + if delete_bucket and not bucket_name: + raise ValueError("Bucket name is required when 'Delete Bucket' is checked.") + + try: + delete_endpoint_and_model(project_id, region, endpoint_name) + bucket_status_message = "" + if delete_bucket: + bucket_status_message = delete_gcs_bucket(bucket_name) + if endpoint_name: + return ( + f"Endpoint {endpoint_name} and associated model deleted successfully!" + f" {bucket_status_message}" + ) + else: + return ( + "There are currently no endpoints available for deletion." + f" {bucket_status_message}" + ) + except Exception as e: # pylint: disable=broad-exception-caught + return f"Error cleaning up resources: {e}" + + +def delete_endpoint_and_model( + project_id: str, region: str, endpoint_name: str +) -> None: + """Deletes a specific Vertex AI endpoint and its associated model.""" + if endpoint_name: + endpoint_id = endpoint_name.split(" - ")[0] + endpoint_resource_name = ( + f"projects/{project_id}/locations/{region}/endpoints/{endpoint_id}" + ) + endpoint = aiplatform.Endpoint( + endpoint_resource_name, project=project_id, location=region + ) + deployed_models = endpoint.list_models() + for deployed_model in deployed_models: + endpoint.undeploy(deployed_model_id=deployed_model.id) + model = aiplatform.Model(deployed_model.model) + model.delete() + endpoint.delete() + + +def delete_gcs_bucket(bucket_name: str) -> str: + """Deletes a GCS bucket using gsutil.""" + try: + run_command(["gsutil", "-m", "rm", "-r", bucket_name]) + logger.info("Bucket %s deleted using gsutil.", bucket_name) + return f"Bucket {bucket_name} deleted successfully!" + except subprocess.CalledProcessError as e: + logger.error( + "Error deleting bucket %s using gsutil: %s", bucket_name, str(e) + ) + return f"Bucket {bucket_name} could not be found or deleted. "