Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AddPersonAccountsToQuery,
AddRecordTypesToQuery,
DynamicLookupQueryExtender,
register_sqlite_functions,
)
from cumulusci.tasks.bulkdata.step import (
DEFAULT_BULK_BATCH_SIZE,
Expand Down Expand Up @@ -766,6 +767,9 @@ def _init_db(self):
with self._database_url() as database_url:
parent_engine = create_engine(database_url)
with parent_engine.connect() as connection:
# Register custom SQLite functions for smart lookup resolution
register_sqlite_functions(connection)

# initialize the DB session
self.session = Session(connection)

Expand Down
45 changes: 43 additions & 2 deletions cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import typing as T
from functools import cached_property

from sqlalchemy import String, and_, func, text
from sqlalchemy import String, and_, case, func, text
from sqlalchemy.orm import Query, aliased
from sqlalchemy.sql import literal_column

Expand All @@ -10,6 +11,29 @@
Criterion = T.Any
ID_TABLE_NAME = "cumulusci_id_table"

# Salesforce ID pattern: 15 or 18 alphanumeric characters
# This matches the OID_REGEX pattern used in robotframework/Salesforce.py
SF_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{15}$|^[a-zA-Z0-9]{18}$")


def is_salesforce_id(value: T.Optional[str]) -> bool:
"""Check if a value looks like a valid Salesforce ID."""
if value is None:
return False
return bool(SF_ID_PATTERN.match(str(value)))


def _is_salesforce_id_sqlite(value: T.Optional[str]) -> int:
"""SQLite UDF wrapper for is_salesforce_id."""
return 1 if is_salesforce_id(value) else 0


def register_sqlite_functions(connection) -> None:
"""Register custom SQLite functions on a database connection."""
# Get the underlying DBAPI connection
dbapi_connection = connection.connection.dbapi_connection
dbapi_connection.create_function("is_salesforce_id", 1, _is_salesforce_id_sqlite)


class LoadQueryExtender:
"""Class that transforms a load.py query with columns, filters, joins"""
Expand Down Expand Up @@ -61,9 +85,26 @@ def __init__(self, mapping, metadata, model, _old_format) -> None:

@cached_property
def columns_to_add(self):
"""Build column expressions for lookup fields with smart ID resolution."""
columns = []
for lookup in self.lookups:
lookup.aliased_table = aliased(self.metadata.tables[ID_TABLE_NAME])
return [lookup.aliased_table.columns.sf_id for lookup in self.lookups]
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)

# The resolved SF ID from the ID table join (may be NULL)
sf_id_from_table = lookup.aliased_table.columns.sf_id

smart_lookup = case(
# If we found a match in the ID table, use that
(sf_id_from_table.isnot(None), sf_id_from_table),
# If the original value is already a SF ID, use it directly
(func.is_salesforce_id(value_column) == 1, value_column),
# Otherwise return NULL (lookup not found)
else_=None,
)
columns.append(smart_lookup)
return columns

@cached_property
def outerjoins_to_add(self):
Expand Down
79 changes: 56 additions & 23 deletions cumulusci/tasks/bulkdata/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def test_query_db__joins_self_lookups(self):
sql_path=Path(__file__).parent / "test_query_db__joins_self_lookups.sql",
mapping=Path(__file__).parent / "test_query_db__joins_self_lookups.yml",
mapping_step_name="Update Accounts",
expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""",
expected="""SELECT accounts.id AS accounts_id, accounts."Name" AS "accounts_Name", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(accounts.parent_id) = ?) THEN accounts.parent_id END AS anon_1 FROM accounts LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(accounts.parent_id as varchar) ORDER BY accounts.parent_id""",
old_format=True,
)

Expand All @@ -989,7 +989,7 @@ def test_query_db__joins_select_lookups(self):
sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql",
mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml",
mapping_step_name="Select Event",
expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''',
expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."whoid") = ?) THEN events."whoid" END AS anon_1 FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''',
)

def test_query_db__joins_polymorphic_lookups(self):
Expand All @@ -998,7 +998,7 @@ def test_query_db__joins_polymorphic_lookups(self):
sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql",
mapping=Path(__file__).parent / "test_query_db_joins_lookups.yml",
mapping_step_name="Update Event",
expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """,
expected="""SELECT events.id AS events_id, events."Subject" AS "events_Subject", CASE WHEN (cumulusci_id_table_1.sf_id IS NOT NULL) THEN cumulusci_id_table_1.sf_id WHEN (is_salesforce_id(events."WhoId") = ?) THEN events."WhoId" END AS anon_1 FROM events LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id = ? || cast(events."WhoId" as varchar) ORDER BY events."WhoId" """,
)

@responses.activate
Expand Down Expand Up @@ -1109,12 +1109,13 @@ def test_query_db__person_accounts_enabled__contact_mapping(self, aliased):
) # check that query chaining from above worked.

query_columns, added_filters = _inspect_query(query)
# Validate that the column set is accurate
assert query_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert query_columns[0] == model.sf_id
assert query_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(query_columns) == 3
assert "CASE" in str(query_columns[2]).upper()
assert "is_salesforce_id" in str(query_columns[2])

