Skip to content

Commit 6fad397

Browse files
Adds selection_filter and RANDOM selection strategy
1 parent a368803 commit 6fad397

6 files changed

Lines changed: 715 additions & 235 deletions

File tree

cumulusci/tasks/bulkdata/load.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,12 @@ def _execute_step(
289289
self, step, self._stream_queried_data(mapping, local_ids, query)
290290
)
291291
step.start()
292-
step.load_records(self._stream_queried_data(mapping, local_ids, query))
292+
if mapping.action == DataOperationType.SELECT:
293+
step.select_records(
294+
self._stream_queried_data(mapping, local_ids, query)
295+
)
296+
else:
297+
step.load_records(self._stream_queried_data(mapping, local_ids, query))
293298
step.end()
294299

295300
# Process Job Results
@@ -336,6 +341,8 @@ def configure_step(self, mapping):
336341
self.check_simple_upsert(mapping)
337342
api_options["update_key"] = mapping.update_key[0]
338343
action = DataOperationType.UPSERT
344+
elif mapping.action == DataOperationType.SELECT:
345+
action = DataOperationType.QUERY
339346
else:
340347
action = mapping.action
341348

@@ -349,6 +356,8 @@ def configure_step(self, mapping):
349356
fields=fields,
350357
api=mapping.api,
351358
volume=query.count(),
359+
selection_strategy=mapping.selection_strategy,
360+
selection_filter=mapping.selection_filter,
352361
)
353362
return step, query
354363

@@ -481,10 +490,11 @@ def _process_job_results(self, mapping, step, local_ids):
481490
"""Get the job results and process the results. If we're raising for
482491
row-level errors, do so; if we're inserting, store the new Ids."""
483492

484-
is_insert_or_upsert = mapping.action in (
493+
is_insert_upsert_or_select = mapping.action in (
485494
DataOperationType.INSERT,
486495
DataOperationType.UPSERT,
487496
DataOperationType.ETL_UPSERT,
497+
DataOperationType.SELECT,
488498
)
489499

490500
conn = self.session.connection()
@@ -500,7 +510,7 @@ def _process_job_results(self, mapping, step, local_ids):
500510
break
501511
# If we know we have no successful inserts, don't attempt to persist Ids.
502512
# Do, however, drain the generator to get error-checking behavior.
503-
if is_insert_or_upsert and (
513+
if is_insert_upsert_or_select and (
504514
step.job_result.records_processed - step.job_result.total_row_errors
505515
):
506516
table = self.metadata.tables[self.ID_TABLE_NAME]
@@ -516,7 +526,7 @@ def _process_job_results(self, mapping, step, local_ids):
516526
# person account Contact records so lookups to
517527
# person account Contact records get populated downstream as expected.
518528
if (
519-
is_insert_or_upsert
529+
is_insert_upsert_or_select
520530
and mapping.sf_object == "Contact"
521531
and self._can_load_person_accounts(mapping)
522532
):
@@ -531,7 +541,7 @@ def _process_job_results(self, mapping, step, local_ids):
531541
),
532542
)
533543

534-
if is_insert_or_upsert:
544+
if is_insert_upsert_or_select:
535545
self.session.commit()
536546

537547
def _generate_results_id_map(self, step, local_ids):

cumulusci/tasks/bulkdata/mapping_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from cumulusci.core.enums import StrEnum
1616
from cumulusci.core.exceptions import BulkDataException
1717
from cumulusci.tasks.bulkdata.dates import iso_to_date
18+
from cumulusci.tasks.bulkdata.select_utils import SelectStrategy
1819
from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType
1920
from cumulusci.utils import convert_to_snake_case
2021
from cumulusci.utils.yaml.model_parser import CCIDictModel
@@ -84,7 +85,7 @@ class BulkMode(StrEnum):
8485

8586
ENUM_VALUES = {
8687
v.value.lower(): v.value
87-
for enum in [BulkMode, DataApi, DataOperationType]
88+
for enum in [BulkMode, DataApi, DataOperationType, SelectStrategy]
8889
for v in enum.__members__.values()
8990
}
9091

@@ -107,9 +108,13 @@ class MappingStep(CCIDictModel):
107108
] = None # default should come from task options
108109
anchor_date: Optional[Union[str, date]] = None
109110
soql_filter: Optional[str] = None # soql_filter property
111+
selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy
112+
selection_filter: Optional[
113+
str
114+
] = None # filter to be added at the end of select query
110115
update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts
111116

112-
@validator("bulk_mode", "api", "action", pre=True)
117+
@validator("bulk_mode", "api", "action", "selection_strategy", pre=True)
113118
def case_normalize(cls, val):
114119
if isinstance(val, Enum):
115120
return val

cumulusci/tasks/bulkdata/select_utils.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import typing as T
23

