Skip to content

Commit ebd5f08

Browse files
Modify functionality to return records in order of the user query to support ORDER BY operation
1 parent 6eca455 commit ebd5f08

2 files changed

Lines changed: 166 additions & 6 deletions

File tree

cumulusci/tasks/bulkdata/step.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,11 @@ def select_records(self, records):
491491
user_query_records = user_query_executor.get_results()
492492

493493
# Find intersection based on 'Id'
494-
user_query_ids = set(record[0] for record in user_query_records)
494+
user_query_ids = (
495+
list(record[0] for record in user_query_records)
496+
if user_query_records
497+
else []
498+
)
495499

496500
# Execute the main select query using Bulk API
497501
select_query_records = self._execute_select_query(
@@ -500,10 +504,16 @@ def select_records(self, records):
500504

501505
# If user_query_ids exist, filter select_query_records based on the intersection of Ids
502506
if self.selection_filter:
507+
# Create a dictionary to map IDs to their corresponding records
508+
id_to_record_map = {
509+
record[query_fields.index("Id")]: record
510+
for record in select_query_records
511+
}
512+
# Extend query_records in the order of user_query_ids
503513
query_records.extend(
504514
record
505-
for record in select_query_records
506-
if record[query_fields.index("Id")] in user_query_ids
515+
for id in user_query_ids
516+
if (record := id_to_record_map.get(id)) is not None
507517
)
508518
else:
509519
query_records.extend(select_query_records)
@@ -928,12 +938,17 @@ def convert(rec, fields):
928938
f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}"
929939
)
930940
# Find intersection based on 'Id'
931-
user_query_ids = set(record[0] for record in user_query_records)
941+
user_query_ids = list(record[0] for record in user_query_records)
942+
# Create a dictionary to map IDs to their corresponding records
943+
id_to_record_map = {
944+
record[query_fields.index("Id")]: record for record in select_query_records
945+
}
932946

947+
# Extend query_records in the order of user_query_ids
933948
return [
934949
record
935-
for record in select_query_records
936-
if record[query_fields.index("Id")] in user_query_ids
950+
for id in user_query_ids
951+
if (record := id_to_record_map.get(id)) is not None
937952
]
938953

939954
def get_results(self):

cumulusci/tasks/bulkdata/tests/test_step.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,70 @@ def test_select_records_user_selection_filter_success(self, download_mock):
701701
== 3
702702
)
703703

704+
@mock.patch("cumulusci.tasks.bulkdata.step.download_file")
705+
def test_select_records_user_selection_filter_order_success(self, download_mock):
706+
# Set up mock context and BulkApiDmlOperation
707+
context = mock.Mock()
708+
step = BulkApiDmlOperation(
709+
sobject="Contact",
710+
operation=DataOperationType.QUERY,
711+
api_options={"batch_size": 10, "update_key": "LastName"},
712+
context=context,
713+
fields=["LastName"],
714+
selection_strategy=SelectStrategy.STANDARD,
715+
selection_filter="ORDER BY CreatedDate",
716+
)
717+
718+
# Mock Bulk API responses
719+
step.bulk.endpoint = "https://test"
720+
step.bulk.create_query_job.return_value = "JOB"
721+
step.bulk.query.return_value = "BATCH"
722+
step.bulk.get_query_batch_result_ids.return_value = ["RESULT"]
723+
724+
# Mock the downloaded CSV content with a single record
725+
download_mock.return_value = io.StringIO(
726+
"""Id
727+
003000000000001
728+
003000000000002
729+
003000000000003"""
730+
)
731+
# Mock the query operation
732+
with mock.patch(
733+
"cumulusci.tasks.bulkdata.step.get_query_operation"
734+
) as query_operation_mock:
735+
query_operation_mock.return_value = mock.Mock()
736+
query_operation_mock.return_value.query = mock.Mock()
737+
query_operation_mock.return_value.get_results = mock.Mock()
738+
query_operation_mock.return_value.get_results.return_value = [
739+
["003000000000003"],
740+
["003000000000001"],
741+
["003000000000002"],
742+
]
743+
744+
# Mock the _wait_for_job method to simulate a successful job
745+
step._wait_for_job = mock.Mock()
746+
step._wait_for_job.return_value = DataOperationJobResult(
747+
DataOperationStatus.SUCCESS, [], 0, 0
748+
)
749+
750+
# Prepare input records
751+
records = iter([["Test1"], ["Test2"], ["Test3"]])
752+
753+
# Execute the select_records operation
754+
step.start()
755+
step.select_records(records)
756+
step.end()
757+
758+
# Get the results and assert their properties
759+
results = list(step.get_results())
760+
assert (
761+
len(results) == 3
762+
) # Expect 3 results (matching the input records count)
763+
# Assert that all results are in the order given by user query
764+
assert results[0].id == "003000000000003"
765+
assert results[1].id == "003000000000001"
766+
assert results[2].id == "003000000000002"
767+
704768
@mock.patch("cumulusci.tasks.bulkdata.step.download_file")
705769
def test_select_records_user_selection_filter_failure(self, download_mock):
706770
# Set up mock context and BulkApiDmlOperation
@@ -1428,6 +1492,87 @@ def test_select_records_user_selection_filter_success(self):
14281492
== 3
14291493
)
14301494

