@@ -330,6 +330,50 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
330330 elif isinstance (expression , exp .Selectable ):
331331 selects [name ] = {s .alias_or_name : s .type for s in expression .selects if s .type }
332332
333+ for pivot in scope .pivots :
334+ pivot_source = scope .sources .get (pivot .alias )
335+ if not pivot_source :
336+ continue
337+
338+ inner_name = (
339+ pivot_source .name if isinstance (pivot_source , exp .Table ) else pivot .alias
340+ )
341+ col_types = dict (selects .get (inner_name , {}))
342+
343+ if pivot .unpivot :
344+ for field in pivot .fields :
345+ field_col = field .this
346+ if not isinstance (field_col , exp .Column ) or not field .expressions :
347+ continue
348+
349+ first = field .expressions [0 ]
350+ is_pivot_alias = isinstance (first , exp .PivotAlias )
351+
352+ # FOR column type from the alias literal, or VARCHAR if no alias
353+ if is_pivot_alias :
354+ alias_node = first .args .get ("alias" )
355+ if alias_node :
356+ col_types [field_col .name ] = alias_node .type
357+ else :
358+ col_types [field_col .name ] = exp .DType .VARCHAR
359+
360+ # Value column types from the IN source columns
361+ src = first .this if is_pivot_alias else first
362+ src_cols = src .expressions if isinstance (src , exp .Tuple ) else [src ]
363+ for val_expr in pivot .expressions :
364+ val_cols = (
365+ val_expr .expressions
366+ if isinstance (val_expr , exp .Tuple )
367+ else [val_expr ]
368+ )
369+ for val_col , src_col in zip (val_cols , src_cols ):
370+ src_type = col_types .get (src_col .output_name )
371+ if src_type :
372+ col_types [val_col .output_name ] = src_type
373+
374+ if col_types :
375+ selects [pivot .alias ] = col_types
376+
333377 self ._scope_selects [scope ] = selects
334378
335379 return self ._scope_selects [scope ]
@@ -419,7 +463,14 @@ def _annotate_expression(
419463 source_scope = source_scope .parent
420464
421465 if isinstance (source , exp .Table ):
422- self ._set_type (expr , self .schema .get_column_type (source , expr ))
466+ schema_type = self .schema .get_column_type (source , expr )
467+ if schema_type .is_type (exp .DType .UNKNOWN ) and source .args .get ("pivots" ):
468+ pivot_type = (
469+ self ._get_scope_selects (scope ).get (expr .table , {}).get (expr .name )
470+ )
471+ if pivot_type :
472+ schema_type = pivot_type
473+ self ._set_type (expr , schema_type )
423474 elif source and source_scope :
424475 col_type = (
425476 self ._get_scope_selects (source_scope ).get (expr .table , {}).get (expr .name )
0 commit comments