Skip to content

Commit bb72bfb

Browse files
Fixes issue for improper batching and intersection
1 parent 8c2bb3a commit bb72bfb

3 files changed

Lines changed: 152 additions & 103 deletions

File tree

cumulusci/tasks/bulkdata/select_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,22 @@ class SelectStrategy(StrEnum):
1515
RANDOM = "random"
1616

1717

18+
class SelectRecordRetrievalMode(StrEnum):
19+
"""Enum defining whether you need all records or match the
20+
number of records of the local sql file"""
21+
22+
ALL = "all"
23+
MATCH = "match"
24+
25+
1826
class SelectOperationExecutor:
1927
def __init__(self, strategy: SelectStrategy):
2028
self.strategy = strategy
29+
self.retrieval_mode = (
30+
SelectRecordRetrievalMode.ALL
31+
if strategy == SelectStrategy.SIMILARITY
32+
else SelectRecordRetrievalMode.MATCH
33+
)
2134

2235
def select_generate_query(
2336
self,
@@ -96,7 +109,7 @@ def standard_post_process(
96109
original_records = selected_records.copy()
97110
while len(selected_records) < num_records:
98111
selected_records.extend(original_records)
99-
selected_records = selected_records[:num_records]
112+
selected_records = selected_records[:num_records]
100113

101114
return selected_records, None # Return selected records and None for error
102115

@@ -115,8 +128,8 @@ def similarity_generate_query(
115128
else:
116129
where_clause = None
117130
# Construct the query with the WHERE clause (if it exists)
118-
119-
fields.insert(0, "Id")
131+
if "Id" not in fields:
132+
fields.insert(0, "Id")
120133
fields_to_query = ", ".join(field for field in fields if field)
121134

122135
query = f"SELECT {fields_to_query} FROM {sobject}"

cumulusci/tasks/bulkdata/step.py

Lines changed: 127 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from cumulusci.core.utils import process_bool_arg
2020
from cumulusci.tasks.bulkdata.select_utils import (
2121
SelectOperationExecutor,
22+
SelectRecordRetrievalMode,
2223
SelectStrategy,
2324
)
2425
from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks
@@ -452,71 +453,66 @@ def select_records(self, records):
452453
# Count total number of records to fetch using the copy
453454
total_num_records = sum(1 for _ in records_copy)
454455

455-
# Process in batches based on batch_size from api_options
456-
for offset in range(
457-
0, total_num_records, self.api_options.get("batch_size", 500)
458-
):
459-
# Calculate number of records to fetch in this batch
460-
num_records = min(
461-
self.api_options.get("batch_size", 500), total_num_records - offset
456+
# Since OFFSET is not supported in bulk, we can run only over 1 api_batch_size
457+
# Generate and execute SOQL query
458+
# (not passing offset as it is not supported in Bulk)
459+
(
460+
select_query,
461+
query_fields,
462+
) = self.select_operation_executor.select_generate_query(
463+
sobject=self.sobject,
464+
fields=self.fields,
465+
limit=self.api_options.get("batch_size", 500),
466+
offset=None,
467+
)
468+
if self.selection_filter:
469+
# Generate user filter query if selection_filter is present (offset clause not supported)
470+
user_query = generate_user_filter_query(
471+
filter_clause=self.selection_filter,
472+
sobject=self.sobject,
473+
fields=["Id"],
474+
limit_clause=self.api_options.get("batch_size", 500),
475+
offset_clause=None,
462476
)
463-
464-
# Generate and execute SOQL query
465-
# (not passing offset as it is not supported in Bulk)
466-
(
467-
select_query,
468-
query_fields,
469-
) = self.select_operation_executor.select_generate_query(
470-
sobject=self.sobject, fields=self.fields, limit=num_records, offset=None
477+
# Execute the user query using Bulk API
478+
user_query_executor = get_query_operation(
479+
sobject=self.sobject,
480+
fields=["Id"],
481+
api_options=self.api_options,
482+
context=self,
483+
query=user_query,
484+
api=DataApi.BULK,
471485
)
472-
if self.selection_filter:
473-
# Generate user filter query if selection_filter is present (offset clause not supported)
474-
user_query = generate_user_filter_query(
475-
filter_clause=self.selection_filter,
476-
sobject=self.sobject,
477-
fields=["Id"],
478-
limit_clause=num_records,
479-
offset_clause=None,
480-
)
481-
# Execute the user query using Bulk API
482-
user_query_executor = get_query_operation(
483-
sobject=self.sobject,
484-
fields=["Id"],
485-
api_options=self.api_options,
486-
context=self,
487-
query=user_query,
488-
api=DataApi.BULK,
489-
)
490-
user_query_executor.query()
491-
user_query_records = user_query_executor.get_results()
492-
493-
# Find intersection based on 'Id'
494-
user_query_ids = (
495-
list(record[0] for record in user_query_records)
496-
if user_query_records
497-
else []
498-
)
499-
500-
# Execute the main select query using Bulk API
501-
select_query_records = self._execute_select_query(
502-
select_query=select_query, query_fields=query_fields
486+
user_query_executor.query()
487+
user_query_records = user_query_executor.get_results()
488+
489+
# Find intersection based on 'Id'
490+
user_query_ids = (
491+
list(record[0] for record in user_query_records)
492+
if user_query_records
493+
else []
503494
)
504495

505-
# If user_query_ids exist, filter select_query_records based on the intersection of Ids
506-
if self.selection_filter:
507-
# Create a dictionary to map IDs to their corresponding records
508-
id_to_record_map = {
509-
record[query_fields.index("Id")]: record
510-
for record in select_query_records
511-
}
512-
# Extend query_records in the order of user_query_ids
513-
query_records.extend(
514-
record
515-
for id in user_query_ids
516-
if (record := id_to_record_map.get(id)) is not None
517-
)
518-
else:
519-
query_records.extend(select_query_records)
496+
# Execute the main select query using Bulk API
497+
select_query_records = self._execute_select_query(
498+
select_query=select_query, query_fields=query_fields
499+
)
500+
501+
# If user_query_ids exist, filter select_query_records based on the intersection of Ids
502+
if self.selection_filter:
503+
# Create a dictionary to map IDs to their corresponding records
504+
id_to_record_map = {
505+
record[query_fields.index("Id")]: record
506+
for record in select_query_records
507+
}
508+
# Extend query_records in the order of user_query_ids
509+
query_records.extend(
510+
record
511+
for id in user_query_ids
512+
if (record := id_to_record_map.get(id)) is not None
513+
)
514+
else:
515+
query_records.extend(select_query_records)
520516

521517
# Post-process the query results
522518
(
@@ -525,7 +521,7 @@ def select_records(self, records):
525521
) = self.select_operation_executor.select_post_process(
526522
load_records=records,
527523
query_records=query_records,
528-
num_records=num_records,
524+
num_records=total_num_records,
529525
sobject=self.sobject,
530526
)
531527
if not error_message:
@@ -674,7 +670,7 @@ def __init__(
674670
api_options,
675671
context,
676672
fields,
677-
selection_strategy=SelectStrategy.SIMILARITY,
673+
selection_strategy=SelectStrategy.STANDARD,
678674
selection_filter=None,
679675
):
680676
super().__init__(
@@ -816,60 +812,108 @@ def convert(rec, fields):
816812

817813
self.results = []
818814
query_records = []
815+
user_query_records = []
819816
# Create a copy of the generator using tee
820817
records, records_copy = tee(records)
821818
# Count total number of records to fetch using the copy
822819
total_num_records = sum(1 for _ in records_copy)
820+
# Set offset
821+
offset = 0
823822

824-
# Process in batches
825-
for offset in range(0, total_num_records, self.api_options.get("batch_size")):
826-
num_records = min(
827-
self.api_options.get("batch_size"), total_num_records - offset
828-
)
823+
# Define condition
824+
def condition(retrieval_mode, offset, total_num_records):
825+
if retrieval_mode == SelectRecordRetrievalMode.ALL:
826+
return True
827+
elif retrieval_mode == SelectRecordRetrievalMode.MATCH:
828+
return offset < total_num_records
829829

830+
# Process in batches
831+
while condition(
832+
self.select_operation_executor.retrieval_mode, offset, total_num_records
833+
):
830834
# Generate the SOQL query based on the selection strategy
831835
(
832836
select_query,
833837
query_fields,
834838
) = self.select_operation_executor.select_generate_query(
835839
sobject=self.sobject,
836840
fields=self.fields,
837-
limit=num_records,
841+
limit=self.api_options.get("batch_size"),
838842
offset=offset,
839843
)
840844

841845
# If user given selection filter present, create composite request
842846
if self.selection_filter:
847+
# Generate user query
843848
user_query = generate_user_filter_query(
844849
filter_clause=self.selection_filter,
845850
sobject=self.sobject,
846851
fields=["Id"],
847-
limit_clause=num_records,
852+
limit_clause=self.api_options.get("batch_size"),
848853
offset_clause=offset,
849854
)
850-
query_records.extend(
851-
self._execute_composite_query(
852-
select_query=select_query,
853-
user_query=user_query,
854-
query_fields=query_fields,
855-
)
855+
# Execute composite query
856+
(
857+
current_user_query_records,
858+
current_query_records,
859+
) = self._execute_composite_query(
860+
select_query=select_query,
861+
user_query=user_query,
862+
query_fields=query_fields,
856863
)
864+
# Break if org has no more records
865+
if (
866+
len(current_user_query_records) == 0
867+
and len(current_query_records) == 0
868+
):
869+
break
870+
871+
# Extend to each
872+
user_query_records.extend(current_user_query_records)
873+
query_records.extend(current_query_records)
874+
857875
else:
858876
# Handle the case where self.selection_query is None (and hence user_query is also None)
859877
response = self.sf.restful(
860878
requests.utils.requote_uri(f"query/?q={select_query}"), method="GET"
861879
)
862-
query_records.extend(
863-
list(convert(rec, query_fields) for rec in response["records"])
880+
current_query_records = list(
881+
convert(rec, query_fields) for rec in response["records"]
864882
)
883+
# Break if nothing is returned
884+
if len(current_query_records) == 0:
885+
break
886+
# Extend the query records
887+
query_records.extend(current_query_records)
888+
889+
# Update offset
890+
offset += self.api_options.get("batch_size")
891+
892+
# Find intersection if filter given
893+
if self.selection_filter:
894+
# Find intersection based on 'Id'
895+
user_query_ids = list(record[0] for record in user_query_records)
896+
# Create a dictionary to map IDs to their corresponding records
897+
id_to_record_map = {
898+
record[query_fields.index("Id")]: record for record in query_records
899+
}
900+
901+
# Extend insersection_query_records in the order of user_query_ids
902+
insersection_query_records = [
903+
record
904+
for id in user_query_ids
905+
if (record := id_to_record_map.get(id)) is not None
906+
]
907+
else:
908+
insersection_query_records = query_records
865909

866910
# Post-process the query results for this batch
867911
(
868912
selected_records,
869913
error_message,
870914
) = self.select_operation_executor.select_post_process(
871915
load_records=records,
872-
query_records=query_records,
916+
query_records=insersection_query_records,
873917
num_records=total_num_records,
874918
sobject=self.sobject,
875919
)
@@ -888,7 +932,7 @@ def convert(rec, fields):
888932
)
889933

890934
def _execute_composite_query(self, select_query, user_query, query_fields):
891-
"""Executes a composite request with two queries and returns the intersected results."""
935+
"""Executes a composite request with two queries and returns the results."""
892936

893937
def convert(rec, fields):
894938
"""Helper function to convert record values to strings, handling None values"""
@@ -937,19 +981,8 @@ def convert(rec, fields):
937981
raise SOQLQueryException(
938982
f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}"
939983
)
940-
# Find intersection based on 'Id'
941-
user_query_ids = list(record[0] for record in user_query_records)
942-
# Create a dictionary to map IDs to their corresponding records
943-
id_to_record_map = {
944-
record[query_fields.index("Id")]: record for record in select_query_records
945-
}
946984

947-
# Extend query_records in the order of user_query_ids
948-
return [
949-
record
950-
for id in user_query_ids
951-
if (record := id_to_record_map.get(id)) is not None
952-
]
985+
return user_query_records, select_query_records
953986

954987
def get_results(self):
955988
"""Return a generator of DataOperationResult objects."""
@@ -1076,8 +1109,8 @@ def generate_user_filter_query(
10761109
filter_clause: str,
10771110
sobject: str,
10781111
fields: list,
1079-
limit_clause: Union[int, None] = None,
1080-
offset_clause: Union[int, None] = None,
1112+
limit_clause: Union[float, None] = None,
1113+
offset_clause: Union[float, None] = None,
10811114
) -> str:
10821115
"""
10831116
Generates a SOQL query with the provided filter, object, fields, limit, and offset clauses.

cumulusci/tasks/bulkdata/tests/test_step.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,7 +1685,7 @@ def test_select_records_similarity_strategy_success(self):
16851685
selection_strategy=SelectStrategy.SIMILARITY,
16861686
)
16871687

1688-
results = {
1688+
results_first_call = {
16891689
"records": [
16901690
{
16911691
"Id": "003000000000001",
@@ -1705,13 +1705,16 @@ def test_select_records_similarity_strategy_success(self):
17051705
],
17061706
"done": True,
17071707
}
1708-
step.sf.restful = mock.Mock()
1709-
step.sf.restful.return_value = results
1708+
1709+
# First call returns `results_first_call`, second call returns an empty list
1710+
step.sf.restful = mock.Mock(
1711+
side_effect=[results_first_call, {"records": [], "done": True}]
1712+
)
17101713
records = iter(
17111714
[
1712-
["Id: 1", "Jawad", "mjawadtp@example.com"],
1713-
["Id: 2", "Aditya", "aditya@example.com"],
1714-
["Id: 3", "Tom Cruise", "tom@example.com"],
1715+
["Jawad", "mjawadtp@example.com"],
1716+
["Aditya", "aditya@example.com"],
1717+
["Tom Cruise", "tom@example.com"],
17151718
]
17161719
)
17171720
step.start()

0 commit comments

Comments
 (0)