-
Notifications
You must be signed in to change notification settings - Fork 254
Expand file tree
/
Copy pathquery_transformers.py
More file actions
305 lines (250 loc) · 11.1 KB
/
query_transformers.py
File metadata and controls
305 lines (250 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import re
import typing as T
from functools import cached_property
from sqlalchemy import String, and_, case, func, text
from sqlalchemy.orm import Query, aliased
from sqlalchemy.sql import literal_column
from cumulusci.core.exceptions import BulkDataException
Criterion = T.Any
ID_TABLE_NAME = "cumulusci_id_table"
# Salesforce ID pattern: 15 or 18 alphanumeric characters
# This matches the OID_REGEX pattern used in robotframework/Salesforce.py
SF_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{15}$|^[a-zA-Z0-9]{18}$")
def is_salesforce_id(value: T.Optional[str]) -> bool:
"""Check if a value looks like a valid Salesforce ID."""
if value is None:
return False
return bool(SF_ID_PATTERN.match(str(value)))
def _is_salesforce_id_sqlite(value: T.Optional[str]) -> int:
"""SQLite UDF wrapper for is_salesforce_id."""
return 1 if is_salesforce_id(value) else 0
def register_sqlite_functions(connection) -> None:
"""Register custom SQLite functions on a database connection."""
# Get the underlying DBAPI connection
dbapi_connection = connection.connection.dbapi_connection
dbapi_connection.create_function("is_salesforce_id", 1, _is_salesforce_id_sqlite)
class LoadQueryExtender:
"""Class that transforms a load.py query with columns, filters, joins"""
@cached_property
def columns_to_add(*args) -> T.Optional[T.List]:
return None
@cached_property
def filters_to_add(*args) -> T.Optional[T.List]:
return None
@cached_property
def outerjoins_to_add(*args) -> T.Optional[T.List]:
return None
def __init__(self, mapping, metadata, model) -> None:
self.mapping, self.metadata, self.model = mapping, metadata, model
def add_columns(self, query: Query):
"""Add columns to the query"""
if self.columns_to_add:
query = query.add_columns(*self.columns_to_add)
return query
def add_filters(self, query: Query):
"""Add filters to the query"""
if self.filters_to_add:
return query.filter(*self.filters_to_add)
return query
def add_outerjoins(self, query: Query):
"""Add outer joins to the query"""
if self.outerjoins_to_add:
for table, condition in self.outerjoins_to_add:
query = query.outerjoin(table, condition)
return query
class AddLookupsToQuery(LoadQueryExtender):
"""Adds columns and joins relatinng to lookups"""
def __init__(self, mapping, metadata, model, _old_format) -> None:
super().__init__(mapping, metadata, model)
self._old_format = _old_format
self.lookups = [
lookup for lookup in self.mapping.lookups.values() if not lookup.after
]
@cached_property
def columns_to_add(self):
"""Build column expressions for lookup fields with smart ID resolution."""
columns = []
for lookup in self.lookups:
lookup.aliased_table = aliased(self.metadata.tables[ID_TABLE_NAME])
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)
# The resolved SF ID from the ID table join (may be NULL)
sf_id_from_table = lookup.aliased_table.columns.sf_id
smart_lookup = case(
# If we found a match in the ID table, use that
(sf_id_from_table.isnot(None), sf_id_from_table),
# If the original value is already a SF ID, use it directly
(func.is_salesforce_id(value_column) == 1, value_column),
# Otherwise return NULL (lookup not found)
else_=None,
)
columns.append(smart_lookup)
return columns
@cached_property
def outerjoins_to_add(self):
# Outer join with lookup ids table:
# returns main obj even if lookup is null
def join_for_lookup(lookup):
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)
if self._old_format:
return (
lookup.aliased_table,
lookup.aliased_table.columns.id
== str(lookup.table) + "-" + func.cast(value_column, String),
)
else:
return (
lookup.aliased_table,
lookup.aliased_table.columns.id == value_column,
)
return [join_for_lookup(lookup) for lookup in self.lookups]
class DynamicLookupQueryExtender(LoadQueryExtender):
"""Dynamically adds columns and joins for all fields in lookup tables, handling polymorphic lookups"""
def __init__(
self, mapping, all_mappings, metadata, model, _old_format: bool
) -> None:
super().__init__(mapping, metadata, model)
self._old_format = _old_format
self.all_mappings = all_mappings
self.lookups = [
lookup for lookup in self.mapping.lookups.values() if not lookup.after
]
@cached_property
def columns_to_add(self):
"""Add all relevant fields from lookup tables directly without CASE, with support for polymorphic lookups."""
columns = []
for lookup in self.lookups:
tables = lookup.table if isinstance(lookup.table, list) else [lookup.table]
lookup.parent_tables = [
aliased(
self.metadata.tables[table], name=f"{lookup.name}_{table}_alias"
)
for table in tables
]
for parent_table, table_name in zip(lookup.parent_tables, tables):
# Find the mapping step for this polymorphic type
lookup_mapping_step = next(
(
step
for step in self.all_mappings.values()
if step.table == table_name
),
None,
)
if lookup_mapping_step:
load_fields = lookup_mapping_step.fields.keys()
for field in load_fields:
if field in lookup_mapping_step.fields:
matching_column = next(
(
col
for col in parent_table.columns
if col.name == lookup_mapping_step.fields[field]
)
)
columns.append(
matching_column.label(f"{parent_table.name}_{field}")
)
else:
# Append an empty string if the field is not present
columns.append(
literal_column("''").label(
f"{parent_table.name}_{field}"
)
)
return columns
@cached_property
def outerjoins_to_add(self):
"""Add outer joins for each lookup table directly, including handling for polymorphic lookups."""
def join_for_lookup(lookup, parent_table):
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)
return (parent_table, parent_table.columns.id == value_column)
joins = []
for lookup in self.lookups:
for parent_table in lookup.parent_tables:
joins.append(join_for_lookup(lookup, parent_table))
return joins
class AddRecordTypesToQuery(LoadQueryExtender):
"""Adds columns, joins and filters relatinng to recordtypes"""
def __init__(self, mapping, metadata, model) -> None:
super().__init__(mapping, metadata, model)
if "RecordTypeId" in mapping.fields:
self.rt_dest_table = metadata.tables[
mapping.get_destination_record_type_table()
]
else:
self.rt_dest_table = None
@cached_property
def columns_to_add(self):
if self.rt_dest_table is not None:
return [self.rt_dest_table.columns.record_type_id]
@cached_property
def filters_to_add(self):
if self.mapping.record_type and hasattr(self.model, "record_type"):
return [self.model.record_type == self.mapping.record_type]
@cached_property
def outerjoins_to_add(self):
if "RecordTypeId" in self.mapping.fields:
try:
rt_source_table = self.metadata.tables[
self.mapping.get_source_record_type_table()
]
except KeyError as e:
# For generate_and_load_from_yaml, In case of namespace_inject true, mapping table name doesn't have namespace added
# We are checking for table_rt_mapping table
try:
rt_source_table = self.metadata.tables[
f"{self.mapping.table}_rt_mapping"
]
except KeyError as f:
raise BulkDataException(
"A record type mapping table was not found in your dataset. "
f"Was it generated by extract_data? {e}",
) from f
rt_dest_table = self.metadata.tables[
self.mapping.get_destination_record_type_table()
]
# Check if 'is_person_type' column exists in rt_source_table.columns
is_person_type_column = getattr(
rt_source_table.columns, "is_person_type", None
)
# If it does not exist, set condition to True
is_person_type_condition = (
rt_dest_table.columns.is_person_type == is_person_type_column
if is_person_type_column is not None
else True
)
return [
(
rt_source_table,
rt_source_table.columns.record_type_id
== getattr(self.model, self.mapping.fields["RecordTypeId"]),
),
# Combination of IsPersonType and DeveloperName is unique
(
rt_dest_table,
and_(
rt_dest_table.columns.developer_name
== rt_source_table.columns.developer_name,
is_person_type_condition,
),
),
]
class AddMappingFiltersToQuery(LoadQueryExtender):
"""Adds filters relating to user-specified filters"""
@cached_property
def filters_to_add(self):
if self.mapping.filters:
return [text(f) for f in self.mapping.filters]
class AddPersonAccountsToQuery(LoadQueryExtender):
"""Add filters relating to Person accounts."""
@cached_property
def filters_to_add(self):
"""Filter out non-person account Contact records.
Contact records for person accounts were already created by the system."""
assert self.mapping.sf_object == "Contact"
return [
func.lower(self.model.__table__.columns.get("IsPersonAccount")) == "false"
]