Skip to content

Commit cf0ffd9

Browse files
authored
fix(optimizer): qualify (UN)PIVOT on CTE sources (#7560)
* fix(optimizer): qualify (UN)PIVOT on CTE sources Fixes `SELECT * FROM cte UNPIVOT(...)` and `SELECT alias.col FROM cte UNPIVOT(...) AS alias` in bigquery and other dialects. Both cases used to error or silently skip star expansion because column resolution for a pivoted CTE went to the wrong place. Background: when a CTE is referenced with a pivot, `scope.py` stores the pivoted `exp.Table` under the pivot alias (not the CTE's `Scope`) so the pivot is treated as a new logical source. But the Table has no schema entry for the CTE name, so column lookups returned `[]`. Changes: - parser: UNPIVOT's pre-FOR value column(s) and the FOR field are now parsed as `Identifier` rather than `Column`. They're new output names, not references to existing columns. IN-list items stay as `Column` since those do reference source-table columns. PIVOT is unchanged. - optimizer/resolver: for pivoted-CTE sources (`Table` with pivots, no db qualifier, name matches a known CTE), fall back to the CTE's `Scope` to read pre-pivot columns. Enables star expansion. - optimizer/qualify_columns: when validating a column against a pivoted source, validate against the post-pivot column set rather than the pre-pivot source columns. Direct references like `u.val` pass, typos like `u.nonexistent` still error. - optimizer/qualify_columns: `_pivot_output_columns` helper factored out of `_expand_stars` and reused by the validator. Collapses ~18 lines of inline logic that computed excluded-vs-output column sets separately then combined them at the use site. - optimizer/qualify_columns: the `_unpivot_columns` filter in `validate_qualify_columns` is removed. With the parser change, unpivot output names are Identifiers and never enter `scope.unqualified_columns`, so the filter was dead code. * Add multi-value/name test
1 parent fc6e7cb commit cf0ffd9

2 files changed

Lines changed: 64 additions & 25 deletions

File tree

sqlglot/optimizer/qualify_columns.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
if t.TYPE_CHECKING:
1717
from sqlglot._typing import E
18-
from collections.abc import Iterator, Iterable
18+
from collections.abc import Iterator, Iterable, Sequence
1919

2020

2121
def qualify_columns(
@@ -181,6 +181,33 @@ def _separate_pseudocolumns(scope: Scope, pseudocolumns: set[str]) -> None:
181181
scope.clear_cache()
182182

183183

184+
def _pivot_output_columns(pivot: exp.Pivot, pre_pivot_columns: Sequence[str]) -> list[str]:
185+
"""Compute the columns exposed after a (UN)PIVOT, given its pre-pivot source columns.
186+
187+
Returns an empty list for degenerate pivots (no IN-list or no output names) so callers
188+
can fall through to their non-pivot handling.
189+
"""
190+
if pivot.unpivot:
191+
excluded = {
192+
c.output_name
193+
for field in pivot.fields
194+
if isinstance(field, exp.In)
195+
for e in field.expressions
196+
for c in e.find_all(exp.Column)
197+
}
198+
outputs = [i.name for i in _unpivot_columns(pivot)]
199+
else:
200+
excluded = {c.output_name for c in pivot.find_all(exp.Column)}
201+
outputs = [c.output_name for c in pivot.args.get("columns") or []]
202+
if not outputs:
203+
outputs = [c.alias_or_name for c in pivot.expressions]
204+
205+
if not excluded or not outputs:
206+
return []
207+
208+
return [c for c in pre_pivot_columns if c not in excluded] + outputs
209+
210+
184211
def _unpivot_columns(unpivot: exp.Pivot) -> Iterator[exp.Identifier]:
185212
name_columns = [
186213
field.this
@@ -605,7 +632,13 @@ def _qualify_columns(
605632
column_name = column.name
606633

607634
if column_table and column_table in scope.sources:
635+
column_source = scope.sources[column_table]
608636
source_columns = resolver.get_source_columns(column_table)
637+
# For pivoted sources, source_columns are pre-pivot; validate against the post-pivot set.
638+
if isinstance(column_source, exp.Table) and (
639+
pivots := column_source.args.get("pivots")
640+
):
641+
source_columns = _pivot_output_columns(pivots[0], source_columns)
609642
if (
610643
not allow_partial_qualification
611644
and source_columns
@@ -782,26 +815,7 @@ def _expand_stars(
782815
coalesced_columns = set()
783816
dialect = resolver.dialect
784817

785-
pivot_output_columns = None
786-
pivot_exclude_columns: set[str] = set()
787-
788818
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
789-
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
790-
if pivot.unpivot:
791-
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
792-
793-
for field in pivot.fields:
794-
if isinstance(field, exp.In):
795-
pivot_exclude_columns.update(
796-
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
797-
)
798-
799-
else:
800-
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
801-
802-
pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
803-
if not pivot_output_columns:
804-
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
805819

806820
if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
807821
isinstance(col, exp.Dot) for col in scope.stars
@@ -865,11 +879,7 @@ def _expand_stars(
865879
replaced_columns = replace_columns.get(table_id, {})
866880

867881
if pivot:
868-
if pivot_output_columns and pivot_exclude_columns:
869-
pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
870-
pivot_columns.extend(pivot_output_columns)
871-
else:
872-
pivot_columns = pivot.alias_column_names
882+
pivot_columns = pivot.alias_column_names or _pivot_output_columns(pivot, columns)
873883

874884
if pivot_columns:
875885
new_selections.extend(

tests/test_optimizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,35 @@ def test_qualify_columns(self, logger):
647647
"UNPIVOT(`sales` FOR `quarter` IN (`produce`.`q1`, `produce`.`q2`)) AS `produce`",
648648
)
649649

650+
self.assertEqual(
651+
optimizer.qualify.qualify(
652+
parse_one(
653+
"WITH cte AS (SELECT 1 AS a, 2 AS b, 3 AS c) "
654+
"SELECT u.val, u.name FROM cte UNPIVOT(val FOR name IN (a, b, c)) AS u"
655+
),
656+
).sql(),
657+
'WITH "cte" AS (SELECT 1 AS "a", 2 AS "b", 3 AS "c") '
658+
'SELECT "u"."val" AS "val", "u"."name" AS "name" FROM "cte" AS "cte" '
659+
'UNPIVOT("val" FOR "name" IN ("cte"."a", "cte"."b", "cte"."c")) AS "u"',
660+
)
661+
662+
self.assertEqual(
663+
optimizer.qualify.qualify(
664+
parse_one(
665+
"WITH produce AS (SELECT 'Kale' AS product, 51 AS q1, 23 AS q2, 45 AS q3, 3 AS q4) "
666+
"SELECT * FROM produce UNPIVOT((first_half, second_half) FOR semesters "
667+
"IN ((q1, q2) AS 'h1', (q3, q4) AS 'h2'))",
668+
dialect="bigquery",
669+
),
670+
dialect="bigquery",
671+
).sql(dialect="bigquery"),
672+
"WITH `produce` AS (SELECT 'Kale' AS `product`, 51 AS `q1`, 23 AS `q2`, 45 AS `q3`, 3 AS `q4`) "
673+
"SELECT `produce`.`product` AS `product`, `produce`.`semesters` AS `semesters`, "
674+
"`produce`.`first_half` AS `first_half`, `produce`.`second_half` AS `second_half` "
675+
"FROM `produce` AS `produce` UNPIVOT((`first_half`, `second_half`) FOR `semesters` "
676+
"IN ((`produce`.`q1`, `produce`.`q2`) AS 'h1', (`produce`.`q3`, `produce`.`q4`) AS 'h2')) AS `produce`",
677+
)
678+
650679
def test_validate_columns(self):
651680
with self.assertRaisesRegex(
652681
OptimizeError, "Column 'foo' could not be resolved. Line: 1, Col: 10"

0 commit comments

Comments
 (0)