diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae3a8bc52..f0f7245ca 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -115,11 +115,11 @@ jobs: working-directory: export strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12"] jax-version: ["newest"] include: - - python-version: "3.10" - jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml + - python-version: "3.11" + jax-version: "0.6.0" # keep in sync with minimum version in export/pyproject.toml # TODO(b/401258175) Re-enable once JAX nightlies are fixed. # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it. # jax-version: "nightly" diff --git a/export/orbax/__init__.py b/export/orbax/__init__.py index b3140b7d1..090c612a8 100644 --- a/export/orbax/__init__.py +++ b/export/orbax/__init__.py @@ -15,6 +15,7 @@ """Defines exported symbols for Orbax Export.""" # pylint: disable=g-importing-member +from orbax.export import bundle from orbax.export import config from orbax.export import constants from orbax.export import obm_configs @@ -29,14 +30,15 @@ from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager + from orbax.export.jax_module import JaxModule from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig + # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. from orbax.export.utils import remove_signature_defaults from orbax.export.utils import TensorSpecWithDefault - # A new PyPI release will be pushed everytime `__version__` is increased. __version__ = '0.0.8' diff --git a/export/orbax/export/__init__.py b/export/orbax/export/__init__.py index b3140b7d1..090c612a8 100644 --- a/export/orbax/export/__init__.py +++ b/export/orbax/export/__init__.py @@ -15,6 +15,7 @@ """Defines exported symbols for Orbax Export.""" # pylint: disable=g-importing-member +from orbax.export import bundle from orbax.export import config from orbax.export import constants from orbax.export import obm_configs @@ -29,14 +30,15 @@ from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on from orbax.export.dtensor_utils import shutdown_dtensor from orbax.export.export_manager import ExportManager + from orbax.export.jax_module import JaxModule from orbax.export.obm_configs import Jax2ObmOptions from orbax.export.serving_config import ServingConfig + # TODO(dinghua): remove them after we change all references to # utils.remove_signature_defaults. from orbax.export.utils import remove_signature_defaults from orbax.export.utils import TensorSpecWithDefault - # A new PyPI release will be pushed everytime `__version__` is increased. __version__ = '0.0.8' diff --git a/export/orbax/export/bundle.py b/export/orbax/export/bundle.py new file mode 100644 index 000000000..76a6df5e0 --- /dev/null +++ b/export/orbax/export/bundle.py @@ -0,0 +1,134 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bundle orchestration export utilities.""" + +import dataclasses +import os +import shutil +from typing import Dict, List + +from neptune_model.protos import bundle_orchestration_pb2 + + +@dataclasses.dataclass +class SubModel: + """Defines a unique model in the bundle. + + Attributes: + name: Name of the sub-model subdirectory in the bundle. + path: Path to the exported sub-model. + """ + + name: str + path: str + + +@dataclasses.dataclass +class PipelineStep: + """Defines one model's execution within the bundle. + + Attributes: + model: Name of the sub-model (must match a SubModel.name). + pipeline: Name of the pipeline in the sub-model to execute. + repeated_times: Number of times to repeat this step. + requires_h2d: Whether this step requires Host-to-Device transfer. + requires_d2h: Whether this step requires Device-to-Host transfer. + """ + + model: str + pipeline: str + repeated_times: int = 1 + requires_h2d: bool = False + requires_d2h: bool = False + + +@dataclasses.dataclass +class BundleDefinition: + """Defines organization of sub-models in the bundle. + + Attributes: + name: Human-readable name of the bundle. + version: Version of the bundle. + pipelines: Maps pipeline name to a sequential list of steps. + """ + + name: str + version: int + pipelines: Dict[str, List[PipelineStep]] + + +def create_bundle( + output_path: str, + bundle_def: BundleDefinition, + models: List[SubModel], + copy_models: bool = False, +): + """Creates the bundle directory, symlinks/copies sub-models, and writes the proto. + + Args: + output_path: Path where the bundle should be created. + bundle_def: Definition of the bundle pipelines and metadata. + models: List of sub-models to include in the bundle. + copy_models: If True, copies the sub-models instead of symlinking them. + """ + os.makedirs(output_path, exist_ok=True) + + for model in models: + dest_path = os.path.join(output_path, model.name) + if os.path.lexists(dest_path): + if os.path.islink(dest_path): + os.unlink(dest_path) + elif os.path.isdir(dest_path): + shutil.rmtree(dest_path) + else: + os.remove(dest_path) + if copy_models: + shutil.copytree(model.path, dest_path) + else: + os.symlink(model.path, dest_path) + + proto = bundle_orchestration_pb2.BundleOrchestration() + proto.metadata.name = bundle_def.name + proto.metadata.version = bundle_def.version + + for pipeline_name, steps in bundle_def.pipelines.items(): + bundle_pipeline = bundle_orchestration_pb2.BundlePipeline() + for step in steps: + component = bundle_pipeline.components.add() + component.model_name = step.model + component.pipeline_name = step.pipeline + component.repeated_times = step.repeated_times + component.requires_h2d = step.requires_h2d + component.requires_d2h = step.requires_d2h + + proto.pipelines[pipeline_name].CopyFrom(bundle_pipeline) + + pb_path = os.path.join(output_path, "bundle_orchestration.pb") + with open(pb_path, "wb") as f: + f.write(proto.SerializeToString()) diff --git a/export/orbax/export/bundle_test.py b/export/orbax/export/bundle_test.py new file mode 100644 index 000000000..8bcce2f7c --- /dev/null +++ b/export/orbax/export/bundle_test.py @@ -0,0 +1,248 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test suite for model bundle execution.""" + +import os + +from absl.testing import absltest +from orbax.export import bundle + +from neptune_model.protos import bundle_orchestration_pb2 + + +class BundleTest(absltest.TestCase): + + def test_create_bundle_copy_models(self): + src_workspace = self.create_tempdir('source_workspace').full_path + bundle_dest = self.create_tempdir('bundle_destination').full_path + + src_model_dir = os.path.join(src_workspace, 'src_model') + self._create_dummy_orbax_model(src_model_dir) + + bundle_def = bundle.BundleDefinition( + name='test_bundle', + version=1, + pipelines={ + 'pipeline1': [ + bundle.PipelineStep(model='model1', pipeline='step1'), + ] + } + ) + models = [ + bundle.SubModel(name='model1', path=src_model_dir), + ] + + bundle.create_bundle( + bundle_dest, bundle_def, models, copy_models=True + ) + + self.assertTrue(os.path.exists(bundle_dest)) + + dest_model_dir = os.path.join(bundle_dest, 'model1') + self.assertTrue(os.path.exists(dest_model_dir)) + self.assertTrue(os.path.isdir(dest_model_dir)) + + self._assert_model_files_exist(dest_model_dir) + + proto_path = os.path.join(bundle_dest, 'bundle_orchestration.pb') + self.assertTrue(os.path.exists(proto_path)) + + proto = bundle_orchestration_pb2.BundleOrchestration() + with open(proto_path, 'rb') as f: + proto.ParseFromString(f.read()) + + self.assertEqual(proto.metadata.name, 'test_bundle') + self.assertEqual(proto.metadata.version, 1) + self.assertIn('pipeline1', proto.pipelines) + + pipeline = proto.pipelines['pipeline1'] + self.assertLen(pipeline.components, 1) + self.assertEqual(pipeline.components[0].model_name, 'model1') + self.assertEqual(pipeline.components[0].pipeline_name, 'step1') + self.assertEqual(pipeline.components[0].repeated_times, 1) + self.assertFalse(pipeline.components[0].requires_h2d) + self.assertFalse(pipeline.components[0].requires_d2h) + + def test_create_bundle_copy_multiple_nested_models(self): + src_workspace = self.create_tempdir('source_workspace').full_path + bundle_dest = self.create_tempdir('bundle_destination').full_path + + src_model1_dir = os.path.join(src_workspace, 'src_model1') + self._create_dummy_orbax_model(src_model1_dir) + + src_model2_dir = os.path.join(src_workspace, 'src_model2') + os.makedirs(src_model2_dir) + with open(os.path.join(src_model2_dir, 'model.json'), 'w') as f: + f.write('model2_json_content') + + bundle_def = bundle.BundleDefinition( + name='nested_bundle', + version=2, + pipelines={ + 'pipeline1': [ + bundle.PipelineStep(model='model1', pipeline='step1'), + bundle.PipelineStep(model='model2', pipeline='step2'), + ] + } + ) + models = [ + bundle.SubModel(name='model1', path=src_model1_dir), + bundle.SubModel(name='model2', path=src_model2_dir), + ] + + bundle.create_bundle( + bundle_dest, bundle_def, models, copy_models=True + ) + + dest_model1 = os.path.join(bundle_dest, 'model1') + self.assertTrue(os.path.exists(dest_model1)) + self.assertTrue(os.path.isdir(dest_model1)) + self.assertFalse(os.path.islink(dest_model1)) + + self._assert_model_files_exist(dest_model1) + + dest_model2 = os.path.join(bundle_dest, 'model2') + self.assertTrue(os.path.exists(dest_model2)) + self.assertTrue(os.path.isdir(dest_model2)) + + self.assertTrue(os.path.exists(os.path.join(dest_model2, 'model.json'))) + with open(os.path.join(dest_model2, 'model.json'), 'r') as f: + self.assertEqual(f.read(), 'model2_json_content') + + proto_path = os.path.join(bundle_dest, 'bundle_orchestration.pb') + self.assertTrue(os.path.exists(proto_path)) + + proto = bundle_orchestration_pb2.BundleOrchestration() + with open(proto_path, 'rb') as f: + proto.ParseFromString(f.read()) + + self.assertEqual(proto.metadata.name, 'nested_bundle') + self.assertEqual(proto.metadata.version, 2) + self.assertIn('pipeline1', proto.pipelines) + + pipeline = proto.pipelines['pipeline1'] + self.assertLen(pipeline.components, 2) + + self.assertEqual(pipeline.components[0].model_name, 'model1') + self.assertEqual(pipeline.components[0].pipeline_name, 'step1') + self.assertEqual(pipeline.components[0].repeated_times, 1) + self.assertFalse(pipeline.components[0].requires_h2d) + self.assertFalse(pipeline.components[0].requires_d2h) + + self.assertEqual(pipeline.components[1].model_name, 'model2') + self.assertEqual(pipeline.components[1].pipeline_name, 'step2') + self.assertEqual(pipeline.components[1].repeated_times, 1) + self.assertFalse(pipeline.components[1].requires_h2d) + self.assertFalse(pipeline.components[1].requires_d2h) + + def test_create_bundle_symlink_models(self): + src_workspace = self.create_tempdir('source_workspace').full_path + bundle_dest = self.create_tempdir('bundle_destination').full_path + + src_model_dir = os.path.join(src_workspace, 'src_model') + self._create_dummy_orbax_model(src_model_dir) + + bundle_def = bundle.BundleDefinition( + name='symlink_bundle', + version=1, + pipelines={ + 'pipeline1': [ + bundle.PipelineStep(model='model1', pipeline='step1'), + ] + } + ) + models = [ + bundle.SubModel(name='model1', path=src_model_dir), + ] + + bundle.create_bundle( + bundle_dest, bundle_def, models, copy_models=False + ) + + self.assertTrue(os.path.exists(bundle_dest)) + + dest_model_dir = os.path.join(bundle_dest, 'model1') + self.assertTrue(os.path.exists(dest_model_dir)) + self.assertTrue(os.path.islink(dest_model_dir)) + self.assertEqual(os.readlink(dest_model_dir), src_model_dir) + + proto_path = os.path.join(bundle_dest, 'bundle_orchestration.pb') + self.assertTrue(os.path.exists(proto_path)) + + proto = bundle_orchestration_pb2.BundleOrchestration() + with open(proto_path, 'rb') as f: + proto.ParseFromString(f.read()) + + self.assertEqual(proto.metadata.name, 'symlink_bundle') + self.assertEqual(proto.metadata.version, 1) + self.assertIn('pipeline1', proto.pipelines) + + pipeline = proto.pipelines['pipeline1'] + self.assertLen(pipeline.components, 1) + self.assertEqual(pipeline.components[0].model_name, 'model1') + self.assertEqual(pipeline.components[0].pipeline_name, 'step1') + + def _create_dummy_orbax_model(self, model_dir: str): + """Creates a dummy Orbax Model structure for testing.""" + os.makedirs(os.path.join(model_dir, 'checkpoint')) + with open(os.path.join(model_dir, 'manifest.pb'), 'w') as f: + f.write('dummy_manifest_pb') + with open(os.path.join(model_dir, 'neptune_model_version.txt'), 'w') as f: + f.write('dummy_version') + with open(os.path.join(model_dir, 'orchestration.pb'), 'w') as f: + f.write('dummy_orchestration') + with open(os.path.join(model_dir, 'predict.shlo'), 'w') as f: + f.write('dummy_shlo') + for i in range(1, 4): + checkpoint_path = os.path.join( + model_dir, 'checkpoint', f'checkpoint_{i}' + ) + with open(checkpoint_path, 'w') as f: + f.write(f'dummy_checkpoint_{i}') + with open(os.path.join(model_dir, 'checkpoint', 'metadata'), 'w') as f: + f.write('dummy_metadata') + + def _assert_model_files_exist(self, model_dir: str): + """Asserts that the expected dummy Orbax Model files exist.""" + for filename in [ + 'manifest.pb', + 'neptune_model_version.txt', + 'predict.shlo', + ]: + self.assertTrue(os.path.exists(os.path.join(model_dir, filename))) + for i in range(1, 4): + self.assertTrue( + os.path.exists( + os.path.join(model_dir, 'checkpoint', f'checkpoint_{i}') + ) + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/export/pyproject.toml b/export/pyproject.toml index c135b78a9..a4d2d2a40 100644 --- a/export/pyproject.toml +++ b/export/pyproject.toml @@ -7,7 +7,7 @@ name = "orbax-export" description = "Orbax Export" readme = 'README.md' license = {file = 'LICENSE'} -requires-python = '>=3.10' +requires-python = '>=3.11' authors = [{name = 'Orbax Authors', email='orbax-dev@google.com'}] classifiers=[ 'Development Status :: 4 - Beta', @@ -24,12 +24,13 @@ dependencies = [ 'absl-py', 'dataclasses-json', 'etils', - 'jax >= 0.4.34', + 'jax >= 0.6.0', 'jaxlib', 'jaxtyping', 'numpy', 'protobuf', - "orbax-checkpoint >=0.9.0" + "orbax-checkpoint >=0.9.0", + "neptune-model" ] dynamic = ['version']