Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AddPersonAccountsToQuery,
AddRecordTypesToQuery,
DynamicLookupQueryExtender,
register_sqlite_functions,
)
from cumulusci.tasks.bulkdata.step import (
DEFAULT_BULK_BATCH_SIZE,
Expand Down Expand Up @@ -766,6 +767,9 @@ def _init_db(self):
with self._database_url() as database_url:
parent_engine = create_engine(database_url)
with parent_engine.connect() as connection:
# Register custom SQLite functions for smart lookup resolution
register_sqlite_functions(connection)

# initialize the DB session
self.session = Session(connection)

Expand Down
2 changes: 1 addition & 1 deletion cumulusci/tasks/bulkdata/mapping_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def check_required(
if fields_describe[field]["createable"] and not defaulted:
required_fields.add(field)
missing_fields = required_fields.difference(
set(self.fields.keys()) | set(self.lookups)
set(self.fields.keys()) | set(self.lookups) | set(self.static.keys())
)
if len(missing_fields) > 0:
message = f"One or more required fields are missing for loading on {self.sf_object} :{missing_fields}"
Expand Down
45 changes: 43 additions & 2 deletions cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import typing as T
from functools import cached_property

from sqlalchemy import String, and_, func, text
from sqlalchemy import String, and_, case, func, text
from sqlalchemy.orm import Query, aliased
from sqlalchemy.sql import literal_column

Expand All @@ -10,6 +11,29 @@
Criterion = T.Any
ID_TABLE_NAME = "cumulusci_id_table"

# Salesforce ID pattern: 15 or 18 alphanumeric characters
# This matches the OID_REGEX pattern used in robotframework/Salesforce.py
SF_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{15}$|^[a-zA-Z0-9]{18}$")


def is_salesforce_id(value: T.Optional[str]) -> bool:
"""Check if a value looks like a valid Salesforce ID."""
if value is None:
return False
return bool(SF_ID_PATTERN.match(str(value)))


def _is_salesforce_id_sqlite(value: T.Optional[str]) -> int:
"""SQLite UDF wrapper for is_salesforce_id."""
return 1 if is_salesforce_id(value) else 0


def register_sqlite_functions(connection) -> None:
"""Register custom SQLite functions on a database connection."""
# Get the underlying DBAPI connection
dbapi_connection = connection.connection.dbapi_connection
dbapi_connection.create_function("is_salesforce_id", 1, _is_salesforce_id_sqlite)


class LoadQueryExtender:
"""Class that transforms a load.py query with columns, filters, joins"""
Expand Down Expand Up @@ -61,9 +85,26 @@ def __init__(self, mapping, metadata, model, _old_format) -> None:

@cached_property
def columns_to_add(self):
"""Build column expressions for lookup fields with smart ID resolution."""
columns = []
for lookup in self.lookups:
lookup.aliased_table = aliased(self.metadata.tables[ID_TABLE_NAME])
return [lookup.aliased_table.columns.sf_id for lookup in self.lookups]
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)

# The resolved SF ID from the ID table join (may be NULL)
sf_id_from_table = lookup.aliased_table.columns.sf_id

smart_lookup = case(
# If we found a match in the ID table, use that
(sf_id_from_table.isnot(None), sf_id_from_table),
# If the original value is already a SF ID, use it directly
(func.is_salesforce_id(value_column) == 1, value_column),
# Otherwise return NULL (lookup not found)
else_=None,
)
columns.append(smart_lookup)
return columns

@cached_property
def outerjoins_to_add(self):
Expand Down
80 changes: 57 additions & 23 deletions cumulusci/tasks/bulkdata/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def test_query_db__joins_self_lookups(self):
sql_path=Path(__file__).parent / "test_query_db__joins_self_lookups.sql",
mapping=Path(__file__).parent / "test_query_db__joins_self_lookups.yml",
mapping_step_name="Update Accounts",
expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""",
expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(accounts.parent_id) = ?) THEN accounts.parent_id END AS anon_1 FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""",
old_format=True,
)

Expand All @@ -989,7 +989,7 @@ def test_query_db__joins_select_lookups(self):
sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql",
mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml",
mapping_step_name="Select Event",
expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''',
expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."whoid") = ?) THEN events."whoid" END AS anon_1 FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''',
)

def test_query_db__joins_polymorphic_lookups(self):
Expand All @@ -998,7 +998,7 @@ def test_query_db__joins_polymorphic_lookups(self):
sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql",
mapping=Path(__file__).parent / "test_query_db_joins_lookups.yml",
mapping_step_name="Update Event",
expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """,
expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."WhoId") = ?) THEN events."WhoId" END AS anon_1 FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """,
)

@responses.activate
Expand Down Expand Up @@ -1109,12 +1109,13 @@ def test_query_db__person_accounts_enabled__contact_mapping(self, aliased):
) # check that query chaining from above worked.

query_columns, added_filters = _inspect_query(query)
# Validate that the column set is accurate
assert query_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert query_columns[0] == model.sf_id
assert query_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(query_columns) == 3
assert "CASE" in str(query_columns[2]).upper()
assert "is_salesforce_id" in str(query_columns[2])

