Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
using standard environment variables like JAX_COORDINATOR_ADDRESS,
JAX_PROCESS_ID and JAX_NUM_PROCESSES.
"""
# pylint: disable=g-statement-before-imports,g-import-not-at-top

try: # SimDevice import must occur before JAX.
import simdevice # pylint: disable=unused-import

_SIMDEVICE_AVAILABLE = True
except ImportError:
_SIMDEVICE_AVAILABLE = False

import os

Expand All @@ -33,7 +41,7 @@
from orbax.checkpoint._src.testing.benchmarks.core import device_mesh

try:
import pathwaysutils # pylint: disable=g-import-not-at-top
import pathwaysutils

_PATHWAYS_AVAILABLE = True
except ImportError:
Expand Down
133 changes: 120 additions & 13 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@
False,
'If True, skip validation of the benchmark results.',
)
_ENABLE_SIMDEVICE = flags.DEFINE_boolean(
'enable_simdevice',
False,
'If True, auto-configures SimDevice topology flags based on tpu_type.',
)

# --- Pathways Flags ---
# Pathways uses a "Sidecar" architecture on XPK:
Expand Down Expand Up @@ -314,7 +319,9 @@
def _validate_output_directory(flags_dict):
out_dir = flags_dict['output_directory']
storage = flags_dict['storage']
if storage:
if storage == 'local' and out_dir.startswith('/tmp'):
return True
if storage and out_dir.startswith('gs://'):
return True
return out_dir.startswith('gs://')

Expand All @@ -323,7 +330,8 @@ def _validate_output_directory(flags_dict):
['output_directory', 'storage'],
_validate_output_directory,
message=(
'--output_directory must start with gs:// unless --storage is provided.'
'--output_directory must start with gs:// unless --storage=local is'
' provided for /tmp paths.'
),
)
# LINT.ThenChange(README.md:launch_xpk_flags_table)
Expand Down Expand Up @@ -561,7 +569,7 @@ def create_cluster() -> None:
f'--num-slices={_NUM_SLICES.value}',
]

if _TPU_TYPE.value is not None:
if _TPU_TYPE.value is not None and not _ENABLE_SIMDEVICE.value:
cmd.append(f'--tpu-type={_TPU_TYPE.value}')
if _DEVICE_TYPE.value is not None:
cmd.append(f'--device-type={_DEVICE_TYPE.value}')
Expand Down Expand Up @@ -612,11 +620,14 @@ def create_cluster() -> None:
# ramdisk without MTC might be possible and could be explored later.
if _RAMDISK_DIRECTORY.value:
if not _OUTPUT_DIRECTORY.value.startswith('gs://'):
raise ValueError(
'--ramdisk_directory requires --output_directory to be a gs:// path'
' for MTC.'
)
bucket = _OUTPUT_DIRECTORY.value.split('/')[2]
if _STORAGE.value != 'local':
raise ValueError(
'--ramdisk_directory requires --output_directory to be a gs:// path'
' unless --storage=local is used.'
)
bucket = 'local'
else:
bucket = _OUTPUT_DIRECTORY.value.split('/')[2]
cmd.append('--enable-mtc')
cmd.append('--mtc-ramdisk-size=32G')
cmd.append(f'--mtc-gcs-bucket={bucket}')
Expand Down Expand Up @@ -673,6 +684,77 @@ def get_hardware_type(
return HardwareType.UNKNOWN


TPU_TOPOLOGY_MAP = {
'v4-8': '2,2,1',
'v4-16': '2,2,2',
'v4-32': '2,2,4',
'v4-64': '2,4,4',
'v4-128': '4,4,4',
'v4-256': '4,4,8',
'v4-512': '4,8,8',
'v4-1024': '8,8,8',
'v4-1536': '8,8,12',
'v4-2048': '8,8,16',
'v4-4096': '8,16,16',
'v5litepod-8': '2,4,1',
'v5litepod-16': '4,4,1',
'v5litepod-32': '4,8,1',
'v5litepod-64': '8,8,1',
'v5litepod-128': '8,16,1',
'v5litepod-256': '16,16,1',
'v5e-8': '2,4,1',
'v5e-16': '4,4,1',
'v5e-32': '4,8,1',
'v5e-64': '8,8,1',
'v5e-128': '8,16,1',
'v5e-256': '16,16,1',
'v5p-8': '2,2,1',
'v5p-16': '2,2,2',
'v5p-32': '2,2,4',
'v5p-64': '2,4,4',
'v5p-128': '4,4,4',
'v5p-256': '4,4,8',
'v5p-512': '4,8,8',
'v5p-1024': '8,8,8',
'v5p-2048': '8,8,16',
'v5p-4096': '8,16,16',
'v5p-8192': '16,16,16',
'v5p-12288': '16,16,24',
'v6e-1': '1,1,1',
'v6e-4': '2,2,1',
'v6e-8': '2,4,1',
'v6e-16': '4,4,1',
'v6e-32': '4,8,1',
'v6e-64': '8,8,1',
'v6e-128': '8,16,1',
'v6e-256': '16,16,1',
}


def parse_tpu_spec_to_topology(tpu_type: str) -> tuple[str, str]:
"""Translates a TPU spec name to TPU version and topology.

