Skip to content

Commit f150b45

Browse files
committed
refactor impl
1 parent 24a25bc commit f150b45

2 files changed

Lines changed: 9 additions & 16 deletions

File tree

sqlglot-integration-tests

sqlglot/optimizer/annotate_types.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,8 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
322322
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}
323323

324324
for pivot in scope.pivots:
325-
pivot_source = scope.sources.get(pivot.alias)
326-
if not pivot_source:
327-
continue
328-
329325
inner_name = (
330-
pivot_source.name if isinstance(pivot_source, exp.Table) else pivot.alias
326+
pivot.parent.name if isinstance(pivot.parent, exp.Table) else pivot.alias
331327
)
332328

333329
col_types = selects.get(inner_name, {}).copy()
@@ -346,7 +342,7 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
346342
src = first.this
347343
else:
348344
col_types[field_col.name] = exp.DataType.build(
349-
"VARCHAR", dialect=self.dialect
345+
"TEXT", dialect=self.dialect
350346
)
351347
src = first
352348

@@ -456,15 +452,12 @@ def _annotate_expression(
456452
if not source:
457453
source_scope = source_scope.parent
458454

459-
if isinstance(source, exp.Table):
460-
schema_type = self.schema.get_column_type(source, expr)
461-
if schema_type.is_type(exp.DType.UNKNOWN) and source.args.get("pivots"):
462-
pivot_type = (
463-
self._get_scope_selects(scope).get(expr.table, {}).get(expr.name)
464-
)
465-
if pivot_type:
466-
schema_type = pivot_type
467-
self._set_type(expr, schema_type)
455+
# Pivot-indexed selects win first: they capture UNPIVOT outputs whether
456+
# or not the pivot alias made it into scope.sources.
457+
if pivot_type := self._get_scope_selects(scope).get(expr.table, {}).get(expr.name):
458+
self._set_type(expr, pivot_type)
459+
elif isinstance(source, exp.Table):
460+
self._set_type(expr, self.schema.get_column_type(source, expr))
468461
elif source and source_scope:
469462
col_type = (
470463
self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)

0 commit comments

Comments
 (0)