Skip to content

Commit 6fc5418

Browse files
committed
fix(optimizer): unpivot annotate types for bq
1 parent cf0ffd9 commit 6fc5418

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

sqlglot-integration-tests

sqlglot/optimizer/annotate_types.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,50 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
330330
elif isinstance(expression, exp.Selectable):
331331
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}
332332

333+
for pivot in scope.pivots:
334+
pivot_source = scope.sources.get(pivot.alias)
335+
if not pivot_source:
336+
continue
337+
338+
inner_name = (
339+
pivot_source.name if isinstance(pivot_source, exp.Table) else pivot.alias
340+
)
341+
col_types = dict(selects.get(inner_name, {}))
342+
343+
if pivot.unpivot:
344+
for field in pivot.fields:
345+
field_col = field.this
346+
if not isinstance(field_col, exp.Column) or not field.expressions:
347+
continue
348+
349+
first = field.expressions[0]
350+
is_pivot_alias = isinstance(first, exp.PivotAlias)
351+
352+
# FOR column type from the alias literal, or VARCHAR if no alias
353+
if is_pivot_alias:
354+
alias_node = first.args.get("alias")
355+
if alias_node:
356+
col_types[field_col.name] = alias_node.type
357+
else:
358+
col_types[field_col.name] = exp.DType.VARCHAR
359+
360+
# Value column types from the IN source columns
361+
src = first.this if is_pivot_alias else first
362+
src_cols = src.expressions if isinstance(src, exp.Tuple) else [src]
363+
for val_expr in pivot.expressions:
364+
val_cols = (
365+
val_expr.expressions
366+
if isinstance(val_expr, exp.Tuple)
367+
else [val_expr]
368+
)
369+
for val_col, src_col in zip(val_cols, src_cols):
370+
src_type = col_types.get(src_col.output_name)
371+
if src_type:
372+
col_types[val_col.output_name] = src_type
373+
374+
if col_types:
375+
selects[pivot.alias] = col_types
376+
333377
self._scope_selects[scope] = selects
334378

335379
return self._scope_selects[scope]
@@ -419,7 +463,14 @@ def _annotate_expression(
419463
source_scope = source_scope.parent
420464

421465
if isinstance(source, exp.Table):
422-
self._set_type(expr, self.schema.get_column_type(source, expr))
466+
schema_type = self.schema.get_column_type(source, expr)
467+
if schema_type.is_type(exp.DType.UNKNOWN) and source.args.get("pivots"):
468+
pivot_type = (
469+
self._get_scope_selects(scope).get(expr.table, {}).get(expr.name)
470+
)
471+
if pivot_type:
472+
schema_type = pivot_type
473+
self._set_type(expr, schema_type)
423474
elif source and source_scope:
424475
col_type = (
425476
self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)

0 commit comments

Comments
 (0)