diff --git a/sqlglot-integration-tests b/sqlglot-integration-tests index 28ddb0cda5..aebec88e8a 160000 --- a/sqlglot-integration-tests +++ b/sqlglot-integration-tests @@ -1 +1 @@ -Subproject commit 28ddb0cda5c17d0f5f27d4940d3cee12bd80871f +Subproject commit aebec88e8a966be481c083b813bb9ecf67ed3eea diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 0e6b1c407c..740cd69db7 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -330,6 +330,47 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]: elif isinstance(expression, exp.Selectable): selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type} + for pivot in scope.pivots: + inner_name = ( + pivot.parent.name if isinstance(pivot.parent, exp.Table) else pivot.alias + ) + + col_types = selects.get(inner_name, {}).copy() + + if pivot.unpivot: + for field in pivot.fields: + field_col = field.this + + first = seq_get(field.expressions, 0) + + # FOR column type from the alias literal, or VARCHAR if no alias + if isinstance(first, exp.PivotAlias): + alias_node = first.args.get("alias") + if alias_node: + col_types[field_col.name] = alias_node.type + src = first.this + else: + col_types[field_col.name] = exp.DType.VARCHAR.into_expr() + src = first + + # Value column types from the IN source columns + src_cols = src.expressions if isinstance(src, exp.Tuple) else [src] + for val_expr in pivot.expressions: + val_cols = ( + val_expr.expressions + if isinstance(val_expr, exp.Tuple) + else [val_expr] + ) + for val_col, src_col in zip(val_cols, src_cols): + src_type = col_types.get(src_col.output_name) or src_col.type + if isinstance(src_type, exp.DataType) and not src_type.is_type( + exp.DType.UNKNOWN + ): + col_types[val_col.output_name] = src_type + + if col_types: + selects[pivot.alias] = col_types + self._scope_selects[scope] = selects return self._scope_selects[scope] @@ -418,7 +459,11 @@ def _annotate_expression( if not source: source_scope = source_scope.parent - if isinstance(source, exp.Table): + # Pivot-indexed selects win first: they capture UNPIVOT outputs whether + # or not the pivot alias made it into scope.sources. + if pivot_type := self._get_scope_selects(scope).get(expr.table, {}).get(expr.name): + self._set_type(expr, pivot_type) + elif isinstance(source, exp.Table): self._set_type(expr, self.schema.get_column_type(source, expr)) elif source and source_scope: col_type = (