Skip to content

Commit 0ac2000

Browse files
committed
Add tests for annoy_post_process
1 parent 7fae06e commit 0ac2000

1 file changed

Lines changed: 244 additions & 95 deletions

File tree

cumulusci/tasks/bulkdata/tests/test_select_utils.py

Lines changed: 244 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import pandas as pd
12
import pytest
23

34
from cumulusci.tasks.bulkdata.select_utils import (
45
SelectOperationExecutor,
56
SelectStrategy,
7+
annoy_post_process,
68
calculate_levenshtein_distance,
9+
determine_field_types,
710
find_closest_record,
811
levenshtein_distance,
12+
replace_empty_strings_with_missing,
13+
vectorize_records,
914
)
1015

1116

@@ -193,107 +198,56 @@ def test_levenshtein_distance():
193198
) # Longer strings with multiple differences
194199

195200

196-
def test_calculate_levenshtein_distance():
197-
# Identical records
198-
record1 = ["Tom Cruise", "24", "Actor"]
199-
record2 = ["Tom Cruise", "24", "Actor"]
200-
assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0
201-
202-
# Records with one different field
203-
record1 = ["Tom Cruise", "24", "Actor"]
204-
record2 = ["Tom Hanks", "24", "Actor"]
205-
assert calculate_levenshtein_distance(record1, record2) > 0 # Non-zero distance
206-
207-
# One record has an empty field
208-
record1 = ["Tom Cruise", "24", "Actor"]
209-
record2 = ["Tom Cruise", "", "Actor"]
210-
assert (
211-
calculate_levenshtein_distance(record1, record2) > 0
212-
) # Distance should reflect the empty field
213-
214-
# Completely empty records
215-
record1 = ["", "", ""]
216-
record2 = ["", "", ""]
217-
assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0
218-
219-
220-
def test_calculate_levenshtein_distance_error():
221-
# Identical records
222-
record1 = ["Tom Cruise", "24", "Actor"]
223-
record2 = [
224-
"Tom Cruise",
225-
"24",
226-
"Actor",
227-
"SomethingElse",
228-
] # Record Length does not match
229-
with pytest.raises(ValueError) as e:
230-
calculate_levenshtein_distance(record1, record2)
231-
assert "Records must have the same number of fields" in str(e.value)
232-
233-
234-
def test_find_closest_record():
235-
# Test case 1: Exact match
236-
load_record = ["Tom Cruise", "62", "Actor"]
237-
query_records = [
238-
[1, "Tom Hanks", "30", "Actor"],
239-
[2, "Tom Cruise", "62", "Actor"], # Exact match
240-
[3, "Jennifer Aniston", "30", "Actress"],
241-
]
242-
assert find_closest_record(load_record, query_records) == [
243-
2,
244-
"Tom Cruise",
245-
"62",
246-
"Actor",
247-
] # Should return the exact match
248-
249-
# Test case 2: Closest match with slight differences
250-
load_record = ["Tom Cruise", "62", "Actor"]
201+
def test_find_closest_record_different_weights():
202+
load_record = ["hello", "world"]
251203
query_records = [
252-
[1, "Tom Hanks", "62", "Actor"],
253-
[2, "Tom Cruise", "63", "Actor"], # Slight difference
254-
[3, "Jennifer Aniston", "30", "Actress"],
204+
["record1", "hello", "word"], # Levenshtein distance = 1
205+
["record2", "hullo", "word"], # Levenshtein distance = 1
206+
["record3", "hello", "word"], # Levenshtein distance = 1
255207
]
256-
assert find_closest_record(load_record, query_records) == [
257-
2,
258-
"Tom Cruise",
259-
"63",
260-
"Actor",
261-
] # Should return the closest match
262-
263-
# Test case 3: All records are significantly different
264-
load_record = ["Tom Cruise", "62", "Actor"]
208+
weights = [2.0, 0.5]
209+
210+
# With different weights, the first field will have more impact
211+
closest_record = find_closest_record(load_record, query_records, weights)
212+
assert closest_record == [
213+
"record1",
214+
"hello",
215+
"word",
216+
], "The closest record should be 'record1'."
217+
218+
219+
def test_find_closest_record_basic():
220+
load_record = ["hello", "world"]
265221
query_records = [
266-
[1, "Brad Pitt", "30", "Producer"],
267-
[2, "Leonardo DiCaprio", "40", "Director"],
268-
[3, "Jennifer Aniston", "30", "Actress"],
222+
["record1", "hello", "word"], # Levenshtein distance = 1
223+
["record2", "hullo", "word"], # Levenshtein distance = 1
224+
["record3", "hello", "word"], # Levenshtein distance = 1
269225
]
270-
assert (
271-
find_closest_record(load_record, query_records) == query_records[0]
272-
) # Should return the first record as the closest (though none are close)
226+
weights = [1.0, 1.0]
227+
228+
closest_record = find_closest_record(load_record, query_records, weights)
229+
assert closest_record == [
230+
"record1",
231+
"hello",
232+
"word",
233+
], "The closest record should be 'record1'."
273234

