From bc8b539f2f9ee18d941b2c59ff52925ab7972a2f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 2 Jun 2026 23:42:14 +0200 Subject: [PATCH] Fix no-torch PE rebuild import path --- src/microplex_us/__init__.py | 157 ++++++++-------- .../pe_us_data_rebuild_checkpoint.py | 175 ++++++++++++++---- tests/test_package_imports.py | 102 +++++++--- 3 files changed, 295 insertions(+), 139 deletions(-) diff --git a/src/microplex_us/__init__.py b/src/microplex_us/__init__.py index ada081a..3377545 100644 --- a/src/microplex_us/__init__.py +++ b/src/microplex_us/__init__.py @@ -5,50 +5,41 @@ from importlib import import_module from typing import Any -from microplex.targets import TargetSet, TargetSpec - -from microplex_us.calibration_harness import ( - CalibrationHarness, - CalibrationResult, - run_pe_parity_suite, +_CALIBRATION_HARNESS_EXPORTS = ( + "CalibrationHarness", + "CalibrationResult", + "run_pe_parity_suite", ) -from microplex_us.cps_synthetic import ( - CPSSummaryStats, - CPSSyntheticGenerator, - validate_synthetic, +_CPS_SYNTHETIC_EXPORTS = ( + "CPSSummaryStats", + "CPSSyntheticGenerator", + "validate_synthetic", ) -from microplex_us.data import ( - create_sample_data, - get_data_info, - load_cps_asec, - load_cps_for_synthesis, +_DATA_EXPORTS = ( + "create_sample_data", + "get_data_info", + "load_cps_asec", + "load_cps_for_synthesis", ) - -try: - from microplex_us.geography import ( - BLOCK_LEN, - COUNTY_LEN, - STATE_LEN, - TRACT_LEN, - BlockGeography, - derive_geographies, - load_block_probabilities, - normalize_us_state_fips, - ) -except ImportError: - BLOCK_LEN = None - COUNTY_LEN = None - STATE_LEN = None - TRACT_LEN = None - BlockGeography = None - derive_geographies = None - load_block_probabilities = None - normalize_us_state_fips = None -from microplex_us.hierarchical import prepare_cps_for_hierarchical -from microplex_us.pe_targets import ( - PETargets, - create_calibration_targets, - get_pe_targets, +_GEOGRAPHY_EXPORTS = ( + "BLOCK_LEN", + "COUNTY_LEN", + "STATE_LEN", + "TRACT_LEN", + "BlockGeography", + "derive_geographies", + "load_block_probabilities", + "normalize_us_state_fips", +) +_HIERARCHICAL_EXPORTS = ("prepare_cps_for_hierarchical",) +_MICROPLEX_TARGET_EXPORTS = ( + "TargetSet", + "TargetSpec", +) +_PE_TARGETS_EXPORTS = ( + "PETargets", + "create_calibration_targets", + "get_pe_targets", ) _PIPELINE_EXPORTS = ( @@ -180,61 +171,56 @@ "SourceVariablePolicySpec", "resolve_source_variable_capabilities", ) -from microplex_us.target_registry import ( - TargetCategory, - TargetGroup, - TargetLevel, - TargetRegistry, - get_registry, - print_registry_summary, +_TARGET_REGISTRY_EXPORTS = ( + "TargetCategory", + "TargetGroup", + "TargetLevel", + "TargetRegistry", + "get_registry", + "print_registry_summary", ) - _TARGETS_EXPORTS = ( "POLICYENGINE_US_COUNT_ENTITIES", "policyengine_db_target_to_canonical_spec", "policyengine_db_targets_to_canonical_set", ) -from microplex_us.unified_calibration import ( - CalibrationTarget, - UnifiedCalibrator, - calibrate_to_pe_targets, +_UNIFIED_CALIBRATION_EXPORTS = ( + "CalibrationTarget", + "UnifiedCalibrator", + "calibrate_to_pe_targets", +) +_VALIDATION_EXPORTS = ( + "AGI_BRACKETS", + "FILING_STATUSES", + "BaselineComparison", + "MetricComparison", + "SOITargets", + "ValidationResult", + "compute_baseline_comparison", + "compute_validation_metrics", + "export_comparison_json", + "get_soi_years", + "load_soi_targets", + "validate_against_soi", ) - -try: - from microplex_us.validation import ( - AGI_BRACKETS, - FILING_STATUSES, - BaselineComparison, - MetricComparison, - SOITargets, - ValidationResult, - compute_baseline_comparison, - compute_validation_metrics, - export_comparison_json, - get_soi_years, - load_soi_targets, - validate_against_soi, - ) -except ImportError: - AGI_BRACKETS = None - FILING_STATUSES = None - BaselineComparison = None - MetricComparison = None - SOITargets = None - ValidationResult = None - compute_baseline_comparison = None - compute_validation_metrics = None - export_comparison_json = None - get_soi_years = None - load_soi_targets = None - validate_against_soi = None _LAZY_EXPORT_MODULES: dict[str, str] = { + **dict.fromkeys(_CALIBRATION_HARNESS_EXPORTS, "microplex_us.calibration_harness"), + **dict.fromkeys(_CPS_SYNTHETIC_EXPORTS, "microplex_us.cps_synthetic"), + **dict.fromkeys(_DATA_EXPORTS, "microplex_us.data"), + **dict.fromkeys(_GEOGRAPHY_EXPORTS, "microplex_us.geography"), + **dict.fromkeys(_HIERARCHICAL_EXPORTS, "microplex_us.hierarchical"), + **dict.fromkeys(_MICROPLEX_TARGET_EXPORTS, "microplex.targets"), + **dict.fromkeys(_PE_TARGETS_EXPORTS, "microplex_us.pe_targets"), **dict.fromkeys(_PIPELINE_EXPORTS, "microplex_us.pipelines"), **dict.fromkeys(_POLICYENGINE_EXPORTS, "microplex_us.policyengine"), **dict.fromkeys(_SOURCE_REGISTRY_EXPORTS, "microplex_us.source_registry"), + **dict.fromkeys(_TARGET_REGISTRY_EXPORTS, "microplex_us.target_registry"), **dict.fromkeys(_TARGETS_EXPORTS, "microplex_us.targets"), + **dict.fromkeys(_UNIFIED_CALIBRATION_EXPORTS, "microplex_us.unified_calibration"), + **dict.fromkeys(_VALIDATION_EXPORTS, "microplex_us.validation"), } +_OPTIONAL_NONE_EXPORTS = frozenset((*_GEOGRAPHY_EXPORTS, *_VALIDATION_EXPORTS)) def __getattr__(name: str) -> Any: @@ -242,7 +228,12 @@ def __getattr__(name: str) -> Any: module_name = _LAZY_EXPORT_MODULES.get(name) if module_name is None: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - value = getattr(import_module(module_name), name) + try: + value = getattr(import_module(module_name), name) + except ImportError: + if name not in _OPTIONAL_NONE_EXPORTS: + raise + value = None globals()[name] = value return value diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py index a37f993..ba15b70 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py @@ -14,25 +14,7 @@ import h5py import numpy as np import pandas as pd -from microplex.core import ( - EntityObservation, - EntityType, - ObservationFrame, - SourceDescriptor, - SourceQuery, -) -from microplex.targets import assert_valid_benchmark_artifact_manifest -from microplex_us.pipelines.artifacts import ( - USMicroplexArtifactPaths, - USMicroplexVersionedBuildArtifacts, - build_and_save_versioned_us_microplex_from_source_providers, -) -from microplex_us.pipelines.imputation_ablation import ( - ImputationAblationSliceSpec, - ImputationAblationVariant, - score_imputation_ablation_variants, -) from microplex_us.pipelines.index_db import append_us_microplex_run_index_entry from microplex_us.pipelines.pe_us_data_rebuild import ( PEUSDataRebuildProgram, @@ -40,13 +22,6 @@ default_policyengine_us_data_rebuild_program, default_policyengine_us_data_rebuild_source_providers, ) -from microplex_us.pipelines.pe_us_data_rebuild_audit import ( - build_policyengine_us_data_rebuild_native_audit, -) -from microplex_us.pipelines.pe_us_data_rebuild_parity import ( - build_policyengine_us_data_rebuild_parity_artifact, - write_policyengine_us_data_rebuild_parity_artifact, -) from microplex_us.pipelines.registry import ( append_us_microplex_run_registry_entry, build_us_microplex_run_registry_entry, @@ -58,18 +33,26 @@ resolve_us_stage_artifact_contract_path, ) from microplex_us.pipelines.stage_metrics import stage_metrics -from microplex_us.pipelines.stage_run import ( - USStageInputOverride, - parse_us_stage_input_override, - write_us_stage_run_manifests_from_artifact_manifest, -) -from microplex_us.variables import prune_redundant_variables if TYPE_CHECKING: - from microplex.core import SourceProvider + from microplex.core import ( + EntityObservation, + ObservationFrame, + SourceDescriptor, + SourceProvider, + SourceQuery, + ) from microplex.targets import TargetProvider + from microplex_us.pipelines.artifacts import ( + USMicroplexVersionedBuildArtifacts, + ) + from microplex_us.pipelines.imputation_ablation import ( + ImputationAblationSliceSpec, + ImputationAblationVariant, + ) from microplex_us.pipelines.registry import FrontierMetric + from microplex_us.pipelines.stage_run import USStageInputOverride from microplex_us.pipelines.us import USMicroplexBuildConfig from microplex_us.policyengine.harness import ( PolicyEngineUSComparisonCache, @@ -82,6 +65,96 @@ LOGGER = logging.getLogger(__name__) +def assert_valid_benchmark_artifact_manifest(*args: Any, **kwargs: Any) -> Any: + from microplex.targets import ( + assert_valid_benchmark_artifact_manifest as _assert_valid, + ) + + return _assert_valid(*args, **kwargs) + + +def build_and_save_versioned_us_microplex_from_source_providers( + *args: Any, + **kwargs: Any, +) -> Any: + from microplex_us.pipelines.artifacts import ( + build_and_save_versioned_us_microplex_from_source_providers as _build, + ) + + return _build( + *args, + **kwargs, + ) + + +def score_imputation_ablation_variants(*args: Any, **kwargs: Any) -> Any: + from microplex_us.pipelines.imputation_ablation import ( + score_imputation_ablation_variants as _score_imputation_ablation_variants, + ) + + return _score_imputation_ablation_variants(*args, **kwargs) + + +def build_policyengine_us_data_rebuild_native_audit( + *args: Any, + **kwargs: Any, +) -> Any: + from microplex_us.pipelines.pe_us_data_rebuild_audit import ( + build_policyengine_us_data_rebuild_native_audit as _build_audit, + ) + + return _build_audit(*args, **kwargs) + + +def build_policyengine_us_data_rebuild_parity_artifact( + *args: Any, + **kwargs: Any, +) -> Any: + from microplex_us.pipelines.pe_us_data_rebuild_parity import ( + build_policyengine_us_data_rebuild_parity_artifact as _build_parity, + ) + + return _build_parity(*args, **kwargs) + + +def write_policyengine_us_data_rebuild_parity_artifact( + *args: Any, + **kwargs: Any, +) -> Any: + from microplex_us.pipelines.pe_us_data_rebuild_parity import ( + write_policyengine_us_data_rebuild_parity_artifact as _write_parity, + ) + + return _write_parity(*args, **kwargs) + + +def parse_us_stage_input_override(*args: Any, **kwargs: Any) -> Any: + from microplex_us.pipelines.stage_run import ( + parse_us_stage_input_override as _parse_us_stage_input_override, + ) + + return _parse_us_stage_input_override(*args, **kwargs) + + +def write_us_stage_run_manifests_from_artifact_manifest( + *args: Any, + **kwargs: Any, +) -> Any: + from microplex_us.pipelines.stage_run import ( + write_us_stage_run_manifests_from_artifact_manifest as _write_manifests, + ) + + return _write_manifests(*args, **kwargs) + + +def prune_redundant_variables(*args: Any, **kwargs: Any) -> Any: + from microplex_us.variables import ( + prune_redundant_variables as _prune_redundant_variables, + ) + + return _prune_redundant_variables(*args, **kwargs) + + def _root_logger_has_handlers() -> bool: return bool(logging.getLogger().handlers) @@ -303,6 +376,8 @@ def _infer_policyengine_baseline_household_weight_sum( def _checkpoint_imputation_ablation_variants() -> tuple[ImputationAblationVariant, ...]: + from microplex_us.pipelines.imputation_ablation import ImputationAblationVariant + return ( ImputationAblationVariant( name="broad_common_qrf", @@ -325,6 +400,8 @@ def _checkpoint_imputation_ablation_variants() -> tuple[ImputationAblationVarian def _checkpoint_imputation_ablation_slice_specs() -> tuple[ ImputationAblationSliceSpec, ... ]: + from microplex_us.pipelines.imputation_ablation import ImputationAblationSliceSpec + return ( ImputationAblationSliceSpec( name="state_by_age", @@ -397,6 +474,8 @@ def _build_checkpoint_source_descriptor( person_variables: set[str] | None = None, name: str | None = None, ) -> SourceDescriptor | None: + from microplex.core import EntityObservation, EntityType, SourceDescriptor + def _build_observation( entity: EntityType, table: pd.DataFrame, @@ -463,6 +542,8 @@ def _build_observation( def _household_person_relationship(frame: ObservationFrame) -> Any: + from microplex.core import EntityType + relationship = next( ( candidate @@ -506,6 +587,8 @@ def _subset_checkpoint_frame_to_households( *, source: SourceDescriptor, ) -> ObservationFrame | None: + from microplex.core import EntityType, ObservationFrame + relationship = _household_person_relationship(frame) households = frame.tables[EntityType.HOUSEHOLD] persons = frame.tables[EntityType.PERSON] @@ -548,6 +631,8 @@ def _split_checkpoint_household_ids( eval_fraction: float, random_seed: int, ) -> tuple[tuple[Any, ...], tuple[Any, ...]] | None: + from microplex.core import EntityType + relationship = _household_person_relationship(frame) household_ids = ( frame.tables[EntityType.HOUSEHOLD][relationship.parent_key] @@ -573,6 +658,8 @@ def _build_checkpoint_holdout_scaffold_source( *, masked_target_variables: set[str] | None = None, ) -> SourceDescriptor | None: + from microplex.core import EntityType + excluded_variables = set(masked_target_variables or ()) return _build_checkpoint_source_descriptor( base_source=scaffold_source, @@ -593,6 +680,8 @@ def _resolve_checkpoint_imputation_targets( donor_input: Any, current_seed: pd.DataFrame, ) -> tuple[list[str], list[str]]: + from microplex.core import EntityType + scaffold_observed = prune_redundant_variables( scaffold_input.fusion_plan.variables_for(EntityType.HOUSEHOLD) | scaffold_input.fusion_plan.variables_for(EntityType.PERSON) @@ -830,6 +919,8 @@ def _build_checkpoint_imputation_ablation_payload( if build_result.source_frame is None or not build_result.source_frames: return None + from microplex.core import EntityType + from microplex_us.pipelines.us import USMicroplexPipeline pipeline = USMicroplexPipeline(build_result.config) @@ -1235,6 +1326,11 @@ def _load_checkpoint_versioned_artifacts( artifact_root: Path, frontier_metric: FrontierMetric, ) -> USMicroplexVersionedBuildArtifacts: + from microplex_us.pipelines.artifacts import ( + USMicroplexArtifactPaths, + USMicroplexVersionedBuildArtifacts, + ) + manifest_path = artifact_root / "manifest.json" manifest = json.loads(manifest_path.read_text()) artifacts = dict(manifest.get("artifacts", {})) @@ -1906,6 +2002,8 @@ def default_policyengine_us_data_rebuild_queries( ) -> dict[str, SourceQuery]: """Return default provider queries for a rebuild checkpoint smoke run.""" + from microplex.core import SourceQuery + from microplex_us.data_sources.cps import CPSASECSourceProvider from microplex_us.data_sources.donor_surveys import DonorSurveySourceProvider from microplex_us.data_sources.puf import PUFSourceProvider @@ -2232,9 +2330,8 @@ def run_policyengine_us_data_rebuild_checkpoint( ) -def main(argv: list[str] | None = None) -> None: - """CLI entry point for one PE-US-data rebuild checkpoint.""" - +def build_policyengine_us_data_rebuild_checkpoint_parser() -> argparse.ArgumentParser: + """Build the PE-US-data rebuild checkpoint parser without runtime imports.""" parser = argparse.ArgumentParser( description="Run a versioned PE-US-data rebuild checkpoint in microplex-us." ) @@ -2405,7 +2502,15 @@ def main(argv: list[str] | None = None) -> None: metavar="STAGE_ID.KEY=PATH", help=("Explicit stage input override. Requires --allow-stage-input-overrides."), ) + return parser + + +def main(argv: list[str] | None = None) -> None: + """CLI entry point for one PE-US-data rebuild checkpoint.""" + + parser = build_policyengine_us_data_rebuild_checkpoint_parser() args = parser.parse_args(argv) + stage_input_overrides = tuple( parse_us_stage_input_override(value) for value in args.stage_input_override ) diff --git a/tests/test_package_imports.py b/tests/test_package_imports.py index 7328cab..5cb054a 100644 --- a/tests/test_package_imports.py +++ b/tests/test_package_imports.py @@ -4,38 +4,98 @@ import subprocess import sys +import textwrap -def test_root_import_leaves_pipeline_exports_lazy() -> None: - result = subprocess.run( - [ - sys.executable, - "-c", - ("import microplex_us; print('build_us_microplex' in vars(microplex_us))"), - ], +def _run_python(source: str) -> subprocess.CompletedProcess[str]: + return subprocess.run( + [sys.executable, "-c", textwrap.dedent(source)], check=True, capture_output=True, text=True, ) + +_BLOCK_TORCH_IMPORTS = """ +import importlib.abc +import sys + + +class BlockTorch(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path=None, target=None): + if fullname == "torch" or fullname.startswith("torch."): + raise ModuleNotFoundError("No module named 'torch'") + return None + + +sys.meta_path.insert(0, BlockTorch()) +""" + + +def test_root_import_leaves_pipeline_exports_lazy() -> None: + result = _run_python( + "import microplex_us; print('build_us_microplex' in vars(microplex_us))" + ) + + assert result.stdout.strip() == "False" + + +def test_root_import_does_not_require_torch_or_core_microplex() -> None: + result = _run_python( + _BLOCK_TORCH_IMPORTS + + """ +import microplex_us +print("microplex" in sys.modules) +print("TargetSpec" in vars(microplex_us)) + """ + ) + + assert result.stdout.splitlines() == ["False", "False"] + + +def test_pe_rebuild_checkpoint_import_does_not_require_torch() -> None: + result = _run_python( + _BLOCK_TORCH_IMPORTS + + """ +import microplex_us.pipelines.pe_us_data_rebuild_checkpoint +print("microplex" in sys.modules) + """ + ) + assert result.stdout.strip() == "False" +def test_pe_rebuild_checkpoint_help_does_not_require_torch_or_core_microplex() -> None: + result = _run_python( + _BLOCK_TORCH_IMPORTS + + """ +import runpy + +sys.argv = ["pe_us_data_rebuild_checkpoint", "--help"] +try: + runpy.run_module( + "microplex_us.pipelines.pe_us_data_rebuild_checkpoint", + run_name="__main__", + ) +except SystemExit as exc: + print(f"exit={exc.code}") +print(f"microplex_imported={'microplex' in sys.modules}") + """ + ) + + assert result.stdout.splitlines()[-2:] == [ + "exit=0", + "microplex_imported=False", + ] + + def test_data_sources_import_leaves_family_benchmark_lazy() -> None: - result = subprocess.run( - [ - sys.executable, - "-c", - ( - "import sys; " - "import microplex_us.data_sources; " - "print('microplex_us.data_sources.family_imputation_benchmark' " - "in sys.modules)" - ), - ], - check=True, - capture_output=True, - text=True, + result = _run_python( + """ + import sys + import microplex_us.data_sources + print("microplex_us.data_sources.family_imputation_benchmark" in sys.modules) + """ ) assert result.stdout.strip() == "False"