Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ jobs:
working-directory: export
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml
- python-version: "3.11"
jax-version: "0.6.0" # keep in sync with minimum version in export/pyproject.toml
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it.
# jax-version: "nightly"
Expand Down
4 changes: 3 additions & 1 deletion export/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines exported symbols for Orbax Export."""

# pylint: disable=g-importing-member
from orbax.export import bundle
from orbax.export import config
from orbax.export import constants
from orbax.export import obm_configs
Expand All @@ -29,14 +30,15 @@
from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager

from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig

# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
from orbax.export.utils import remove_signature_defaults
from orbax.export.utils import TensorSpecWithDefault


# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.0.8'
4 changes: 3 additions & 1 deletion export/orbax/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines exported symbols for Orbax Export."""

# pylint: disable=g-importing-member
from orbax.export import bundle
from orbax.export import config
from orbax.export import constants
from orbax.export import obm_configs
Expand All @@ -29,14 +30,15 @@
from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on
from orbax.export.dtensor_utils import shutdown_dtensor
from orbax.export.export_manager import ExportManager

from orbax.export.jax_module import JaxModule
from orbax.export.obm_configs import Jax2ObmOptions
from orbax.export.serving_config import ServingConfig

# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
from orbax.export.utils import remove_signature_defaults
from orbax.export.utils import TensorSpecWithDefault


# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.0.8'
134 changes: 134 additions & 0 deletions export/orbax/export/bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bundle orchestration export utilities."""

import dataclasses
import os
import shutil
from typing import Dict, List

from neptune_model.protos import bundle_orchestration_pb2


@dataclasses.dataclass
class SubModel:
"""Defines a unique model in the bundle.

Attributes:
name: Name of the sub-model subdirectory in the bundle.
path: Path to the exported sub-model.
"""

name: str
path: str


@dataclasses.dataclass
class PipelineStep:
"""Defines one model's execution within the bundle.

Attributes:
model: Name of the sub-model (must match a SubModel.name).
pipeline: Name of the pipeline in the sub-model to execute.
repeated_times: Number of times to repeat this step.
requires_h2d: Whether this step requires Host-to-Device transfer.
requires_d2h: Whether this step requires Device-to-Host transfer.
"""

model: str
pipeline: str
repeated_times: int = 1
requires_h2d: bool = False
requires_d2h: bool = False


@dataclasses.dataclass
class BundleDefinition:
"""Defines organization of sub-models in the bundle.

Attributes:
name: Human-readable name of the bundle.
version: Version of the bundle.
pipelines: Maps pipeline name to a sequential list of steps.
"""

name: str
version: int
pipelines: Dict[str, List[PipelineStep]]


def create_bundle(
output_path: str,
bundle_def: BundleDefinition,
models: List[SubModel],
copy_models: bool = False,
):
"""Creates the bundle directory, symlinks/copies sub-models, and writes the proto.

Args:
output_path: Path where the bundle should be created.
bundle_def: Definition of the bundle pipelines and metadata.
models: List of sub-models to include in the bundle.
copy_models: If True, copies the sub-models instead of symlinking them.
"""
os.makedirs(output_path, exist_ok=True)

for model in models:
dest_path = os.path.join(output_path, model.name)
if os.path.lexists(dest_path):
if os.path.islink(dest_path):
os.unlink(dest_path)
elif os.path.isdir(dest_path):
shutil.rmtree(dest_path)
else:
os.remove(dest_path)
if copy_models:
shutil.copytree(model.path, dest_path)
else:
os.symlink(model.path, dest_path)

proto = bundle_orchestration_pb2.BundleOrchestration()
proto.metadata.name = bundle_def.name
proto.metadata.version = bundle_def.version

for pipeline_name, steps in bundle_def.pipelines.items():
bundle_pipeline = bundle_orchestration_pb2.BundlePipeline()
for step in steps:
component = bundle_pipeline.components.add()
component.model_name = step.model
component.pipeline_name = step.pipeline
component.repeated_times = step.repeated_times
component.requires_h2d = step.requires_h2d
component.requires_d2h = step.requires_d2h

proto.pipelines[pipeline_name].CopyFrom(bundle_pipeline)

pb_path = os.path.join(output_path, "bundle_orchestration.pb")
with open(pb_path, "wb") as f:
f.write(proto.SerializeToString())
Loading
Loading