Skip to content

Commit d67fc6b

Browse files
Fix issue where zero threshold was selecting everything. Added tests as well
1 parent 2a30113 commit d67fc6b

2 files changed

Lines changed: 189 additions & 4 deletions

File tree

cumulusci/tasks/bulkdata/select_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def annoy_post_process(
397397
# Retrieve the corresponding record from the database
398398
record = query_record_data[neighbor_index]
399399
closest_record_id = record_to_id_map[tuple(record)]
400-
if threshold and (neighbor_distances[idx] >= threshold):
400+
if threshold is not None and (neighbor_distances[idx] >= threshold):
401401
selected_records.append(None)
402402
insertion_candidates.append(load_shaped_records[i])
403403
else:
@@ -445,7 +445,7 @@ def levenshtein_post_process(
445445
select_record, target_records, similarity_weights
446446
)
447447

448-
if distance_threshold and match_distance > distance_threshold:
448+
if distance_threshold is not None and match_distance > distance_threshold:
449449
# Append load record for insertion if distance exceeds threshold
450450
insertion_candidates.append(load_record)
451451
selected_records.append(None)

cumulusci/tasks/bulkdata/tests/test_step.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,9 @@ def test_process_insert_records_failure(self, download_mock):
12321232
)
12331233

12341234
@mock.patch("cumulusci.tasks.bulkdata.step.download_file")
1235-
def test_select_records_similarity_strategy__insert_records(self, download_mock):
1235+
def test_select_records_similarity_strategy__insert_records__non_zero_threshold(
1236+
self, download_mock
1237+
):
12361238
# Set up mock context and BulkApiDmlOperation
12371239
context = mock.Mock()
12381240
# Add step with threshold
@@ -1325,6 +1327,102 @@ def test_select_records_similarity_strategy__insert_records(self, download_mock)
13251327
== 1
13261328
)
13271329

