Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class SingleReplicaArrayRestoreArgs(ArrayRestoreArgs):

def __post_init__(self):
super().__post_init__()
logging.log_first_n(
logging.WARNING,
'`single_replica_sharding` is deprecated and will be removed in a'
' future version. It is not needed, as Orbax code will automatically'
' construct a single-replica sharding used for restoring before'
' broadcasting.',
1,
)
if self.single_replica_sharding is not None:
logging.log_first_n(
logging.WARNING,
'`single_replica_sharding` is deprecated and will be removed in a'
' future version. It is not needed, as Orbax code will automatically'
' construct a single-replica sharding used for restoring before'
' broadcasting.',
1,
)

250 changes: 250 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""Database initialization utilities for Tiering Service."""

import contextlib
import datetime
import sqlite3

from orbax.checkpoint.experimental.tiering_service import db_schema
from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2
import sqlalchemy
from sqlalchemy import event
from sqlalchemy.dialects.sqlite.aiosqlite import AsyncAdapt_aiosqlite_connection
from sqlalchemy.engine import Engine
Expand All @@ -27,6 +29,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.future import select
import sqlalchemy.orm
from sqlalchemy.orm import sessionmaker


Expand Down Expand Up @@ -228,3 +231,250 @@ async def async_verify_db(config: tiering_service_pb2.ServerConfig) -> None:
f" prefix: DB has {db_backend.prefix!r}, config expects"
f" {instance.prefix!r}"
)


async def get_active_jobs(
session: AsyncSession, hostname: str, pid: int
) -> list[db_schema.AssetJob]:
"""Returns all active PROCESSING jobs owned by this worker."""
stmt = (
select(db_schema.AssetJob)
.options(
sqlalchemy.orm.selectinload(
db_schema.AssetJob.target_tier_path
).selectinload(db_schema.TierPath.storage_backend),
sqlalchemy.orm.selectinload(db_schema.AssetJob.asset)
.selectinload(db_schema.Asset.tier_paths)
.selectinload(db_schema.TierPath.storage_backend),
)
.where(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.worker_host == hostname,
db_schema.AssetJob.worker_pid == pid,
)
)
result = await session.execute(stmt)
return list(result.scalars().all())


async def _has_eligible_jobs(
session: AsyncSession,
backend_id: int | None,
now: datetime.datetime,
) -> bool:
"""Checks if there are any eligible jobs for the backend without locking."""
active_assets_subquery = (
select(db_schema.AssetJob.asset_uuid)
.where(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.expiration_at >= now,
)
.scalar_subquery()
)

if backend_id is None:
backend_cond = db_schema.AssetJob.target_tier_path_id.is_(None)
else:
backend_cond = db_schema.TierPath.storage_backend_id == backend_id

stmt = (
select(db_schema.AssetJob.id)
.join(
db_schema.TierPath,
db_schema.AssetJob.target_tier_path_id == db_schema.TierPath.id,
isouter=True,
)
.where(
sqlalchemy.or_(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_QUEUED,
sqlalchemy.and_(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.expiration_at < now,
),
),
~db_schema.AssetJob.asset_uuid.in_(active_assets_subquery),
backend_cond,
)
.limit(1)
)
result = await session.execute(stmt)
return result.scalar() is not None


async def _try_lock_backend(
session: AsyncSession,
backend_id: int,
max_active: int,
now: datetime.datetime,
) -> bool:
"""Locks the backend and returns True if it has capacity for more jobs."""
# Lock only this specific backend to avoid race conditions.
backend_stmt = (
select(db_schema.StorageBackend)
.where(db_schema.StorageBackend.id == backend_id)
.with_for_update()
)
await session.execute(backend_stmt)

# Re-evaluate active capacity under the backend lock.
active_count_stmt = (
select(sqlalchemy.func.count(db_schema.AssetJob.id))
.join(
db_schema.TierPath,
db_schema.AssetJob.target_tier_path_id == db_schema.TierPath.id,
)
.where(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.expiration_at >= now,
db_schema.TierPath.storage_backend_id == backend_id,
)
)
active_count_result = await session.execute(active_count_stmt)
active_count = active_count_result.scalar()

return active_count < max_active


async def _claim_eligible_job(
session: AsyncSession,
backend_id: int | None,
lease_duration: datetime.timedelta,
hostname: str,
pid: int,
now: datetime.datetime,
) -> db_schema.AssetJob | None:
"""Fetches and claims the next eligible job, if any."""
active_assets_subquery = (
select(db_schema.AssetJob.asset_uuid)
.where(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.expiration_at >= now,
)
.scalar_subquery()
)

if backend_id is None:
backend_cond = db_schema.AssetJob.target_tier_path_id.is_(None)
else:
backend_cond = db_schema.TierPath.storage_backend_id == backend_id

# Fetch the next eligible job, filtering for:
# 1. Jobs that are queued or whose execution lease has expired (stale jobs).
# 2. Jobs targeting assets that aren't already actively being processed by
# another job, preventing concurrency conflicts on the same asset.
# 3. Jobs belonging to the requested storage backend.
# We select the oldest job (FIFO order) and use SKIP LOCKED concurrency
# control to prevent multiple workers from matching or blocking on the same
# job.
stmt = (
select(db_schema.AssetJob)
.options(
sqlalchemy.orm.selectinload(
db_schema.AssetJob.target_tier_path
).selectinload(db_schema.TierPath.storage_backend),
sqlalchemy.orm.selectinload(db_schema.AssetJob.asset)
.selectinload(db_schema.Asset.tier_paths)
.selectinload(db_schema.TierPath.storage_backend),
)
.join(
db_schema.TierPath,
db_schema.AssetJob.target_tier_path_id == db_schema.TierPath.id,
isouter=True,
)
.where(
sqlalchemy.or_(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_QUEUED,
sqlalchemy.and_(
db_schema.AssetJob.status
== db_schema.JobStatus.JOB_STATUS_PROCESSING,
db_schema.AssetJob.expiration_at < now,
),
),
~db_schema.AssetJob.asset_uuid.in_(active_assets_subquery),
backend_cond,
)
.order_by(db_schema.AssetJob.created_at.asc())
.limit(1)
.with_for_update(skip_locked=True)
)

result = await session.execute(stmt)
job = result.scalars().first()

if job:
# Atomically claim the job
job.status = db_schema.JobStatus.JOB_STATUS_PROCESSING
job.expiration_at = now + lease_duration
job.worker_host = hostname
job.worker_pid = pid
job.last_updated_at = now
session.add(job)
if (
job.request_type == db_schema.RequestType.REQUEST_TYPE_COPY
and job.target_tier_path
):
job.target_tier_path.state = db_schema.TierPathState.IN_PROGRESS
session.add(job.target_tier_path)

return job


async def acquire_next_job(
session_maker: sessionmaker,
backend_id: int | None,
lease_duration: datetime.timedelta,
hostname: str,
pid: int,
max_active: int,
) -> db_schema.AssetJob | None:
"""Queries the database for the next eligible job on the given backend and claims it.

Args:
session_maker: A session maker or session factory. MUST be configured with
`expire_on_commit=False` to prevent returned Job relationships from being
expired upon transaction commit.
backend_id: The ID of the storage backend.
lease_duration: Lease duration for the claimed job.
hostname: Hostname of the claiming worker.
pid: PID of the claiming worker.
max_active: Maximum active jobs allowed on this backend.

Returns:
The claimed AssetJob instance, or None if no eligible jobs are available or
if capacity is full.
"""
now = datetime.datetime.now(datetime.timezone.utc)

async with session_maker() as session:
await session.begin()
try:
# Check if there are any jobs at all before acquiring locks
if not await _has_eligible_jobs(session, backend_id, now):
await session.rollback()
return None

if backend_id is not None:
if not await _try_lock_backend(session, backend_id, max_active, now):
# no jobs available on this backend, release lock and return None
await session.rollback()
return None

job = await _claim_eligible_job(
session, backend_id, lease_duration, hostname, pid, now
)
if job is None:
await session.rollback()
return None

await session.commit()
return job
except Exception:
await session.rollback()
raise
Loading
Loading