34
from cumulusci.core.enums import StrEnum
@@ -9,14 +10,52 @@
910
class SelectStrategy(StrEnum):
1011
"""Enum defining the different selection strategies requested."""
1112

12-
RANDOM = "random"
13+
STANDARD = "standard"
1314
SIMILARITY = "similarity"
15+
RANDOM = "random"
1416

1517

16-
def random_generate_query(
17-
sobject: str, fields: T.List[str], num_records: float
18+
class SelectOperationExecutor:
19+
def __init__(self, strategy: SelectStrategy):
20+
self.strategy = strategy
21+
22+
def select_generate_query(
23+
self, sobject: str, fields: T.List[str], num_records: int
24+
):
25+
# For STANDARD strategy
26+
if self.strategy == SelectStrategy.STANDARD:
27+
return standard_generate_query(sobject=sobject, num_records=num_records)
28+
# For SIMILARITY strategy
29+
elif self.strategy == SelectStrategy.SIMILARITY:
30+
return similarity_generate_query(sobject=sobject, fields=fields)
31+
# For RANDOM strategy
32+
elif self.strategy == SelectStrategy.RANDOM:
33+
return standard_generate_query(sobject=sobject, num_records=num_records)
34+
35+
def select_post_process(
36+
self, load_records, query_records: list, num_records: int, sobject: str
37+
):
38+
# For STANDARD strategy
39+
if self.strategy == SelectStrategy.STANDARD:
40+
return standard_post_process(
41+
query_records=query_records, num_records=num_records, sobject=sobject
42+
)
43+
# For SIMILARITY strategy
44+
elif self.strategy == SelectStrategy.SIMILARITY:
45+
return similarity_post_process(
46+
load_records=load_records, query_records=query_records, sobject=sobject
47+
)
48+
# For RANDOM strategy
49+
elif self.strategy == SelectStrategy.RANDOM:
50+
return random_post_process(
51+
query_records=query_records, num_records=num_records, sobject=sobject
52+
)
53+
54+
55+
def standard_generate_query(
56+
sobject: str, num_records: int
1857
) -> T.Tuple[str, T.List[str]]:
19-
"""Generates the SOQL query for the random selection strategy"""
58+
"""Generates the SOQL query for the standard (as well as random) selection strategy"""
2059
# Get the WHERE clause from DEFAULT_DECLARATIONS if available
2160
declaration = DEFAULT_DECLARATIONS.get(sobject)
2261
if declaration:
@@ -32,10 +71,10 @@ def random_generate_query(
3271
return query, ["Id"]
3372

3473

35-
def random_post_process(
36-
load_records, query_records: list, num_records: float, sobject: str
74+
def standard_post_process(
75+
query_records: list, num_records: int, sobject: str
3776
) -> T.Tuple[T.List[dict], T.Union[str, None]]:
38-
"""Processes the query results for the random selection strategy"""
77+
"""Processes the query results for the standard selection strategy"""
3978
# Handle case where query returns 0 records
4079
if not query_records:
4180
error_message = f"No records found for {sobject} in the target org."
@@ -59,9 +98,8 @@ def random_post_process(
5998
def similarity_generate_query(
6099
sobject: str,
61100
fields: T.List[str],
62-
num_records: float,
63101
) -> T.Tuple[str, T.List[str]]:
64-
"""Generates the SOQL query for the random selection strategy"""
102+
"""Generates the SOQL query for the similarity selection strategy"""
65103
# Get the WHERE clause from DEFAULT_DECLARATIONS if available
66104
declaration = DEFAULT_DECLARATIONS.get(sobject)
67105
if declaration:
@@ -81,7 +119,7 @@ def similarity_generate_query(
81119

82120

83121
def similarity_post_process(
84-
load_records, query_records: list, num_records: float, sobject: str
122+
load_records: list, query_records: list, sobject: str
85123
) -> T.Tuple[T.List[dict], T.Union[str, None]]:
86124
"""Processes the query results for the similarity selection strategy"""
87125
# Handle case where query returns 0 records
@@ -100,6 +138,26 @@ def similarity_post_process(
100138
return closest_records, None
101139

102140

141+
def random_post_process(
142+
query_records: list, num_records: int, sobject: str
143+
) -> T.Tuple[T.List[dict], T.Union[str, None]]:
144+
"""Processes the query results for the random selection strategy"""
145+
146+
if not query_records:
147+
error_message = f"No records found for {sobject} in the target org."
148+
return [], error_message
149+
150+
selected_records = []
151+
for _ in range(num_records): # Loop 'num_records' times
152+
# Randomly select one record from query_records
153+
random_record = random.choice(query_records)
154+
selected_records.append(
155+
{"id": random_record[0], "success": True, "created": False}
156+
)
157+
158+
return selected_records, None
159+
160+
103161
def find_closest_record(load_record: list, query_records: list):
104162
closest_distance = float("inf")
105163
closest_record = query_records[0]

0 commit comments

Comments
 (0)