# Validate person contact records WERE filtered out
filter_out_contacts, *rest = added_filters
Expand Down Expand Up @@ -1182,12 +1183,13 @@ def test_query_db__person_accounts_disabled__contact_mapping(self, aliased):
added_columns.extend(args)
all_columns = initialization_columns + tuple(added_columns)

# Validate that the column set is accurate
assert all_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert all_columns[0] == model.sf_id
assert all_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(all_columns) == 3
assert "CASE" in str(all_columns[2]).upper()
assert "is_salesforce_id" in str(all_columns[2])

# Validate person contact records were not filtered out
task._can_load_person_accounts.assert_called_once_with(mapping)
Expand Down Expand Up @@ -1240,12 +1242,13 @@ def test_query_db__person_accounts_enabled__neither_account_nor_contact_mapping(
query = task._query_db(mapping)
query_columns, added_filters = _inspect_query(query)

# Validate that the column set is accurate
assert query_columns == (
model.sf_id,
model.__table__.columns["name"],
aliased.return_value.columns.sf_id,
)
# Validate that the initialization columns are accurate
assert query_columns[0] == model.sf_id
assert query_columns[1] == model.__table__.columns["name"]
# Third column is a CASE expression for smart lookup resolution
assert len(query_columns) == 3
assert "CASE" in str(query_columns[2]).upper()
assert "is_salesforce_id" in str(query_columns[2])

# Validate person contact db records had their Name updated as blank
task._can_load_person_accounts.assert_not_called()
Expand Down Expand Up @@ -2715,8 +2718,6 @@ def get_random_string():
chunks_index = 0

def fetchmany(batch_size):
nonlocal chunks_index

assert 200 == batch_size

# _generate_contact_id_map_for_person_accounts should break if fetchmany returns falsy.
Expand Down Expand Up @@ -3021,6 +3022,38 @@ def test_mapping_file_with_explicit_IsPersonAccount(self, caplog):
task._init_task()
task._init_mapping()

def test_smart_lookup__mixed_sf_ids_and_local_refs(self):
"""Test that smart lookup handles both pre-resolved SF IDs and local references"""
base_path = Path(__file__).parent
sql_path = base_path / "test_smart_lookup.sql"
mapping_path = base_path / "test_smart_lookup.yml"

task = _make_task(
LoadData,
{
"options": {
"sql_path": sql_path,
"mapping": mapping_path,
}
},
)

with mock.patch(
"cumulusci.tasks.bulkdata.load.validate_and_inject_mapping"
), mock.patch.object(task, "sf", create=True):
task._init_mapping()

with task._init_db():
task._old_format = False
query = task._query_db(task.mapping["Insert PricebookEntry"])
results = list(query.all())
results_by_id = {row[0]: row[2] for row in results}
assert results_by_id["PricebookEntry-1"] == "01sSG00000Dsd89YAB"
assert results_by_id["PricebookEntry-2"] == "01s000000000001AAA"
assert results_by_id["PricebookEntry-3"] == "01sSG00000Dsd89"
assert results_by_id["PricebookEntry-4"] is None
assert results_by_id["PricebookEntry-5"] is None


class TestLoadDataIntegrationTests:
# bulk API not supported by VCR yet
Expand Down
30 changes: 30 additions & 0 deletions cumulusci/tasks/bulkdata/tests/test_smart_lookup.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
BEGIN TRANSACTION;

CREATE TABLE "PricebookEntry" (
id VARCHAR(255) NOT NULL,
"Pricebook2Id" VARCHAR(255),
"UnitPrice" VARCHAR(255),
PRIMARY KEY (id)
);
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-1', '01sSG00000Dsd89YAB', '100');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-2', 'Pricebook2-1', '200');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-3', '01sSG00000Dsd89', '300');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-4', NULL, '400');
INSERT INTO "PricebookEntry" VALUES('PricebookEntry-5', 'invalid-ref', '500');

CREATE TABLE "Pricebook2" (
id VARCHAR(255) NOT NULL,
"Name" VARCHAR(255),
PRIMARY KEY (id)
);
INSERT INTO "Pricebook2" VALUES('Pricebook2-1', 'Standard Price Book');
INSERT INTO "Pricebook2" VALUES('Pricebook2-2', 'Partner Price Book');

CREATE TABLE "cumulusci_id_table" (
id VARCHAR(255) NOT NULL,
sf_id VARCHAR(18)
);
INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-1', '01s000000000001AAA');
INSERT INTO "cumulusci_id_table" VALUES('Pricebook2-2', '01s000000000002AAA');

COMMIT;
15 changes: 15 additions & 0 deletions cumulusci/tasks/bulkdata/tests/test_smart_lookup.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Insert Pricebook2:
sf_object: Pricebook2
table: Pricebook2
fields:
Name: Name

Insert PricebookEntry:
sf_object: PricebookEntry
table: PricebookEntry
fields:
UnitPrice: UnitPrice
lookups:
Pricebook2Id:
table: Pricebook2
key_field: Pricebook2Id
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies = [
"sarge",
"selenium<4",
"simple-salesforce>=1.12.6",
"snowfakery>=4.2.0",
"snowfakery>=4.2.1",
"xmltodict",
"docutils<=0.21.2",
]
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading