11package compiler
22
33import (
4+ "fmt"
45 "sort"
56
67 analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
78 "github.com/sqlc-dev/sqlc/internal/config"
9+ "github.com/sqlc-dev/sqlc/internal/engine/postgresql"
810 "github.com/sqlc-dev/sqlc/internal/source"
911 "github.com/sqlc-dev/sqlc/internal/sql/ast"
1012 "github.com/sqlc-dev/sqlc/internal/sql/named"
@@ -143,6 +145,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
143145 raw , namedParams , edits := rewrite .NamedParameters (c .conf .Engine , raw , numbers , dollar )
144146
145147 var table * ast.TableName
148+ var insertStmt * ast.InsertStmt
146149 switch n := raw .Stmt .(type ) {
147150 case * ast.InsertStmt :
148151 if err := check (validate .InsertStmt (n )); err != nil {
@@ -156,6 +159,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
156159 if err := check (c .validateOnConflictColumns (n )); err != nil {
157160 return nil , err
158161 }
162+ insertStmt = n
159163 }
160164
161165 if err := check (validate .FuncCall (c .catalog , c .combo , raw )); err != nil {
@@ -189,6 +193,11 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
189193 if err := check (err ); err != nil {
190194 return nil , err
191195 }
196+ if c .conf .Engine == config .EnginePostgreSQL {
197+ if err := check (c .validateOnConflictTypes (insertStmt , params )); err != nil {
198+ return nil , err
199+ }
200+ }
192201 cols , err := c .outputColumns (qc , raw .Stmt )
193202 if err := check (err ); err != nil {
194203 return nil , err
@@ -281,6 +290,79 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error {
281290 return nil
282291}
283292
293+ // validateOnConflictTypes checks that $N params used in DO UPDATE SET assignments
294+ // are type-compatible with the target column, based on the type already resolved
295+ // for that param from the INSERT columns.
296+ func (c * Compiler ) validateOnConflictTypes (n * ast.InsertStmt , params []Parameter ) error {
297+ if n == nil || n .OnConflictClause == nil || n .OnConflictClause .Action != ast .OnConflictActionUpdate {
298+ return nil
299+ }
300+ fqn , err := ParseTableName (n .Relation )
301+ if err != nil {
302+ return err
303+ }
304+ table , err := c .catalog .GetTable (fqn )
305+ if err != nil {
306+ return err
307+ }
308+
309+ // Build param number → resolved DataType string from already-resolved params.
310+ // Skips params with "any" type (unresolved).
311+ paramDataTypes := make (map [int ]string , len (params ))
312+ for i := range params {
313+ if params [i ].Column != nil && params [i ].Column .DataType != "any" {
314+ paramDataTypes [params [i ].Number ] = params [i ].Column .DataType
315+ }
316+ }
317+
318+ // Build column name → DataType string using the same dataType() function
319+ // used by resolveCatalogRefs, so formats are comparable.
320+ colDataTypes := make (map [string ]string , len (table .Columns ))
321+ for _ , col := range table .Columns {
322+ colDataTypes [col .Name ] = dataType (& col .Type )
323+ }
324+
325+ for _ , item := range n .OnConflictClause .TargetList .Items {
326+ target , ok := item .(* ast.ResTarget )
327+ if ! ok || target .Name == nil {
328+ continue
329+ }
330+ colDT , ok := colDataTypes [* target .Name ]
331+ if ! ok {
332+ continue
333+ }
334+ switch val := target .Val .(type ) {
335+ case * ast.ParamRef :
336+ paramDT , ok := paramDataTypes [val .Number ]
337+ if ! ok {
338+ continue
339+ }
340+ if postgresql .TypeFamily (paramDT ) != postgresql .TypeFamily (colDT ) {
341+ return & sqlerr.Error {
342+ Message : fmt .Sprintf ("parameter $%d has type %q but column %q has type %q" , val .Number , paramDT , * target .Name , colDT ),
343+ Location : val .Location ,
344+ }
345+ }
346+ case * ast.ColumnRef :
347+ excludedCol , ok := excludedColumn (val )
348+ if ! ok {
349+ continue
350+ }
351+ excludedDT , ok := colDataTypes [excludedCol ]
352+ if ! ok {
353+ continue
354+ }
355+ if postgresql .TypeFamily (excludedDT ) != postgresql .TypeFamily (colDT ) {
356+ return & sqlerr.Error {
357+ Message : fmt .Sprintf ("EXCLUDED.%s has type %q but column %q has type %q" , excludedCol , excludedDT , * target .Name , colDT ),
358+ Location : val .Location ,
359+ }
360+ }
361+ }
362+ }
363+ return nil
364+ }
365+
284366// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col
285367// reference, and ok=true. Returns "", false otherwise.
286368func excludedColumn (ref * ast.ColumnRef ) (string , bool ) {
0 commit comments