@@ -321,6 +321,50 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
321321 elif isinstance (expression , exp .Selectable ):
322322 selects [name ] = {s .alias_or_name : s .type for s in expression .selects if s .type }
323323
324+ for pivot in scope .pivots :
325+ pivot_source = scope .sources .get (pivot .alias )
326+ if not pivot_source :
327+ continue
328+
329+ inner_name = (
330+ pivot_source .name if isinstance (pivot_source , exp .Table ) else pivot .alias
331+ )
332+ col_types = dict (selects .get (inner_name , {}))
333+
334+ if pivot .unpivot :
335+ for field in pivot .fields :
336+ field_col = field .this
337+ if not isinstance (field_col , exp .Column ) or not field .expressions :
338+ continue
339+
340+ first = field .expressions [0 ]
341+ is_pivot_alias = isinstance (first , exp .PivotAlias )
342+
343+ # FOR column type from the alias literal, or VARCHAR if no alias
344+ if is_pivot_alias :
345+ alias_node = first .args .get ("alias" )
346+ if alias_node :
347+ col_types [field_col .name ] = alias_node .type
348+ else :
349+ col_types [field_col .name ] = exp .DType .VARCHAR
350+
351+ # Value column types from the IN source columns
352+ src = first .this if is_pivot_alias else first
353+ src_cols = src .expressions if isinstance (src , exp .Tuple ) else [src ]
354+ for val_expr in pivot .expressions :
355+ val_cols = (
356+ val_expr .expressions
357+ if isinstance (val_expr , exp .Tuple )
358+ else [val_expr ]
359+ )
360+ for val_col , src_col in zip (val_cols , src_cols ):
361+ src_type = col_types .get (src_col .output_name )
362+ if src_type :
363+ col_types [val_col .output_name ] = src_type
364+
365+ if col_types :
366+ selects [pivot .alias ] = col_types
367+
324368 self ._scope_selects [scope ] = selects
325369
326370 return self ._scope_selects [scope ]
@@ -410,7 +454,14 @@ def _annotate_expression(
410454 source_scope = source_scope .parent
411455
412456 if isinstance (source , exp .Table ):
413- self ._set_type (expr , self .schema .get_column_type (source , expr ))
457+ schema_type = self .schema .get_column_type (source , expr )
458+ if schema_type .is_type (exp .DType .UNKNOWN ) and source .args .get ("pivots" ):
459+ pivot_type = (
460+ self ._get_scope_selects (scope ).get (expr .table , {}).get (expr .name )
461+ )
462+ if pivot_type :
463+ schema_type = pivot_type
464+ self ._set_type (expr , schema_type )
414465 elif source and source_scope :
415466 col_type = (
416467 self ._get_scope_selects (source_scope ).get (expr .table , {}).get (expr .name )
0 commit comments