diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py index a0d864548e..f0c9514f6a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py @@ -20,6 +20,14 @@ using standard environment variables like JAX_COORDINATOR_ADDRESS, JAX_PROCESS_ID and JAX_NUM_PROCESSES. """ +# pylint: disable=g-statement-before-imports,g-import-not-at-top + +try: # SimDevice import must occur before JAX. + import simdevice # pylint: disable=unused-import + + _SIMDEVICE_AVAILABLE = True +except ImportError: + _SIMDEVICE_AVAILABLE = False import os @@ -33,7 +41,7 @@ from orbax.checkpoint._src.testing.benchmarks.core import device_mesh try: - import pathwaysutils # pylint: disable=g-import-not-at-top + import pathwaysutils _PATHWAYS_AVAILABLE = True except ImportError: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index 8a583c884f..885d4b301a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -285,6 +285,11 @@ False, 'If True, skip validation of the benchmark results.', ) +_ENABLE_SIMDEVICE = flags.DEFINE_boolean( + 'enable_simdevice', + False, + 'If True, auto-configures SimDevice topology flags based on tpu_type.', +) # --- Pathways Flags --- # Pathways uses a "Sidecar" architecture on XPK: @@ -314,7 +319,9 @@ def _validate_output_directory(flags_dict): out_dir = flags_dict['output_directory'] storage = flags_dict['storage'] - if storage: + if storage == 'local' and out_dir.startswith('/tmp'): + return True + if storage and out_dir.startswith('gs://'): return True return out_dir.startswith('gs://') @@ -323,7 +330,8 @@ def _validate_output_directory(flags_dict): ['output_directory', 'storage'], _validate_output_directory, message=( - '--output_directory must start with gs:// unless --storage is provided.' + '--output_directory must start with gs:// unless --storage=local is' + ' provided for /tmp paths.' ), ) # LINT.ThenChange(README.md:launch_xpk_flags_table) @@ -561,7 +569,7 @@ def create_cluster() -> None: f'--num-slices={_NUM_SLICES.value}', ] - if _TPU_TYPE.value is not None: + if _TPU_TYPE.value is not None and not _ENABLE_SIMDEVICE.value: cmd.append(f'--tpu-type={_TPU_TYPE.value}') if _DEVICE_TYPE.value is not None: cmd.append(f'--device-type={_DEVICE_TYPE.value}') @@ -612,11 +620,14 @@ def create_cluster() -> None: # ramdisk without MTC might be possible and could be explored later. if _RAMDISK_DIRECTORY.value: if not _OUTPUT_DIRECTORY.value.startswith('gs://'): - raise ValueError( - '--ramdisk_directory requires --output_directory to be a gs:// path' - ' for MTC.' - ) - bucket = _OUTPUT_DIRECTORY.value.split('/')[2] + if _STORAGE.value != 'local': + raise ValueError( + '--ramdisk_directory requires --output_directory to be a gs:// path' + ' unless --storage=local is used.' + ) + bucket = 'local' + else: + bucket = _OUTPUT_DIRECTORY.value.split('/')[2] cmd.append('--enable-mtc') cmd.append('--mtc-ramdisk-size=32G') cmd.append(f'--mtc-gcs-bucket={bucket}') @@ -673,6 +684,77 @@ def get_hardware_type( return HardwareType.UNKNOWN +TPU_TOPOLOGY_MAP = { + 'v4-8': '2,2,1', + 'v4-16': '2,2,2', + 'v4-32': '2,2,4', + 'v4-64': '2,4,4', + 'v4-128': '4,4,4', + 'v4-256': '4,4,8', + 'v4-512': '4,8,8', + 'v4-1024': '8,8,8', + 'v4-1536': '8,8,12', + 'v4-2048': '8,8,16', + 'v4-4096': '8,16,16', + 'v5litepod-8': '2,4,1', + 'v5litepod-16': '4,4,1', + 'v5litepod-32': '4,8,1', + 'v5litepod-64': '8,8,1', + 'v5litepod-128': '8,16,1', + 'v5litepod-256': '16,16,1', + 'v5e-8': '2,4,1', + 'v5e-16': '4,4,1', + 'v5e-32': '4,8,1', + 'v5e-64': '8,8,1', + 'v5e-128': '8,16,1', + 'v5e-256': '16,16,1', + 'v5p-8': '2,2,1', + 'v5p-16': '2,2,2', + 'v5p-32': '2,2,4', + 'v5p-64': '2,4,4', + 'v5p-128': '4,4,4', + 'v5p-256': '4,4,8', + 'v5p-512': '4,8,8', + 'v5p-1024': '8,8,8', + 'v5p-2048': '8,8,16', + 'v5p-4096': '8,16,16', + 'v5p-8192': '16,16,16', + 'v5p-12288': '16,16,24', + 'v6e-1': '1,1,1', + 'v6e-4': '2,2,1', + 'v6e-8': '2,4,1', + 'v6e-16': '4,4,1', + 'v6e-32': '4,8,1', + 'v6e-64': '8,8,1', + 'v6e-128': '8,16,1', + 'v6e-256': '16,16,1', +} + + +def parse_tpu_spec_to_topology(tpu_type: str) -> tuple[str, str]: + """Translates a TPU spec name to TPU version and topology. + + Args: + tpu_type: A TPU spec string like 'v6e-256', 'v5litepod-8', 'v4-8'. + + Returns: + A tuple of (tpu_version, topology_string). + tpu_version: string representation of TPU version (e.g. 'v6e', 'v5e'). + topology_string: comma-separated 3D TPU topology coordinates (e.g. + '8,16,2'). + """ + tpu_type = tpu_type.strip().lower() + version = tpu_type.split('-')[0] + if version == 'v5litepod': + version = 'v5e' + + topology = TPU_TOPOLOGY_MAP.get(tpu_type) + if not topology: + raise ValueError(f'Unsupported or invalid TPU spec format: {tpu_type}') + + return version, topology + + def construct_workload_command( *, workload_name: str, @@ -683,6 +765,8 @@ def construct_workload_command( benchmark_binary_path: str, hardware_type: HardwareType, v_level: int | None, + tpu_type: str | None = None, + enable_simdevice: bool = False, ) -> str: """Constructs the command to run inside the workload.""" # Environment variables @@ -719,10 +803,28 @@ def construct_workload_command( else: raise ValueError(f'Unsupported hardware type: {hardware_type}') + if enable_simdevice: + if tpu_type is None: + raise ValueError( + '`tpu_type` must be specified when `enable_simdevice` is True' + ) + tpu_version, topology = parse_tpu_spec_to_topology(tpu_type) + env_vars.extend([ + f'export SIMDEVICE_TPU_VERSION={tpu_version}', + f'export SIMDEVICE_TOPOLOGY={topology}', + f'echo SIMDEVICE_TPU_VERSION = {tpu_version}', + f'echo SIMDEVICE_TOPOLOGY = {topology}', + ]) + env_cmd = ' && '.join(env_vars) + ' && ' if env_vars else '' + if benchmark_binary_path.endswith('.py'): + cmd_executable = f'python3 {benchmark_binary_path}' + else: + cmd_executable = benchmark_binary_path + python_args = [ - f'python3 {benchmark_binary_path}', + cmd_executable, f'--config_file={config_file}', f'--output_directory={os.path.join(output_directory, run_id)}', '--alsologtostderr', @@ -761,9 +863,9 @@ def construct_xpk_command( f'--priority={_PRIORITY.value}', ] - if _STORAGE.value is not None: + if _STORAGE.value is not None and _STORAGE.value != 'local': base_cmd.append(f'--storage={_STORAGE.value}') - if _TPU_TYPE.value is not None: + if _TPU_TYPE.value is not None and not _ENABLE_SIMDEVICE.value: base_cmd.append(f'--tpu-type={_TPU_TYPE.value}') if _DEVICE_TYPE.value is not None: base_cmd.append(f'--device-type={_DEVICE_TYPE.value}') @@ -877,7 +979,9 @@ def print_summary( def upload_config_to_gcs(local_path: str, gcs_root: str, run_id: str) -> str: """Uploads the local config file to GCS and returns the GCS path.""" if not gcs_root.startswith('gs://'): - raise ValueError('Config diectory is not a GCS path.') + if _STORAGE.value == 'local': + return local_path + raise ValueError('Config directory is not a GCS path.') filename = os.path.basename(local_path) gcs_path = os.path.join(gcs_root, run_id, filename) @@ -1009,7 +1113,8 @@ def main(argv: Sequence[str]) -> None: # 5. Construct Commands Console.print_step(4, 6, 'Constructing Commands') - hardware_type = get_hardware_type(_TPU_TYPE.value, _DEVICE_TYPE.value) + requested_tpu_type = None if _ENABLE_SIMDEVICE.value else _TPU_TYPE.value + hardware_type = get_hardware_type(requested_tpu_type, _DEVICE_TYPE.value) workload_cmd = construct_workload_command( workload_name=workload_name_base, config_file=remote_config_path, @@ -1019,6 +1124,8 @@ def main(argv: Sequence[str]) -> None: benchmark_binary_path=_BENCHMARK_BINARY_PATH.value, hardware_type=hardware_type, v_level=_V_LEVEL.value, + tpu_type=_TPU_TYPE.value, + enable_simdevice=_ENABLE_SIMDEVICE.value, ) attempts = 2 if _TEST_RESTART_WORKFLOW.value else 1 diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py index 602ed9486d..55532cd758 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py @@ -178,5 +178,94 @@ def test_get_hardware_type(self, tpu, device, expected): self.assertEqual(launch_xpk.get_hardware_type(tpu, device), expected) +class ParseTpuSpecToTopologyTest(parameterized.TestCase): + + @parameterized.named_parameters( + # Cloud TPU specs + dict( + testcase_name='v6e_256', + tpu_type='v6e-256', + expected_version='v6e', + expected_topology='16,16,1', + ), + dict( + testcase_name='v6e_16', + tpu_type='v6e-16', + expected_version='v6e', + expected_topology='4,4,1', + ), + dict( + testcase_name='v5e_8', + tpu_type='v5litepod-8', + expected_version='v5e', + expected_topology='2,4,1', + ), + dict( + testcase_name='v4_8', + tpu_type='v4-8', + expected_version='v4', + expected_topology='2,2,1', + ), + dict( + testcase_name='v5p_8', + tpu_type='v5p-8', + expected_version='v5p', + expected_topology='2,2,1', + ), + dict( + testcase_name='v6e_8', + tpu_type='v6e-8', + expected_version='v6e', + expected_topology='2,4,1', + ), + ) + def test_parse_tpu_spec_to_topology( + self, tpu_type, expected_version, expected_topology + ): + version, topology = launch_xpk.parse_tpu_spec_to_topology(tpu_type) + self.assertEqual(version, expected_version) + self.assertEqual(topology, expected_topology) + + +class ConstructXpkCommandTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + flags.FLAGS.set_default('xpk_path', 'xpk') + flags.FLAGS.set_default('cluster_name', 'test-cluster') + flags.FLAGS.set_default('project', 'test-project') + flags.FLAGS.set_default('zone', 'test-zone') + flags.FLAGS.set_default('num_slices', 1) + flags.FLAGS.set_default('priority', 'medium') + flags.FLAGS.set_default('storage', 'local') + + def test_simdevice_ignores_tpu_type_in_xpk_flags(self): + flags.FLAGS.set_default('enable_simdevice', True) + flags.FLAGS.set_default('tpu_type', 'v6e-8') + flags.FLAGS.set_default('device_type', 'n2-standard-4') + + xpk_cmd = launch_xpk.construct_xpk_command( + workload_name='test-workload', + workload_command='python3 benchmark.py', + ) + + # We should see device-type but not tpu-type in the XPK command flags + self.assertIn('--device-type=n2-standard-4', xpk_cmd) + self.assertNotIn('--tpu-type=v6e-8', xpk_cmd) + self.assertNotIn('--tpu-type', xpk_cmd) + + def test_normal_tpu_uses_tpu_type_in_xpk_flags(self): + flags.FLAGS.set_default('enable_simdevice', False) + flags.FLAGS.set_default('tpu_type', 'v6e-8') + + xpk_cmd = launch_xpk.construct_xpk_command( + workload_name='test-workload', + workload_command='python3 benchmark.py', + ) + + # We should see tpu-type in the XPK command flags + self.assertIn('--tpu-type=v6e-8', xpk_cmd) + + if __name__ == '__main__': absltest.main()