diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index 8a583c884f..76552f348d 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -721,8 +721,13 @@ def construct_workload_command( env_cmd = ' && '.join(env_vars) + ' && ' if env_vars else '' + if benchmark_binary_path.endswith('.py'): + cmd_executable = f'python3 {benchmark_binary_path}' + else: + cmd_executable = benchmark_binary_path + python_args = [ - f'python3 {benchmark_binary_path}', + cmd_executable, f'--config_file={config_file}', f'--output_directory={os.path.join(output_directory, run_id)}', '--alsologtostderr',