diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 0732d57777..a6cbdee9ac 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -28,6 +28,7 @@ AddPersonAccountsToQuery, AddRecordTypesToQuery, DynamicLookupQueryExtender, + register_sqlite_functions, ) from cumulusci.tasks.bulkdata.step import ( DEFAULT_BULK_BATCH_SIZE, @@ -766,6 +767,9 @@ def _init_db(self): with self._database_url() as database_url: parent_engine = create_engine(database_url) with parent_engine.connect() as connection: + # Register custom SQLite functions for smart lookup resolution + register_sqlite_functions(connection) + # initialize the DB session self.session = Session(connection) diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index 4278dda6bd..63ed9c48f1 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -545,7 +545,7 @@ def check_required( if fields_describe[field]["createable"] and not defaulted: required_fields.add(field) missing_fields = required_fields.difference( - set(self.fields.keys()) | set(self.lookups) + set(self.fields.keys()) | set(self.lookups) | set(self.static.keys()) ) if len(missing_fields) > 0: message = f"One or more required fields are missing for loading on {self.sf_object} :{missing_fields}" diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index 181736a4bc..3f632c694e 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -1,7 +1,8 @@ +import re import typing as T from functools import cached_property -from sqlalchemy import String, and_, func, text +from sqlalchemy import String, and_, case, func, text from sqlalchemy.orm import Query, aliased from sqlalchemy.sql import literal_column @@ -10,6 +11,29 @@ Criterion = T.Any ID_TABLE_NAME = "cumulusci_id_table" +# Salesforce ID pattern: 15 or 18 alphanumeric characters +# This matches the OID_REGEX pattern used in robotframework/Salesforce.py +SF_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{15}$|^[a-zA-Z0-9]{18}$") + + +def is_salesforce_id(value: T.Optional[str]) -> bool: + """Check if a value looks like a valid Salesforce ID.""" + if value is None: + return False + return bool(SF_ID_PATTERN.match(str(value))) + + +def _is_salesforce_id_sqlite(value: T.Optional[str]) -> int: + """SQLite UDF wrapper for is_salesforce_id.""" + return 1 if is_salesforce_id(value) else 0 + + +def register_sqlite_functions(connection) -> None: + """Register custom SQLite functions on a database connection.""" + # Get the underlying DBAPI connection + dbapi_connection = connection.connection.dbapi_connection + dbapi_connection.create_function("is_salesforce_id", 1, _is_salesforce_id_sqlite) + class LoadQueryExtender: """Class that transforms a load.py query with columns, filters, joins""" @@ -61,9 +85,26 @@ def __init__(self, mapping, metadata, model, _old_format) -> None: @cached_property def columns_to_add(self): + """Build column expressions for lookup fields with smart ID resolution.""" + columns = [] for lookup in self.lookups: lookup.aliased_table = aliased(self.metadata.tables[ID_TABLE_NAME]) - return [lookup.aliased_table.columns.sf_id for lookup in self.lookups] + key_field = lookup.get_lookup_key_field(self.model) + value_column = getattr(self.model, key_field) + + # The resolved SF ID from the ID table join (may be NULL) + sf_id_from_table = lookup.aliased_table.columns.sf_id + + smart_lookup = case( + # If we found a match in the ID table, use that + (sf_id_from_table.isnot(None), sf_id_from_table), + # If the original value is already a SF ID, use it directly + (func.is_salesforce_id(value_column) == 1, value_column), + # Otherwise return NULL (lookup not found) + else_=None, + ) + columns.append(smart_lookup) + return columns @cached_property def outerjoins_to_add(self): diff --git a/cumulusci/tasks/bulkdata/tests/test_load.py b/cumulusci/tasks/bulkdata/tests/test_load.py index 8fb8ee0756..f3657779c3 100644 --- a/cumulusci/tasks/bulkdata/tests/test_load.py +++ b/cumulusci/tasks/bulkdata/tests/test_load.py @@ -979,7 +979,7 @@ def test_query_db__joins_self_lookups(self): sql_path=Path(__file__).parent / "test_query_db__joins_self_lookups.sql", mapping=Path(__file__).parent / "test_query_db__joins_self_lookups.yml", mapping_step_name="Update Accounts", - expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""", + expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(accounts.parent_id) = ?) THEN accounts.parent_id END AS anon_1 FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""", old_format=True, ) @@ -989,7 +989,7 @@ def test_query_db__joins_select_lookups(self): sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql", mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml", mapping_step_name="Select Event", - expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''', + expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."whoid") = ?) THEN events."whoid" END AS anon_1 FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''', ) def test_query_db__joins_polymorphic_lookups(self): @@ -998,7 +998,7 @@ def test_query_db__joins_polymorphic_lookups(self): sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql", mapping=Path(__file__).parent / "test_query_db_joins_lookups.yml", mapping_step_name="Update Event", - expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """, + expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."WhoId") = ?) THEN events."WhoId" END AS anon_1 FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """, ) @responses.activate @@ -1109,12 +1109,13 @@ def test_query_db__person_accounts_enabled__contact_mapping(self, aliased): ) # check that query chaining from above worked. query_columns, added_filters = _inspect_query(query) - # Validate that the column set is accurate - assert query_columns == ( - model.sf_id, - model.__table__.columns["name"], - aliased.return_value.columns.sf_id, - ) + # Validate that the initialization columns are accurate + assert query_columns[0] == model.sf_id + assert query_columns[1] == model.__table__.columns["name"] + # Third column is a CASE expression for smart lookup resolution + assert len(query_columns) == 3 + assert "CASE" in str(query_columns[2]).upper() + assert "is_salesforce_id" in str(query_columns[2]) # Validate person contact records WERE filtered out filter_out_contacts, *rest = added_filters @@ -1182,12 +1183,13 @@ def test_query_db__person_accounts_disabled__contact_mapping(self, aliased): added_columns.extend(args) all_columns = initialization_columns + tuple(added_columns) - # Validate that the column set is accurate - assert all_columns == ( - model.sf_id, - model.__table__.columns["name"], - aliased.return_value.columns.sf_id, - ) + # Validate that the initialization columns are accurate + assert all_columns[0] == model.sf_id + assert all_columns[1] == model.__table__.columns["name"] + # Third column is a CASE expression for smart lookup resolution + assert len(all_columns) == 3 + assert "CASE" in str(all_columns[2]).upper() + assert "is_salesforce_id" in str(all_columns[2]) # Validate person contact records were not filtered out task._can_load_person_accounts.assert_called_once_with(mapping) @@ -1240,12 +1242,13 @@ def test_query_db__person_accounts_enabled__neither_account_nor_contact_mapping( query = task._query_db(mapping) query_columns, added_filters = _inspect_query(query) - # Validate that the column set is accurate - assert query_columns == ( - model.sf_id, - model.__table__.columns["name"], - aliased.return_value.columns.sf_id, - ) + # Validate that the initialization columns are accurate + assert query_columns[0] == model.sf_id + assert query_columns[1] == model.__table__.columns["name"] + # Third column is a CASE expression for smart lookup resolution + assert len(query_columns) == 3 + assert "CASE" in str(query_columns[2]).upper() + assert "is_salesforce_id" in str(query_columns[2]) # Validate person contact db records had their Name updated as blank task._can_load_person_accounts.assert_not_called() @@ -2715,8 +2718,6 @@ def get_random_string(): chunks_index = 0 def fetchmany(batch_size): - nonlocal chunks_index - assert 200 == batch_size # _generate_contact_id_map_for_person_accounts should break if fetchmany returns falsy. @@ -3021,6 +3022,38 @@ def test_mapping_file_with_explicit_IsPersonAccount(self, caplog): task._init_task() task._init_mapping() + def test_smart_lookup__mixed_sf_ids_and_local_refs(self): + """Test that smart lookup handles both pre-resolved SF IDs and local references""" + base_path = Path(__file__).parent + sql_path = base_path / "test_smart_lookup.sql" + mapping_path = base_path / "test_smart_lookup.yml" + + task = _make_task( + LoadData, + { + "options": { + "sql_path": sql_path, + "mapping": mapping_path, + } + }, + ) + + with mock.patch( + "cumulusci.tasks.bulkdata.load.validate_and_inject_mapping" + ), mock.patch.object(task, "sf", create=True): + task._init_mapping() + + with task._init_db(): + task._old_format = False + query = task._query_db(task.mapping["Insert PricebookEntry"]) + results = list(query.all()) + results_by_id = {row[0]: row[2] for row in results} + assert results_by_id["PricebookEntry-1"] == "01sSG00000Dsd89YAB" + assert results_by_id["PricebookEntry-2"] == "01s000000000001AAA" + assert results_by_id["PricebookEntry-3"] == "01sSG00000Dsd89" + assert results_by_id["PricebookEntry-4"] is None + assert results_by_id["PricebookEntry-5"] is None + class TestLoadDataIntegrationTests: # bulk API not supported by VCR yet @@ -3033,6 +3066,7 @@ def test_error_result_counting__multi_batches( { "sql_path": cumulusci_test_repo_root / "datasets/bad_sample.sql", "mapping": cumulusci_test_repo_root / "datasets/mapping.yml", + "ignore_row_errors": True, }, ) with mock.patch("cumulusci.tasks.bulkdata.step.DEFAULT_BULK_BATCH_SIZE", 3): diff --git a/cumulusci/tasks/bulkdata/tests/test_smart_lookup.sql b/cumulusci/tasks/bulkdata/tests/test_smart_lookup.sql new file mode 100644 index 0000000000..bd15c2898a --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_smart_lookup.sql @@ -0,0 +1,30 @@ +BEGIN TRANSACTION; + +CREATE TABLE "PricebookEntry" ( + id VARCHAR(255) NOT NULL, + "Pricebook2Id" VARCHAR(255), + "UnitPrice" VARCHAR(255), + PRIMARY KEY (id) +); +INSERT INTO "PricebookEntry" VALUES('PricebookEntry-1', '01sSG00000Dsd89YAB', '100'); +INSERT INTO "PricebookEntry" VALUES('PricebookEntry-2', 'Pricebook2-1', '200'); +INSERT INTO "PricebookEntry" VALUES('PricebookEntry-3', '01sSG00000Dsd89', '300'); +INSERT INTO "PricebookEntry" VALUES('PricebookEntry-4', NULL, '400'); +INSERT INTO "PricebookEntry" VALUES('PricebookEntry-5', 'invalid-ref', '500'); + +CREATE TABLE "Pricebook2" ( + id VARCHAR(255) NOT NULL, + "Name" VARCHAR(255), + PRIMARY KEY (id) +); +INSERT INTO "Pricebook2" VALUES('Pricebook2-1', 'Standard Price Book'); +INSERT INTO "Pricebook2" VALUES('Pricebook2-2', 'Partner Price Book'); + +CREATE TABLE "cumulusci_id_table" ( + id VARCHAR(255) NOT NULL, + sf_id VARCHAR(18) +); +INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-1', '01s000000000001AAA'); +INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-2', '01s000000000002AAA'); + +COMMIT; diff --git a/cumulusci/tasks/bulkdata/tests/test_smart_lookup.yml b/cumulusci/tasks/bulkdata/tests/test_smart_lookup.yml new file mode 100644 index 0000000000..96643322c7 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_smart_lookup.yml @@ -0,0 +1,15 @@ +Insert Pricebook2: + sf_object: Pricebook2 + table: Pricebook2 + fields: + Name: Name + +Insert PricebookEntry: + sf_object: PricebookEntry + table: PricebookEntry + fields: + UnitPrice: UnitPrice + lookups: + Pricebook2Id: + table: Pricebook2 + key_field: Pricebook2Id diff --git a/pyproject.toml b/pyproject.toml index 979be8fe17..b2956c25cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ "sarge", "selenium<4", "simple-salesforce>=1.12.6", - "snowfakery>=4.2.0", + "snowfakery>=4.2.1", "xmltodict", "docutils<=0.21.2", ] diff --git a/uv.lock b/uv.lock index 764d19ad31..2ba9df31a2 100644 --- a/uv.lock +++ b/uv.lock @@ -493,7 +493,7 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'select'" }, { name = "selenium", specifier = "<4" }, { name = "simple-salesforce", specifier = ">=1.12.6" }, - { name = "snowfakery", specifier = ">=4.2.0" }, + { name = "snowfakery", specifier = ">=4.2.1" }, { name = "sqlalchemy", specifier = "<2" }, { name = "xmltodict" }, ] @@ -2138,7 +2138,7 @@ wheels = [ [[package]] name = "snowfakery" -version = "4.2.0" +version = "4.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -2155,9 +2155,9 @@ dependencies = [ { name = "setuptools" }, { name = "sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/48/b155854cd42d4bf345873fc11bdeb3257a505914de088dc8a0ec90e8753e/snowfakery-4.2.0.tar.gz", hash = "sha256:930df06131749d033559e5edb4a60daa747beacafa8a36b07093d23da7095908", size = 110275, upload-time = "2025-12-19T09:17:56.172Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/f2/0e21350da9dbf8fc28bd42f1c9f4fa1ced67e207d65ffa88f95f7eee53a6/snowfakery-4.2.1.tar.gz", hash = "sha256:aa410b5de078f54e29c2e6675df3b331e617c3df36e524e564d594bba2f83f61", size = 110726, upload-time = "2026-01-09T07:02:02.199Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/00/dca6dd168eeb2783e0e667923155cee1695b594bdcb22c27214379214728/snowfakery-4.2.0-py3-none-any.whl", hash = "sha256:e5df8361c86bc70741fdb8d31d2e420d0bb9e0c804097874756911ec6d2ee49b", size = 138789, upload-time = "2025-12-19T09:17:54.94Z" }, + { url = "https://files.pythonhosted.org/packages/4d/3c/08bd6c8b5b4fbf6a461d8aa9b8089c1f2fd6c1c712e8e85767bfe3567caf/snowfakery-4.2.1-py3-none-any.whl", hash = "sha256:38fd9cedd0dca8cc2fa65c4adcc88ae1f91aab99efc328c2c6e52c9cb1965522", size = 139258, upload-time = "2026-01-09T07:02:03.841Z" }, ] [[package]]