Skip to content

Commit 5351ca1

Browse files
authored
fix(optimizer): EXPLODE qualify and annotate (#7549)
* fix(optimizer): EXPLODE qualify and annotate * change tests * ref walrus
1 parent 234198a commit 5351ca1

3 files changed

Lines changed: 86 additions & 29 deletions

File tree

sqlglot/optimizer/annotate_types.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,23 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
300300

301301
alias_column_names = expression.alias_column_names
302302

303-
if (
304-
isinstance(expression, exp.Unnest)
305-
and expression.type
306-
and expression.type.is_type(exp.DType.STRUCT)
303+
if isinstance(expression, exp.Unnest):
304+
exp_type = expression.type
305+
elif isinstance(expression, exp.Lateral) and isinstance(
306+
expression.this, exp.Explode
307307
):
308+
exp_type = expression.this.type
309+
else:
310+
exp_type = None
311+
312+
struct_type = (
313+
exp_type if exp_type and exp_type.is_type(exp.DType.STRUCT) else None
314+
)
315+
316+
if struct_type:
308317
selects[name] = {
309318
col_def.name: t.cast(t.Union[exp.DataType, exp.DType], col_def.kind)
310-
for col_def in expression.type.expressions
319+
for col_def in struct_type.expressions
311320
if isinstance(col_def, exp.ColumnDef) and col_def.kind
312321
}
313322
else:

sqlglot/optimizer/resolver.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,33 +142,34 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[
142142
if isinstance(source, exp.Table):
143143
columns = self.schema.column_names(source, only_visible)
144144
elif isinstance(source, Scope) and isinstance(
145-
source.expression, (exp.Values, exp.Unnest)
145+
source_expr := source.expression, (exp.Values, exp.Unnest, exp.Lateral)
146146
):
147-
columns = source.expression.named_selects
147+
columns = source_expr.named_selects
148148

149149
# in bigquery, unnest structs are automatically scoped as tables, so you can
150150
# directly select a struct field in a query.
151151
# this handles the case where the unnest is statically defined.
152-
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
153-
unnest = source.expression
154-
155-
# if type is not annotated yet, try to get it from the schema
156-
if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN):
157-
unnest_expr = seq_get(unnest.expressions, 0)
152+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source_expr, exp.Unnest):
153+
if not source_expr.type or source_expr.type.is_type(exp.DType.UNKNOWN):
154+
unnest_expr = seq_get(source_expr.expressions, 0)
158155
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
159-
col_type = self._get_unnest_column_type(unnest_expr)
160-
# extract element type if it's an ARRAY
156+
col_type = self._get_unnest_column_type(unnest_expr, self.scope.parent)
161157
if col_type and col_type.is_type(exp.DType.ARRAY):
162158
element_types = col_type.expressions
163159
if element_types:
164-
unnest.type = element_types[0].copy()
165-
else:
166-
if col_type:
167-
unnest.type = col_type.copy()
168-
# check if the result type is a STRUCT - extract struct field names
169-
if unnest.is_type(exp.DType.STRUCT):
170-
for k in unnest.type.expressions: # type: ignore
171-
columns.append(k.name)
160+
source_expr.type = element_types[0].copy()
161+
elif col_type:
162+
source_expr.type = col_type.copy()
163+
164+
columns.extend(self._struct_field_names(source_expr.type))
165+
elif isinstance(source_expr, exp.Lateral) and isinstance(
166+
source_expr.this, exp.Explode
167+
):
168+
explode_col = source_expr.this.this
169+
170+
if isinstance(explode_col, exp.Column) and source.parent:
171+
col_type = self._get_unnest_column_type(explode_col, source.parent)
172+
columns.extend(self._struct_field_names(col_type))
172173
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
173174
columns = self.get_source_columns_from_set_op(source.expression)
174175
else:
@@ -338,19 +339,27 @@ def _get_unambiguous_columns(
338339

339340
return unambiguous_columns
340341

341-
def _get_unnest_column_type(self, column: exp.Column) -> exp.DataType | None:
342+
def _struct_field_names(self, col_type: exp.DataType | None) -> list[str]:
343+
if col_type and col_type.is_type(exp.DType.ARRAY):
344+
col_type = seq_get(col_type.expressions, 0)
345+
346+
return (
347+
[k.name for k in col_type.expressions]
348+
if col_type and col_type.is_type(exp.DType.STRUCT)
349+
else []
350+
)
351+
352+
def _get_unnest_column_type(self, column: exp.Column, scope: Scope) -> exp.DataType | None:
342353
"""
343-
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
354+
Get the type of a column being unnested/exploded, tracing through CTEs/subqueries to find the base table.
344355
345356
Args:
346-
column: The column expression being unnested.
357+
column: The column expression being unnested/exploded.
358+
scope: The scope to resolve the column in.
347359
348360
Returns:
349361
The DataType of the column, or None if not found.
350362
"""
351-
scope = self.scope.parent
352-
assert scope
353-
354363
# if column is qualified, use that table, otherwise disambiguate using the resolver
355364
if column.table:
356365
table_name = column.table

tests/test_optimizer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,45 @@ def test_validate_columns(self):
651651
)
652652
optimizer.qualify_columns.validate_qualify_columns(qualified)
653653

654+
schema = {"my_table": {"items": "ARRAY<STRUCT<name STRING, age INT>>"}}
655+
expression = annotate_types(
656+
optimizer.qualify.qualify(
657+
parse_one(
658+
"SELECT ci.name, ci.age FROM my_table LATERAL VIEW EXPLODE(items) ci AS ci",
659+
read="spark",
660+
),
661+
schema=schema,
662+
dialect="spark",
663+
),
664+
schema=schema,
665+
dialect="spark",
666+
)
667+
self.assertEqual(
668+
expression.sql(dialect="spark"),
669+
"SELECT `ci`.`name` AS `name`, `ci`.`age` AS `age` FROM `my_table` AS `my_table` LATERAL VIEW EXPLODE(`my_table`.`items`) ci AS `ci`",
670+
)
671+
self.assertEqual(expression.selects[0].type, exp.DataType.build("STRING", dialect="spark"))
672+
self.assertEqual(expression.selects[1].type, exp.DataType.build("INT", dialect="spark"))
673+
674+
schema = {"my_table": {"items": "ARRAY<STRUCT<amount FLOAT, type STRING>>"}}
675+
expression = annotate_types(
676+
optimizer.qualify.qualify(
677+
parse_one(
678+
"SELECT (SELECT SUM(ci.amount) FROM my_table LATERAL VIEW EXPLODE(items) ci AS ci WHERE ci.type = 'promotion') AS total FROM my_table",
679+
read="spark",
680+
),
681+
schema=schema,
682+
dialect="spark",
683+
),
684+
schema=schema,
685+
dialect="spark",
686+
)
687+
self.assertEqual(
688+
expression.sql(dialect="spark"),
689+
"SELECT (SELECT SUM(`ci`.`amount`) AS `_col_0` FROM `my_table` AS `my_table` LATERAL VIEW EXPLODE(`my_table`.`items`) ci AS `ci` WHERE `ci`.`type` = 'promotion') AS `total` FROM `my_table` AS `my_table`",
690+
)
691+
self.assertEqual(expression.selects[0].type, exp.DataType.build("DOUBLE", dialect="spark"))
692+
654693
def test_qualify_columns__with_invisible(self):
655694
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
656695
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)

0 commit comments

Comments
 (0)