274-
# Test case 4: Closest match is the last in the list
275-
load_record = ["Tom Cruise", "62", "Actor"]
235+
236+
def test_find_closest_record_multiple_matches():
237+
load_record = ["cat", "dog"]
276238
query_records = [
277-
[1, "Johnny Depp", "50", "Actor"],
278-
[2, "Brad Pitt", "30", "Producer"],
279-
[3, "Tom Cruise", "62", "Actor"], # Exact match as the last record
239+
["record1", "bat", "dog"], # Levenshtein distance = 1
240+
["record2", "cat", "dog"], # Levenshtein distance = 0
241+
["record3", "dog", "cat"], # Levenshtein distance = 3
280242
]
281-
assert find_closest_record(load_record, query_records) == [
282-
3,
283-
"Tom Cruise",
284-
"62",
285-
"Actor",
286-
] # Should return the last record
287-
288-
# Test case 5: Single record in query_records
289-
load_record = ["Tom Cruise", "62", "Actor"]
290-
query_records = [[1, "Johnny Depp", "50", "Actor"]]
291-
assert find_closest_record(load_record, query_records) == [
292-
1,
293-
"Johnny Depp",
294-
"50",
295-
"Actor",
296-
] # Should return the only record available
243+
weights = [1.0, 1.0]
244+
245+
closest_record = find_closest_record(load_record, query_records, weights)
246+
assert closest_record == [
247+
"record2",
248+
"cat",
249+
"dog",
250+
], "The closest record should be 'record2'."
297251

298252

299253
def test_similarity_post_process_with_records():
@@ -307,10 +261,16 @@ def test_similarity_post_process_with_records():
307261
["003", "Jennifer Aniston", "30", "Actress"],
308262
]
309263

264+
weights = [1.0, 1.0, 1.0] # Adjust weights to match your data structure
265+
310266
selected_records, error_message = select_operator.select_post_process(
311-
load_records, query_records, num_records, sobject
267+
load_records, query_records, num_records, sobject, weights
312268
)
313269

270+
# selected_records, error_message = select_operator.select_post_process(
271+
# load_records, query_records, num_records, sobject
272+
# )
273+
314274
assert error_message is None
315275
assert len(selected_records) == num_records
316276
assert all(record["success"] for record in selected_records)
@@ -329,3 +289,192 @@ def test_similarity_post_process_with_no_records():
329289