# Validate person contact records WERE filtered out
filter_out_contacts, *rest = added_filters
Expand Down Expand Up @@ -1182,12 +1183,13 @@ def test_query_db__person_accounts_disabled__contact_mapping(self, aliased):
added_columns.extend(args)
all_columns = initialization_columns + tuple(added_columns)

# Validate that the column set is accurate
assert all_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert all_columns[0] == model.sf_id
assert all_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(all_columns) == 3
assert "CASE" in str(all_columns[2]).upper()
assert "is_salesforce_id" in str(all_columns[2])

# Validate person contact records were not filtered out
task._can_load_person_accounts.assert_called_once_with(mapping)
Expand Down Expand Up @@ -1240,12 +1242,13 @@ def test_query_db__person_accounts_enabled__neither_account_nor_contact_mapping(
query = task._query_db(mapping)
query_columns, added_filters = _inspect_query(query)

# Validate that the column set is accurate
assert query_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert query_columns[0] == model.sf_id
assert query_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(query_columns) == 3
assert "CASE" in str(query_columns[2]).upper()
assert "is_salesforce_id" in str(query_columns[2])

# Validate person contact db records had their Name updated as blank
task._can_load_person_accounts.assert_not_called()
Expand Down Expand Up @@ -2715,8 +2718,6 @@ def get_random_string():
chunks_index = 0

def fetchmany(batch_size):
nonlocal chunks_index

assert 200 == batch_size

# _generate_contact_id_map_for_person_accounts should break if fetchmany returns falsy.
Expand Down Expand Up @@ -3021,6 +3022,38 @@ def test_mapping_file_with_explicit_IsPersonAccount(self, caplog):
task._init_task()
task._init_mapping()

def test_smart_lookup__mixed_sf_ids_and_local_refs(self):
"""Test that smart lookup handles both pre-resolved SF IDs and local references"""
base_path = Path(__file__).parent
sql_path = base_path / "test_smart_lookup.sql"
mapping_path = base_path / "test_smart_lookup.yml"

task = _make_task(
LoadData,
{
"options": {
"sql_path": sql_path,
"mapping": mapping_path,
}
},
)

with mock.patch(
"cumulusci.tasks.bulkdata.load.validate_and_inject_mapping"
), mock.patch.object(task, "sf", create=True):
task._init_mapping()

with task._init_db():
task._old_format = False
query = task._query_db(task.mapping["Insert PricebookEntry"])
results = list(query.all())
results_by_id = {row[0]: row[2] for row in results}
assert results_by_id["PricebookEntry-1"] == "01sSG00000Dsd89YAB"
assert results_by_id["PricebookEntry-2"] == "01s000000000001AAA"
assert results_by_id["PricebookEntry-3"] == "01sSG00000Dsd89"
assert results_by_id["PricebookEntry-4"] is None
assert results_by_id["PricebookEntry-5"] is None


class TestLoadDataIntegrationTests:
# bulk API not supported by VCR yet
Expand All @@ -3033,6 +3066,7 @@ def test_error_result_counting__multi_batches(
{
"sql_path": cumulusci_test_repo_root / "datasets/bad_sample.sql",
"mapping": cumulusci_test_repo_root / "datasets/mapping.yml",
"ignore_row_errors": True,
},
)
with mock.patch("cumulusci.tasks.bulkdata.step.DEFAULT_BULK_BATCH_SIZE", 3):
Expand Down
30 changes: 30 additions & 0 deletions cumulusci/tasks/bulkdata/tests/test_smart_lookup.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
BEGIN TRANSACTION;

CREATE TABLE "PricebookEntry" (
id VARCHAR(255) NOT NULL,
"Pricebook2Id" VARCHAR(255),
"UnitPrice" VARCHAR(255),
PRIMARY KEY (id)
);
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-1', '01sSG00000Dsd89YAB', '100');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-2', 'Pricebook2-1', '200');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-3', '01sSG00000Dsd89', '300');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-4', NULL, '400');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-5', 'invalid-ref', '500');

CREATE TABLE "Pricebook2" (
id VARCHAR(255) NOT NULL,
"Name" VARCHAR(255),
PRIMARY KEY (id)
);
INSERT INTO "Pricebook2" VALUES('Pricebook2-1', 'Standard Price Book');
INSERT INTO "Pricebook2" VALUES('Pricebook2-2', 'Partner Price Book');

CREATE TABLE "cumulusci_id_table" (
id VARCHAR(255) NOT NULL,
sf_id VARCHAR(18)
);
INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-1', '01s000000000001AAA');
INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-2', '01s000000000002AAA');

COMMIT;
15 changes: 15 additions & 0 deletions cumulusci/tasks/bulkdata/tests/test_smart_lookup.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Insert Pricebook2:
sf_object: Pricebook2
table: Pricebook2
fields:
Name: Name

Insert PricebookEntry:
sf_object: PricebookEntry
table: PricebookEntry
fields:
UnitPrice: UnitPrice
lookups:
Pricebook2Id:
table: Pricebook2
key_field: Pricebook2Id
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies = [
"sarge",
"selenium<4",
"simple-salesforce>=1.12.6",
"snowfakery>=4.2.0",
"snowfakery>=4.2.1",
"xmltodict",
"docutils<=0.21.2",
]
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading