Skip to content

Commit c50a089

Browse files
committed
refactor impl
1 parent 0f27e51 commit c50a089

2 files changed

Lines changed: 9 additions & 16 deletions

File tree

sqlglot-integration-tests

sqlglot/optimizer/annotate_types.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,8 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
331331
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}
332332

333333
for pivot in scope.pivots:
334-
pivot_source = scope.sources.get(pivot.alias)
335-
if not pivot_source:
336-
continue
337-
338334
inner_name = (
339-
pivot_source.name if isinstance(pivot_source, exp.Table) else pivot.alias
335+
pivot.parent.name if isinstance(pivot.parent, exp.Table) else pivot.alias
340336
)
341337

342338
col_types = selects.get(inner_name, {}).copy()
@@ -355,7 +351,7 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
355351
src = first.this
356352
else:
357353
col_types[field_col.name] = exp.DataType.build(
358-
"VARCHAR", dialect=self.dialect
354+
"TEXT", dialect=self.dialect
359355
)
360356
src = first
361357

@@ -465,15 +461,12 @@ def _annotate_expression(
465461
if not source:
466462
source_scope = source_scope.parent
467463

468-
if isinstance(source, exp.Table):
469-
schema_type = self.schema.get_column_type(source, expr)
470-
if schema_type.is_type(exp.DType.UNKNOWN) and source.args.get("pivots"):
471-
pivot_type = (
472-
self._get_scope_selects(scope).get(expr.table, {}).get(expr.name)
473-
)
474-
if pivot_type:
475-
schema_type = pivot_type
476-
self._set_type(expr, schema_type)
464+
# Pivot-indexed selects win first: they capture UNPIVOT outputs whether
465+
# or not the pivot alias made it into scope.sources.
466+
if pivot_type := self._get_scope_selects(scope).get(expr.table, {}).get(expr.name):
467+
self._set_type(expr, pivot_type)
468+
elif isinstance(source, exp.Table):
469+
self._set_type(expr, self.schema.get_column_type(source, expr))
477470
elif source and source_scope:
478471
col_type = (
479472
self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)

0 commit comments

Comments
 (0)