330290
assert selected_records == []
331291
assert error_message == f"No records found for {sobject} in the target org."
292+
293+
294+
def test_calculate_levenshtein_distance_basic():
295+
record1 = ["hello", "world"]
296+
record2 = ["hullo", "word"]
297+
weights = [1.0, 1.0]
298+
299+
# Expected distance based on simple Levenshtein distances
300+
# Levenshtein("hello", "hullo") = 1, Levenshtein("world", "word") = 1
301+
expected_distance = (1 * 1.0 + 1 * 1.0) / 2 # Averaged over two fields
302+
303+
result = calculate_levenshtein_distance(record1, record2, weights)
304+
assert result == pytest.approx(
305+
expected_distance
306+
), "Basic distance calculation failed."
307+
308+
309+
def test_calculate_levenshtein_distance_weighted():
310+
record1 = ["cat", "dog"]
311+
record2 = ["bat", "fog"]
312+
weights = [2.0, 0.5]
313+
314+
# Levenshtein("cat", "bat") = 1, Levenshtein("dog", "fog") = 1
315+
expected_distance = (1 * 2.0 + 1 * 0.5) / 2 # Weighted average over two fields
316+
317+
result = calculate_levenshtein_distance(record1, record2, weights)
318+
assert result == pytest.approx(
319+
expected_distance
320+
), "Weighted distance calculation failed."
321+
322+
323+
def test_replace_empty_strings_with_missing():
324+
# Case 1: Normal case with some empty strings
325+
records = [
326+
["Alice", "", "New York"],
327+
["Bob", "Engineer", ""],
328+
["", "Teacher", "Chicago"],
329+
]
330+
expected = [
331+
["Alice", "missing", "New York"],
332+
["Bob", "Engineer", "missing"],
333+
["missing", "Teacher", "Chicago"],
334+
]
335+
assert replace_empty_strings_with_missing(records) == expected
336+
337+
# Case 2: No empty strings, so the output should be the same as input
338+
records = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]]
339+
expected = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]]
340+
assert replace_empty_strings_with_missing(records) == expected
341+
342+
# Case 3: List with all empty strings
343+
records = [["", "", ""], ["", "", ""]]
344+
expected = [["missing", "missing", "missing"], ["missing", "missing", "missing"]]
345+
assert replace_empty_strings_with_missing(records) == expected
346+
347+
# Case 4: Empty list (should return an empty list)
348+
records = []
349+
expected = []
350+
assert replace_empty_strings_with_missing(records) == expected
351+
352+
# Case 5: List with some empty sublists
353+
records = [[], ["Alice", ""], []]
354+
expected = [[], ["Alice", "missing"], []]
355+
assert replace_empty_strings_with_missing(records) == expected
356+
357+
358+
def test_all_numeric_columns():
359+
df = pd.DataFrame({"A": [1, 2, 3], "B": [4.5, 5.5, 6.5]})
360+
weights = [0.1, 0.2]
361+
expected_output = (
362+
["A", "B"], # numerical_features
363+
[], # boolean_features
364+
[], # categorical_features
365+
[0.1, 0.2], # numerical_weights
366+
[], # boolean_weights
367+
[], # categorical_weights
368+
)
369+
assert determine_field_types(df, weights) == expected_output
370+
371+
372+
def test_all_boolean_columns():
373+
df = pd.DataFrame({"A": ["true", "false", "true"], "B": ["false", "true", "false"]})
374+
weights = [0.3, 0.4]
375+
expected_output = (
376+
[], # numerical_features
377+
["A", "B"], # boolean_features
378+
[], # categorical_features
379+
[], # numerical_weights
380+
[0.3, 0.4], # boolean_weights
381+
[], # categorical_weights
382+
)
383+
assert determine_field_types(df, weights) == expected_output
384+
385+
386+
def test_all_categorical_columns():
387+
df = pd.DataFrame(
388+
{"A": ["apple", "banana", "cherry"], "B": ["dog", "cat", "mouse"]}
389+
)
390+
weights = [0.5, 0.6]
391+
expected_output = (
392+
[], # numerical_features
393+
[], # boolean_features
394+
["A", "B"], # categorical_features
395+
[], # numerical_weights
396+
[], # boolean_weights
397+
[0.5, 0.6], # categorical_weights
398+
)
399+
assert determine_field_types(df, weights) == expected_output
400+
401+
402+
def test_mixed_types():
403+
df = pd.DataFrame(
404+
{
405+
"A": [1, 2, 3],
406+
"B": ["true", "false", "true"],
407+
"C": ["apple", "banana", "cherry"],
408+
}
409+
)
410+
weights = [0.7, 0.8, 0.9]
411+
expected_output = (
412+
["A"], # numerical_features
413+
["B"], # boolean_features
414+
["C"], # categorical_features
415+
[0.7], # numerical_weights
416+
[0.8], # boolean_weights
417+
[0.9], # categorical_weights
418+
)
419+
assert determine_field_types(df, weights) == expected_output
420+
421+
422+
def test_vectorize_records_mixed_numerical_categorical():
423+
# Test data with mixed types: numerical and categorical only
424+
db_records = [["1.0", "apple"], ["2.0", "banana"]]
425+
query_records = [["1.5", "apple"], ["2.5", "cherry"]]
426+
weights = [1.0, 1.0] # Equal weights for numerical and categorical columns
427+
hash_features = 4 # Number of hashing vectorizer features for categorical columns
428+
429+
final_db_vectors, final_query_vectors = vectorize_records(
430+
db_records, query_records, hash_features, weights
431+
)
432+
433+
# Check the shape of the output vectors
434+
assert final_db_vectors.shape[0] == len(db_records), "DB vectors row count mismatch"
435+
assert final_query_vectors.shape[0] == len(
436+
query_records
437+
), "Query vectors row count mismatch"
438+
439+
# Expected dimensions: numerical (1) + categorical hashed features (4)
440+
expected_feature_count = 1 + hash_features
441+
assert (
442+
final_db_vectors.shape[1] == expected_feature_count
443+
), "DB vectors column count mismatch"
444+
assert (
445+
final_query_vectors.shape[1] == expected_feature_count
446+
), "Query vectors column count mismatch"
447+
448+
449+
def test_annoy_post_process():
450+
# Test data
451+
load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]]
452+
query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]]
453+
weights = [1.0, 1.0, 1.0] # Example weights
454+
455+
closest_records, error = annoy_post_process(load_records, query_records, weights)
456+
457+
# Assert the closest records
458+
assert (
459+
len(closest_records) == 2
460+
) # We expect two results (one for each query record)
461+
assert (
462+
closest_records[0]["id"] == "q1"
463+
) # The first query record should match the first load record
464+
465+
# No errors expected
466+
assert error is None
467+
468+
469+
def test_single_record_match_annoy_post_process():
470+
# Mock data where only the first query record matches the first load record
471+
load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]]
472+
query_records = [["q1", "Alice", "Engineer"]]
473+
weights = [1.0, 1.0, 1.0]
474+
475+
closest_records, error = annoy_post_process(load_records, query_records, weights)
476+
477+
# Both the load records should be matched with the only query record we have
478+
assert len(closest_records) == 2
479+
assert closest_records[0]["id"] == "q1"
480+
assert error is None

0 commit comments

Comments
 (0)