1330+
@mock.patch("cumulusci.tasks.bulkdata.step.download_file")
1331+
def test_select_records_similarity_strategy__insert_records__zero_threshold(
1332+
self, download_mock
1333+
):
1334+
# Set up mock context and BulkApiDmlOperation
1335+
context = mock.Mock()
1336+
# Add step with threshold
1337+
step = BulkApiDmlOperation(
1338+
sobject="Contact",
1339+
operation=DataOperationType.QUERY,
1340+
api_options={"batch_size": 10, "update_key": "LastName"},
1341+
context=context,
1342+
fields=["Name", "Email"],
1343+
selection_strategy=SelectStrategy.SIMILARITY,
1344+
threshold=0,
1345+
)
1346+
1347+
# Mock Bulk API responses
1348+
step.bulk.endpoint = "https://test"
1349+
step.bulk.create_query_job.return_value = "JOB"
1350+
step.bulk.query.return_value = "BATCH"
1351+
step.bulk.get_query_batch_result_ids.return_value = ["RESULT"]
1352+
1353+
# Mock the downloaded CSV content with a single record
1354+
select_results = io.StringIO(
1355+
"""[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}]"""
1356+
)
1357+
insert_results = io.StringIO(
1358+
"Id,Success,Created\n003000000000002,true,true\n003000000000003,true,true\n"
1359+
)
1360+
download_mock.side_effect = [select_results, insert_results]
1361+
1362+
# Mock the _wait_for_job method to simulate a successful job
1363+
step._wait_for_job = mock.Mock()
1364+
step._wait_for_job.return_value = DataOperationJobResult(
1365+
DataOperationStatus.SUCCESS, [], 0, 0
1366+
)
1367+
1368+
# Prepare input records
1369+
records = iter(
1370+
[
1371+
["Jawad", "mjawadtp@example.com"],
1372+
["Aditya", "aditya@example.com"],
1373+
["Tom", "cruise@example.com"],
1374+
]
1375+
)
1376+
1377+
# Mock sub-operation for BulkApiDmlOperation
1378+
insert_step = mock.Mock(spec=BulkApiDmlOperation)
1379+
insert_step.start = mock.Mock()
1380+
insert_step.load_records = mock.Mock()
1381+
insert_step.end = mock.Mock()
1382+
insert_step.batch_ids = ["BATCH1"]
1383+
insert_step.bulk = mock.Mock()
1384+
insert_step.bulk.endpoint = "https://test"
1385+
insert_step.job_id = "JOB"
1386+
1387+
with mock.patch(
1388+
"cumulusci.tasks.bulkdata.step.BulkApiDmlOperation",
1389+
return_value=insert_step,
1390+
):
1391+
# Execute the select_records operation
1392+
step.start()
1393+
step.select_records(records)
1394+
step.end()
1395+
1396+
# Get the results and assert their properties
1397+
results = list(step.get_results())
1398+
1399+
assert len(results) == 3 # Expect 3 results (matching the input records count)
1400+
# Assert that all results have the expected ID, success, and created values
1401+
assert (
1402+
results.count(
1403+
DataOperationResult(
1404+
id="003000000000001", success=True, error="", created=False
1405+
)
1406+
)
1407+
== 1
1408+
)
1409+
assert (
1410+
results.count(
1411+
DataOperationResult(
1412+
id="003000000000002", success=True, error="", created=True
1413+
)
1414+
)
1415+
== 1
1416+
)
1417+
assert (
1418+
results.count(
1419+
DataOperationResult(
1420+
id="003000000000003", success=True, error="", created=True
1421+
)
1422+
)
1423+
== 1
1424+
)
1425+
13281426
@mock.patch("cumulusci.tasks.bulkdata.step.download_file")
13291427
def test_select_records_similarity_strategy__insert_records__no_select_records(
13301428
self, download_mock
@@ -2807,7 +2905,9 @@ def test_process_insert_records_failure(self):
28072905
mock_rest_api_dml_operation.end.assert_not_called()
28082906

28092907
@responses.activate
2810-
def test_select_records_similarity_strategy__insert_records(self):
2908+
def test_select_records_similarity_strategy__insert_records__non_zero_threshold(
2909+
self,
2910+
):
28112911
mock_describe_calls()
28122912
task = _make_task(
28132913
LoadData,
@@ -2891,6 +2991,91 @@ def test_select_records_similarity_strategy__insert_records(self):
28912991
== 1
28922992
)
28932993

2994+
@responses.activate
2995+
def test_select_records_similarity_strategy__insert_records__zero_threshold(self):
2996+
mock_describe_calls()
2997+
task = _make_task(
2998+
LoadData,
2999+
{
3000+
"options": {
3001+
"database_url": "sqlite:///test.db",
3002+
"mapping": "mapping.yml",
3003+
}
3004+
},
3005+
)
3006+
task.project_config.project__package__api_version = CURRENT_SF_API_VERSION
3007+
task._init_task()
3008+
3009+
# Create step with threshold
3010+
step = RestApiDmlOperation(
3011+
sobject="Contact",
3012+
operation=DataOperationType.UPSERT,
3013+
api_options={"batch_size": 10},
3014+
context=task,
3015+
fields=["Name", "Email"],
3016+
selection_strategy=SelectStrategy.SIMILARITY,
3017+
threshold=0,
3018+
)
3019+
3020+
results_select_call = {
3021+
"records": [
3022+
{
3023+
"Id": "003000000000001",
3024+
"Name": "Jawad",
3025+
"Email": "mjawadtp@example.com",
3026+
},
3027+
],
3028+
"done": True,
3029+
}
3030+
3031+
results_insert_call = [
3032+
{"id": "003000000000002", "success": True, "created": True},
3033+
{"id": "003000000000003", "success": True, "created": True},
3034+
]
3035+
3036+
step.sf.restful = mock.Mock(
3037+
side_effect=[results_select_call, results_insert_call]
3038+
)
3039+
records = iter(
3040+
[
3041+
["Jawad", "mjawadtp@example.com"],
3042+
["Aditya", "aditya@example.com"],
3043+
["Tom Cruise", "tom@example.com"],
3044+
]
3045+
)
3046+
step.start()
3047+
step.select_records(records)
3048+
step.end()
3049+
3050+
# Get the results and assert their properties
3051+
results = list(step.get_results())
3052+
assert len(results) == 3 # Expect 3 results (matching the input records count)
3053+
# Assert that all results have the expected ID, success, and created values
3054+
assert (
3055+
results.count(
3056+
DataOperationResult(
3057+
id="003000000000001", success=True, error="", created=False
3058+
)
3059+
)
3060+
== 1
3061+
)
3062+
assert (
3063+
results.count(
3064+
DataOperationResult(
3065+
id="003000000000002", success=True, error="", created=True
3066+
)
3067+
)
3068+
== 1
3069+
)
3070+
assert (
3071+
results.count(
3072+
DataOperationResult(
3073+
id="003000000000003", success=True, error="", created=True
3074+
)
3075+
)
3076+
== 1
3077+
)
3078+
28943079
@responses.activate
28953080
def test_insert_dml_operation__boolean_conversion(self):
28963081
mock_describe_calls()

0 commit comments

Comments
 (0)