Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlglot-integration-tests
47 changes: 46 additions & 1 deletion sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,47 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
elif isinstance(expression, exp.Selectable):
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}

for pivot in scope.pivots:
inner_name = (
pivot.parent.name if isinstance(pivot.parent, exp.Table) else pivot.alias
)

col_types = selects.get(inner_name, {}).copy()

if pivot.unpivot:
for field in pivot.fields:
field_col = field.this

first = seq_get(field.expressions, 0)

# FOR column type from the alias literal, or VARCHAR if no alias
if isinstance(first, exp.PivotAlias):
alias_node = first.args.get("alias")
if alias_node:
col_types[field_col.name] = alias_node.type
src = first.this
else:
col_types[field_col.name] = exp.DType.VARCHAR.into_expr()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we handle this differently ?

src = first

# Value column types from the IN source columns
src_cols = src.expressions if isinstance(src, exp.Tuple) else [src]
for val_expr in pivot.expressions:
val_cols = (
val_expr.expressions
if isinstance(val_expr, exp.Tuple)
else [val_expr]
)
for val_col, src_col in zip(val_cols, src_cols):
src_type = col_types.get(src_col.output_name) or src_col.type
if isinstance(src_type, exp.DataType) and not src_type.is_type(
exp.DType.UNKNOWN
):
col_types[val_col.output_name] = src_type

if col_types:
selects[pivot.alias] = col_types

self._scope_selects[scope] = selects

return self._scope_selects[scope]
Expand Down Expand Up @@ -418,7 +459,11 @@ def _annotate_expression(
if not source:
source_scope = source_scope.parent

if isinstance(source, exp.Table):
# Pivot-indexed selects win first: they capture UNPIVOT outputs whether
# or not the pivot alias made it into scope.sources.
if pivot_type := self._get_scope_selects(scope).get(expr.table, {}).get(expr.name):
self._set_type(expr, pivot_type)
elif isinstance(source, exp.Table):
self._set_type(expr, self.schema.get_column_type(source, expr))
elif source and source_scope:
col_type = (
Expand Down
Loading