Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 136 additions & 12 deletions cumulusci/tasks/bulkdata/generate_and_load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

from cumulusci.core.config import TaskConfig
from cumulusci.core.exceptions import TaskOptionsError
from cumulusci.core.utils import import_global
from cumulusci.core.utils import import_global, process_bool_arg
from cumulusci.tasks.bulkdata import LoadData
from cumulusci.tasks.bulkdata.mapping_parser import (
parse_from_yaml,
validate_and_inject_mapping,
)
from cumulusci.tasks.bulkdata.step import DataOperationType
from cumulusci.tasks.bulkdata.utils import generate_batches
from cumulusci.tasks.salesforce import BaseSalesforceApiTask

Expand Down Expand Up @@ -79,6 +84,10 @@ class GenerateAndLoadData(BaseSalesforceApiTask):
"working_directory": {
"description": "Store temporary files in working_directory for easier debugging."
},
"validate_only": {
"description": "Boolean: if True, only validate the generated mapping against the org schema without loading data. "
"Defaults to False."
},
**LoadData.task_options,
}
task_options["mapping"]["required"] = False
Expand Down Expand Up @@ -114,6 +123,7 @@ def _init_options(self, kwargs):

self.working_directory = self.options.get("working_directory", None)
self.database_url = self.options.get("database_url")
self.validate_only = process_bool_arg(self.options.get("validate_only", False))

if self.database_url:
engine, metadata = self._setup_engine(self.database_url)
Expand All @@ -132,6 +142,16 @@ def _run_task(self):
if working_directory:
tempdir = Path(working_directory)
tempdir.mkdir(exist_ok=True)

# Route to validation flow if validate_only is True
if self.validate_only:
return self._run_validation(
database_url=self.database_url,
tempdir=self.working_directory or tempdir,
mapping_file=self.mapping_file,
)

# Normal data generation and loading flow
if self.batch_size:
batches = generate_batches(self.num_records, self.batch_size)
else:
Expand Down Expand Up @@ -186,6 +206,47 @@ def _generate_batch(
total_batches: int,
) -> dict:
"""Generate a batch in database_url or a tempfile if it isn't specified."""
# Setup and generate data
subtask_options = self._setup_and_generate_data(
database_url=database_url,
tempdir=tempdir,
mapping_file=mapping_file,
num_records=batch_size,
batch_index=index,
)

# Load the data
return self._dataload(subtask_options)

def _setup_engine(self, database_url):
"""Set up the database engine"""
engine = create_engine(database_url)

metadata = MetaData(engine)
metadata.reflect()
return engine, metadata

def _setup_and_generate_data(
self,
*,
database_url: Optional[str],
tempdir: Union[Path, str, None],
mapping_file: Union[Path, str, None],
num_records: Optional[int],
batch_index: int,
) -> dict:
"""Setup database and generate data, returning subtask options with mapping.

Args:
database_url: Database URL or None to create temp SQLite
tempdir: Temporary directory for generated files
mapping_file: Path to mapping file or None to generate
num_records: Number of records to generate
batch_index: Current batch number

Returns:
dict: subtask_options with mapping file path set
"""
if not database_url:
sqlite_path = Path(tempdir) / "generated_data.db"
database_url = f"sqlite:///{sqlite_path}"
Expand All @@ -197,28 +258,91 @@ def _generate_batch(
"mapping": mapping_file,
"reset_oids": False,
"database_url": database_url,
"num_records": batch_size,
"current_batch_number": index,
"num_records": num_records,
"current_batch_number": batch_index,
"working_directory": tempdir,
}

# some generator tasks can generate the mapping file instead of reading it
# Generate mapping file if needed
if not subtask_options.get("mapping"):
temp_mapping = Path(tempdir) / "temp_mapping.yml"
mapping_file = self.options.get("generate_mapping_file", temp_mapping)
subtask_options["generate_mapping_file"] = mapping_file

# Run data generation
self._datagen(subtask_options)

if not subtask_options.get("mapping"):
subtask_options["mapping"] = mapping_file
return self._dataload(subtask_options)
subtask_options["mapping"] = subtask_options["generate_mapping_file"]

def _setup_engine(self, database_url):
"""Set up the database engine"""
engine = create_engine(database_url)
return subtask_options

metadata = MetaData(engine)
metadata.reflect()
return engine, metadata
def _run_validation(
self,
*,
database_url: Optional[str],
tempdir: Union[Path, str, None],
mapping_file: Union[Path, str, None],
):
"""Run validation flow: generate data once and validate mapping.

Args:
database_url: Database URL or None to create temp SQLite
tempdir: Temporary directory for generated files
mapping_file: Path to mapping file or None to generate

Returns:
dict: return_values with validation_result
"""
# Setup and generate minimal data to create mapping
subtask_options = self._setup_and_generate_data(
database_url=database_url,
tempdir=tempdir,
mapping_file=mapping_file,
num_records=1, # Generate minimal data just to create mapping
batch_index=0,
)

# Validate the mapping
validation_result = self._validate_mapping(subtask_options)

self.return_values = {"validation_result": validation_result}
return self.return_values

def _validate_mapping(self, subtask_options):
"""Validate the mapping against the org schema without loading data."""
mapping_file = subtask_options.get("mapping")
if not mapping_file:
raise TaskOptionsError("Mapping file path required for validation")

self.logger.info(f"Validating mapping file: {mapping_file}")
mapping = parse_from_yaml(mapping_file)

validation_result = validate_and_inject_mapping(
mapping=mapping,
sf=self.sf,
namespace=self.project_config.project__package__namespace,
data_operation=DataOperationType.INSERT,
inject_namespaces=self.options.get("inject_namespaces", False),
drop_missing=self.options.get("drop_missing_schema", False),
validate_only=True,
)

# Log summary message
self.logger.info("")
if validation_result and validation_result.has_errors():
self.logger.error("== Validation Failed ==")
self.logger.error(f" Errors: {len(validation_result.errors)}")
if validation_result.warnings:
self.logger.warning(f" Warnings: {len(validation_result.warnings)}")
elif validation_result and validation_result.warnings:
self.logger.warning("== Validation Successful (With Warnings) ==")
self.logger.warning(f" Warnings: {len(validation_result.warnings)}")
else:
self.logger.info("== Validation Successful ==")
self.logger.info("")

return validation_result

def _cleanup_object_tables(self, engine, metadata):
"""Delete all tables that do not relate to id->OID mapping"""
Expand Down
Loading
Loading