Args:
tpu_type: A TPU spec string like 'v6e-256', 'v5litepod-8', 'v4-8'.

Returns:
A tuple of (tpu_version, topology_string).
tpu_version: string representation of TPU version (e.g. 'v6e', 'v5e').
topology_string: comma-separated 3D TPU topology coordinates (e.g.
'8,16,2').
"""
tpu_type = tpu_type.strip().lower()
version = tpu_type.split('-')[0]
if version == 'v5litepod':
version = 'v5e'

topology = TPU_TOPOLOGY_MAP.get(tpu_type)
if not topology:
raise ValueError(f'Unsupported or invalid TPU spec format: {tpu_type}')

return version, topology


def construct_workload_command(
*,
workload_name: str,
Expand All @@ -683,6 +765,8 @@ def construct_workload_command(
benchmark_binary_path: str,
hardware_type: HardwareType,
v_level: int | None,
tpu_type: str | None = None,
enable_simdevice: bool = False,
) -> str:
"""Constructs the command to run inside the workload."""
# Environment variables
Expand Down Expand Up @@ -719,10 +803,28 @@ def construct_workload_command(
else:
raise ValueError(f'Unsupported hardware type: {hardware_type}')

if enable_simdevice:
if tpu_type is None:
raise ValueError(
'`tpu_type` must be specified when `enable_simdevice` is True'
)
tpu_version, topology = parse_tpu_spec_to_topology(tpu_type)
env_vars.extend([
f'export SIMDEVICE_TPU_VERSION={tpu_version}',
f'export SIMDEVICE_TOPOLOGY={topology}',
f'echo SIMDEVICE_TPU_VERSION = {tpu_version}',
f'echo SIMDEVICE_TOPOLOGY = {topology}',
])

env_cmd = ' && '.join(env_vars) + ' && ' if env_vars else ''

if benchmark_binary_path.endswith('.py'):
cmd_executable = f'python3 {benchmark_binary_path}'
else:
cmd_executable = benchmark_binary_path

python_args = [
f'python3 {benchmark_binary_path}',
cmd_executable,
f'--config_file={config_file}',
f'--output_directory={os.path.join(output_directory, run_id)}',
'--alsologtostderr',
Expand Down Expand Up @@ -761,9 +863,9 @@ def construct_xpk_command(
f'--priority={_PRIORITY.value}',
]

if _STORAGE.value is not None:
if _STORAGE.value is not None and _STORAGE.value != 'local':
base_cmd.append(f'--storage={_STORAGE.value}')
if _TPU_TYPE.value is not None:
if _TPU_TYPE.value is not None and not _ENABLE_SIMDEVICE.value:
base_cmd.append(f'--tpu-type={_TPU_TYPE.value}')
if _DEVICE_TYPE.value is not None:
base_cmd.append(f'--device-type={_DEVICE_TYPE.value}')
Expand Down Expand Up @@ -877,7 +979,9 @@ def print_summary(
def upload_config_to_gcs(local_path: str, gcs_root: str, run_id: str) -> str:
"""Uploads the local config file to GCS and returns the GCS path."""
if not gcs_root.startswith('gs://'):
raise ValueError('Config diectory is not a GCS path.')
if _STORAGE.value == 'local':
return local_path
raise ValueError('Config directory is not a GCS path.')

filename = os.path.basename(local_path)
gcs_path = os.path.join(gcs_root, run_id, filename)
Expand Down Expand Up @@ -1009,7 +1113,8 @@ def main(argv: Sequence[str]) -> None:

# 5. Construct Commands
Console.print_step(4, 6, 'Constructing Commands')
hardware_type = get_hardware_type(_TPU_TYPE.value, _DEVICE_TYPE.value)
requested_tpu_type = None if _ENABLE_SIMDEVICE.value else _TPU_TYPE.value
hardware_type = get_hardware_type(requested_tpu_type, _DEVICE_TYPE.value)
workload_cmd = construct_workload_command(
workload_name=workload_name_base,
config_file=remote_config_path,
Expand All @@ -1019,6 +1124,8 @@ def main(argv: Sequence[str]) -> None:
benchmark_binary_path=_BENCHMARK_BINARY_PATH.value,
hardware_type=hardware_type,
v_level=_V_LEVEL.value,
tpu_type=_TPU_TYPE.value,
enable_simdevice=_ENABLE_SIMDEVICE.value,
)

attempts = 2 if _TEST_RESTART_WORKFLOW.value else 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,94 @@ def test_get_hardware_type(self, tpu, device, expected):
self.assertEqual(launch_xpk.get_hardware_type(tpu, device), expected)


class ParseTpuSpecToTopologyTest(parameterized.TestCase):

@parameterized.named_parameters(
# Cloud TPU specs
dict(
testcase_name='v6e_256',
tpu_type='v6e-256',
expected_version='v6e',
expected_topology='16,16,1',
),
dict(
testcase_name='v6e_16',
tpu_type='v6e-16',
expected_version='v6e',
expected_topology='4,4,1',
),
dict(
testcase_name='v5e_8',
tpu_type='v5litepod-8',
expected_version='v5e',
expected_topology='2,4,1',
),
dict(
testcase_name='v4_8',
tpu_type='v4-8',
expected_version='v4',
expected_topology='2,2,1',
),
dict(
testcase_name='v5p_8',
tpu_type='v5p-8',
expected_version='v5p',
expected_topology='2,2,1',
),
dict(
testcase_name='v6e_8',
tpu_type='v6e-8',
expected_version='v6e',
expected_topology='2,4,1',
),
)
def test_parse_tpu_spec_to_topology(
self, tpu_type, expected_version, expected_topology
):
version, topology = launch_xpk.parse_tpu_spec_to_topology(tpu_type)
self.assertEqual(version, expected_version)
self.assertEqual(topology, expected_topology)


class ConstructXpkCommandTest(parameterized.TestCase):

def setUp(self):
super().setUp()
flags.FLAGS.set_default('xpk_path', 'xpk')
flags.FLAGS.set_default('cluster_name', 'test-cluster')
flags.FLAGS.set_default('project', 'test-project')
flags.FLAGS.set_default('zone', 'test-zone')
flags.FLAGS.set_default('num_slices', 1)
flags.FLAGS.set_default('priority', 'medium')
flags.FLAGS.set_default('storage', 'local')

def test_simdevice_ignores_tpu_type_in_xpk_flags(self):
flags.FLAGS.set_default('enable_simdevice', True)
flags.FLAGS.set_default('tpu_type', 'v6e-8')
flags.FLAGS.set_default('device_type', 'n2-standard-4')

xpk_cmd = launch_xpk.construct_xpk_command(
workload_name='test-workload',
workload_command='python3 benchmark.py',
)

# We should see device-type but not tpu-type in the XPK command flags
self.assertIn('--device-type=n2-standard-4', xpk_cmd)
self.assertNotIn('--tpu-type=v6e-8', xpk_cmd)
self.assertNotIn('--tpu-type', xpk_cmd)

def test_normal_tpu_uses_tpu_type_in_xpk_flags(self):
flags.FLAGS.set_default('enable_simdevice', False)
flags.FLAGS.set_default('tpu_type', 'v6e-8')

xpk_cmd = launch_xpk.construct_xpk_command(
workload_name='test-workload',
workload_command='python3 benchmark.py',
)

# We should see tpu-type in the XPK command flags
self.assertIn('--tpu-type=v6e-8', xpk_cmd)


if __name__ == '__main__':
absltest.main()
Loading