Skip to content

Commit fca278d

Browse files
committed
validateOnConflictTypes
1 parent b7c9448 commit fca278d

File tree

5 files changed

+168
-3
lines changed

5 files changed

+168
-3
lines changed

internal/compiler/analyze.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package compiler
22

33
import (
4+
"fmt"
45
"sort"
56

67
analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
78
"github.com/sqlc-dev/sqlc/internal/config"
9+
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
810
"github.com/sqlc-dev/sqlc/internal/source"
911
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1012
"github.com/sqlc-dev/sqlc/internal/sql/named"
@@ -143,6 +145,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
143145
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar)
144146

145147
var table *ast.TableName
148+
var insertStmt *ast.InsertStmt
146149
switch n := raw.Stmt.(type) {
147150
case *ast.InsertStmt:
148151
if err := check(validate.InsertStmt(n)); err != nil {
@@ -156,6 +159,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
156159
if err := check(c.validateOnConflictColumns(n)); err != nil {
157160
return nil, err
158161
}
162+
insertStmt = n
159163
}
160164

161165
if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
@@ -189,6 +193,11 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
189193
if err := check(err); err != nil {
190194
return nil, err
191195
}
196+
if c.conf.Engine == config.EnginePostgreSQL {
197+
if err := check(c.validateOnConflictTypes(insertStmt, params)); err != nil {
198+
return nil, err
199+
}
200+
}
192201
cols, err := c.outputColumns(qc, raw.Stmt)
193202
if err := check(err); err != nil {
194203
return nil, err
@@ -281,6 +290,79 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error {
281290
return nil
282291
}
283292

293+
// validateOnConflictTypes checks that $N params used in DO UPDATE SET assignments
294+
// are type-compatible with the target column, based on the type already resolved
295+
// for that param from the INSERT columns.
296+
func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter) error {
297+
if n == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
298+
return nil
299+
}
300+
fqn, err := ParseTableName(n.Relation)
301+
if err != nil {
302+
return err
303+
}
304+
table, err := c.catalog.GetTable(fqn)
305+
if err != nil {
306+
return err
307+
}
308+
309+
// Build param number → resolved DataType string from already-resolved params.
310+
// Skips params with "any" type (unresolved).
311+
paramDataTypes := make(map[int]string, len(params))
312+
for i := range params {
313+
if params[i].Column != nil && params[i].Column.DataType != "any" {
314+
paramDataTypes[params[i].Number] = params[i].Column.DataType
315+
}
316+
}
317+
318+
// Build column name → DataType string using the same dataType() function
319+
// used by resolveCatalogRefs, so formats are comparable.
320+
colDataTypes := make(map[string]string, len(table.Columns))
321+
for _, col := range table.Columns {
322+
colDataTypes[col.Name] = dataType(&col.Type)
323+
}
324+
325+
for _, item := range n.OnConflictClause.TargetList.Items {
326+
target, ok := item.(*ast.ResTarget)
327+
if !ok || target.Name == nil {
328+
continue
329+
}
330+
colDT, ok := colDataTypes[*target.Name]
331+
if !ok {
332+
continue
333+
}
334+
switch val := target.Val.(type) {
335+
case *ast.ParamRef:
336+
paramDT, ok := paramDataTypes[val.Number]
337+
if !ok {
338+
continue
339+
}
340+
if postgresql.TypeFamily(paramDT) != postgresql.TypeFamily(colDT) {
341+
return &sqlerr.Error{
342+
Message: fmt.Sprintf("parameter $%d has type %q but column %q has type %q", val.Number, paramDT, *target.Name, colDT),
343+
Location: val.Location,
344+
}
345+
}
346+
case *ast.ColumnRef:
347+
excludedCol, ok := excludedColumn(val)
348+
if !ok {
349+
continue
350+
}
351+
excludedDT, ok := colDataTypes[excludedCol]
352+
if !ok {
353+
continue
354+
}
355+
if postgresql.TypeFamily(excludedDT) != postgresql.TypeFamily(colDT) {
356+
return &sqlerr.Error{
357+
Message: fmt.Sprintf("EXCLUDED.%s has type %q but column %q has type %q", excludedCol, excludedDT, *target.Name, colDT),
358+
Location: val.Location,
359+
}
360+
}
361+
}
362+
}
363+
return nil
364+
}
365+
284366
// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col
285367
// reference, and ok=true. Returns "", false otherwise.
286368
func excludedColumn(ref *ast.ColumnRef) (string, bool) {

internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,14 @@ DO UPDATE SET name = 1111;
1212
INSERT INTO servers(code, name) VALUES ($1, $2)
1313
ON CONFLICT (code)
1414
DO UPDATE SET name = EXCLUDED.name_typo;
15+
16+
-- name: UpsertServerSetParamTypeMismatch :exec
17+
INSERT INTO servers(code, name) VALUES ($1, $2)
18+
ON CONFLICT (code)
19+
DO UPDATE SET count = $2;
20+
21+
-- name: UpsertServerExcludedTypeMismatch :exec
22+
INSERT INTO servers(code, name, count) VALUES ($1, $2, $3)
23+
ON CONFLICT (code)
24+
DO UPDATE SET count = EXCLUDED.code;
25+
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
CREATE TABLE servers (
2-
code varchar PRIMARY KEY,
3-
name text NOT NULL
4-
);
2+
code varchar PRIMARY KEY,
3+
name text NOT NULL,
4+
count integer NOT NULL DEFAULT 0
5+
);

internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
query.sql:4:15: column "name_typo" of relation "servers" does not exist
33
query.sql:8:13: column "code_typo" of relation "servers" does not exist
44
query.sql:14:22: column "name_typo" of relation "servers" does not exist
5+
query.sql:19:23: parameter $2 has type "text" but column "count" has type "pg_catalog.int4"
6+
query.sql:24:23: EXCLUDED.code has type "pg_catalog.varchar" but column "count" has type "pg_catalog.int4"

internal/engine/postgresql/utils.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,75 @@ func IsNamedParamSign(node *nodes.Node) bool {
3939
return ok && joinNodes(expr.AExpr.Name, ".") == "@"
4040
}
4141

42+
// TypeFamily maps a PostgreSQL DataType string to a canonical type family name,
43+
// grouping compatible type aliases together. This is used for type compatibility
44+
// checks rather than exact string equality, because PostgreSQL considers many
45+
// type aliases assignment-compatible (e.g. text and varchar are both string types).
46+
//
47+
// The groupings are derived from postgresType() in
48+
// internal/codegen/golang/postgresql_type.go, which maps these aliases to the
49+
// same Go type. We cannot call postgresType() directly for type compatibility
50+
// checking because it requires *plugin.GenerateRequest — a protobuf codegen
51+
// struct constructed after compilation — and driver-specific opts.Options.
52+
func TypeFamily(dt string) string {
53+
switch dt {
54+
case "serial", "serial4", "pg_catalog.serial4",
55+
"integer", "int", "int4", "pg_catalog.int4":
56+
return "int32"
57+
case "bigserial", "serial8", "pg_catalog.serial8",
58+
"bigint", "int8", "pg_catalog.int8",
59+
"interval", "pg_catalog.interval":
60+
return "int64"
61+
case "smallserial", "serial2", "pg_catalog.serial2",
62+
"smallint", "int2", "pg_catalog.int2":
63+
return "int16"
64+
case "float", "double precision", "float8", "pg_catalog.float8":
65+
return "float64"
66+
case "real", "float4", "pg_catalog.float4":
67+
return "float32"
68+
case "numeric", "pg_catalog.numeric", "money":
69+
return "numeric"
70+
case "boolean", "bool", "pg_catalog.bool":
71+
return "bool"
72+
case "json", "pg_catalog.json":
73+
return "json"
74+
case "jsonb", "pg_catalog.jsonb":
75+
return "jsonb"
76+
case "bytea", "blob", "pg_catalog.bytea":
77+
return "bytes"
78+
case "date":
79+
return "date"
80+
case "pg_catalog.time":
81+
return "time"
82+
case "pg_catalog.timetz":
83+
return "timetz"
84+
case "pg_catalog.timestamp", "timestamp":
85+
return "timestamp"
86+
case "pg_catalog.timestamptz", "timestamptz":
87+
return "timestamptz"
88+
case "text", "pg_catalog.varchar", "pg_catalog.bpchar",
89+
"string", "citext", "name",
90+
"ltree", "lquery", "ltxtquery":
91+
return "text"
92+
case "uuid":
93+
return "uuid"
94+
case "inet":
95+
return "inet"
96+
case "cidr":
97+
return "cidr"
98+
case "macaddr", "macaddr8":
99+
return "macaddr"
100+
case "bit", "varbit", "pg_catalog.bit", "pg_catalog.varbit":
101+
return "bits"
102+
case "hstore":
103+
return "hstore"
104+
case "vector":
105+
return "vector"
106+
default:
107+
return dt
108+
}
109+
}
110+
42111
func makeByte(s string) byte {
43112
var b byte
44113
if s == "" {

0 commit comments

Comments
 (0)