Skip to content

Commit 446550f

Browse files
committed
fix(optimizer): unpivot annotate types for bq
1 parent 234198a commit 446550f

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

sqlglot-integration-tests

sqlglot/optimizer/annotate_types.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,50 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
321321
elif isinstance(expression, exp.Selectable):
322322
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}
323323

324+
for pivot in scope.pivots:
325+
pivot_source = scope.sources.get(pivot.alias)
326+
if not pivot_source:
327+
continue
328+
329+
inner_name = (
330+
pivot_source.name if isinstance(pivot_source, exp.Table) else pivot.alias
331+
)
332+
col_types = dict(selects.get(inner_name, {}))
333+
334+
if pivot.unpivot:
335+
for field in pivot.fields:
336+
field_col = field.this
337+
if not isinstance(field_col, exp.Column) or not field.expressions:
338+
continue
339+
340+
first = field.expressions[0]
341+
is_pivot_alias = isinstance(first, exp.PivotAlias)
342+
343+
# FOR column type from the alias literal, or VARCHAR if no alias
344+
if is_pivot_alias:
345+
alias_node = first.args.get("alias")
346+
if alias_node:
347+
col_types[field_col.name] = alias_node.type
348+
else:
349+
col_types[field_col.name] = exp.DType.VARCHAR
350+
351+
# Value column types from the IN source columns
352+
src = first.this if is_pivot_alias else first
353+
src_cols = src.expressions if isinstance(src, exp.Tuple) else [src]
354+
for val_expr in pivot.expressions:
355+
val_cols = (
356+
val_expr.expressions
357+
if isinstance(val_expr, exp.Tuple)
358+
else [val_expr]
359+
)
360+
for val_col, src_col in zip(val_cols, src_cols):
361+
src_type = col_types.get(src_col.output_name)
362+
if src_type:
363+
col_types[val_col.output_name] = src_type
364+
365+
if col_types:
366+
selects[pivot.alias] = col_types
367+
324368
self._scope_selects[scope] = selects
325369

326370
return self._scope_selects[scope]
@@ -410,7 +454,14 @@ def _annotate_expression(
410454
source_scope = source_scope.parent
411455

412456
if isinstance(source, exp.Table):
413-
self._set_type(expr, self.schema.get_column_type(source, expr))
457+
schema_type = self.schema.get_column_type(source, expr)
458+
if schema_type.is_type(exp.DType.UNKNOWN) and source.args.get("pivots"):
459+
pivot_type = (
460+
self._get_scope_selects(scope).get(expr.table, {}).get(expr.name)
461+
)
462+
if pivot_type:
463+
schema_type = pivot_type
464+
self._set_type(expr, schema_type)
414465
elif source and source_scope:
415466
col_type = (
416467
self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)

0 commit comments

Comments
 (0)