1495+
@responses.activate
1496+
def test_select_records_user_selection_filter_order_success(self):
1497+
mock_describe_calls()
1498+
task = _make_task(
1499+
LoadData,
1500+
{
1501+
"options": {
1502+
"database_url": "sqlite:///test.db",
1503+
"mapping": "mapping.yml",
1504+
}
1505+
},
1506+
)
1507+
task.project_config.project__package__api_version = CURRENT_SF_API_VERSION
1508+
task._init_task()
1509+
1510+
responses.add(
1511+
responses.POST,
1512+
url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects",
1513+
json=[
1514+
{"id": "003000000000001", "success": True},
1515+
{"id": "003000000000002", "success": True},
1516+
],
1517+
status=200,
1518+
)
1519+
responses.add(
1520+
responses.POST,
1521+
url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects",
1522+
json=[{"id": "003000000000003", "success": True}],
1523+
status=200,
1524+
)
1525+
step = RestApiDmlOperation(
1526+
sobject="Contact",
1527+
operation=DataOperationType.UPSERT,
1528+
api_options={"batch_size": 10, "update_key": "LastName"},
1529+
context=task,
1530+
fields=["LastName"],
1531+
selection_strategy=SelectStrategy.STANDARD,
1532+
selection_filter="ORDER BY CreatedDate",
1533+
)
1534+
1535+
results = {
1536+
"compositeResponse": [
1537+
{
1538+
"body": {
1539+
"records": [
1540+
{"Id": "003000000000001"},
1541+
{"Id": "003000000000002"},
1542+
{"Id": "003000000000003"},
1543+
]
1544+
},
1545+
"referenceId": "select_query",
1546+
"httpStatusCode": 200,
1547+
},
1548+
{
1549+
"body": {
1550+
"records": [
1551+
{"Id": "003000000000003"},
1552+
{"Id": "003000000000001"},
1553+
{"Id": "003000000000002"},
1554+
]
1555+
},
1556+
"referenceId": "user_query",
1557+
"httpStatusCode": 200,
1558+
},
1559+
]
1560+
}
1561+
step.sf.restful = mock.Mock()
1562+
step.sf.restful.return_value = results
1563+
records = iter([["Test1"], ["Test2"], ["Test3"]])
1564+
step.start()
1565+
step.select_records(records)
1566+
step.end()
1567+
1568+
# Get the results and assert their properties
1569+
results = list(step.get_results())
1570+
assert len(results) == 3 # Expect 3 results (matching the input records count)
1571+
# Assert that all results are in the order of user_query
1572+
assert results[0].id == "003000000000003"
1573+
assert results[1].id == "003000000000001"
1574+
assert results[2].id == "003000000000002"
1575+
14311576
@responses.activate
14321577
def test_select_records_user_selection_filter_failure(self):
14331578
mock_describe_calls()

0 commit comments

Comments
 (0)