Skip to content

Commit 7f6b186

Browse files
authored
Catch invalid ON CONFLICT DO UPDATE column references (#4366)
* validateOnConflictColumns * validateOnConflictTypes * validateOnConflictClause * remove validateOnConflictSetTypes * move to validate * params order * ON CONSTRAINT
1 parent 4bf2159 commit 7f6b186

File tree

6 files changed

+138
-13
lines changed

6 files changed

+138
-13
lines changed

internal/compiler/analyze.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
144144
var table *ast.TableName
145145
switch n := raw.Stmt.(type) {
146146
case *ast.InsertStmt:
147-
if err := check(validate.InsertStmt(n)); err != nil {
148-
return nil, err
149-
}
150147
var err error
151148
table, err = ParseTableName(n.Relation)
152149
if err := check(err); err != nil {
153150
return nil, err
154151
}
152+
if err := check(validate.InsertStmt(c.catalog, table, n)); err != nil {
153+
return nil, err
154+
}
155155
}
156156

157157
if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"contexts": ["managed-db"]
2+
"contexts": ["base"]
33
}
Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
1-
-- name: UpsertServer :exec
2-
INSERT INTO servers(code, name) VALUES ($1, $2)
3-
ON CONFLICT (code)
4-
DO UPDATE SET name_typo = 1111;
1+
-- name: UpsertServerSetColumnTypo :exec
2+
INSERT INTO servers(code, name) VALUES ($1, $2)
3+
ON CONFLICT (code)
4+
DO UPDATE SET name_typo = 1111;
5+
6+
-- name: UpsertServerConflictTargetTypo :exec
7+
INSERT INTO servers(code, name) VALUES ($1, $2)
8+
ON CONFLICT (code_typo)
9+
DO UPDATE SET name = 1111;
10+
11+
-- name: UpsertServerExcludedColumnTypo :exec
12+
INSERT INTO servers(code, name) VALUES ($1, $2)
13+
ON CONFLICT (code)
14+
DO UPDATE SET name = EXCLUDED.name_typo;
15+
16+
-- name: UpsertServerMissingConflictTarget :exec
17+
INSERT INTO servers(code, name) VALUES ($1, $2)
18+
ON CONFLICT DO UPDATE SET name = EXCLUDED.name;
19+
20+
-- name: UpsertServerOnConstraintExcludedTypo :exec
21+
INSERT INTO servers(code, name) VALUES ($1, $2)
22+
ON CONFLICT ON CONSTRAINT servers_pkey DO UPDATE SET name = EXCLUDED.name_typo;
23+
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
CREATE TABLE servers (
2-
code varchar PRIMARY KEY,
3-
name text NOT NULL
4-
);
2+
code varchar PRIMARY KEY,
3+
name text NOT NULL
4+
);
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
# package querytest
2-
query.sql:4:15: column "name_typo" of relation "servers" does not exist
2+
query.sql:4:15: column "name_typo" of relation "servers" does not exist
3+
query.sql:8:13: column "code_typo" of relation "servers" does not exist
4+
query.sql:14:22: column "name_typo" of relation "EXCLUDED" does not exist
5+
query.sql:17:1: ON CONFLICT DO UPDATE requires inference specification or constraint name
6+
query.sql:22:61: column "name_typo" of relation "EXCLUDED" does not exist

internal/sql/validate/insert_stmt.go

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package validate
22

33
import (
4+
"strings"
5+
46
"github.com/sqlc-dev/sqlc/internal/sql/ast"
7+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
58
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
69
)
710

8-
func InsertStmt(stmt *ast.InsertStmt) error {
11+
const excludedTable = "EXCLUDED"
12+
13+
func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) error {
914
sel, ok := stmt.SelectStmt.(*ast.SelectStmt)
1015
if !ok {
1116
return nil
@@ -35,5 +40,102 @@ func InsertStmt(stmt *ast.InsertStmt) error {
3540
Message: "INSERT has more expressions than target columns",
3641
}
3742
}
43+
44+
return onConflictClause(c, fqn, stmt)
45+
}
46+
47+
// onConflictClause validates an ON CONFLICT DO UPDATE clause against the target
48+
// table. It checks:
49+
// - ON CONFLICT (col, ...) conflict target columns exist
50+
// - DO UPDATE SET col = ... assignment target columns exist
51+
// - EXCLUDED.col references exist
52+
func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) error {
53+
if fqn == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
54+
return nil
55+
}
56+
57+
table, err := c.GetTable(fqn)
58+
if err != nil {
59+
return err
60+
}
61+
62+
// Build set of column names for existence checks.
63+
colNames := make(map[string]struct{}, len(table.Columns))
64+
for _, col := range table.Columns {
65+
colNames[col.Name] = struct{}{}
66+
}
67+
68+
// DO UPDATE requires a conflict target: ON CONFLICT (col) or ON CONFLICT ON CONSTRAINT name.
69+
if n.OnConflictClause.Infer == nil {
70+
return &sqlerr.Error{
71+
Code: "42601",
72+
Message: "ON CONFLICT DO UPDATE requires inference specification or constraint name",
73+
}
74+
}
75+
76+
// Validate ON CONFLICT (col, ...) conflict target columns.
77+
if n.OnConflictClause.Infer.IndexElems != nil {
78+
for _, item := range n.OnConflictClause.Infer.IndexElems.Items {
79+
elem, ok := item.(*ast.IndexElem)
80+
if !ok || elem.Name == nil {
81+
continue
82+
}
83+
84+
if _, exists := colNames[*elem.Name]; !exists {
85+
e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name)
86+
e.Location = n.OnConflictClause.Infer.Location
87+
return e
88+
}
89+
}
90+
}
91+
92+
// Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references.
93+
if n.OnConflictClause.TargetList == nil {
94+
return nil
95+
}
96+
97+
for _, item := range n.OnConflictClause.TargetList.Items {
98+
target, ok := item.(*ast.ResTarget)
99+
if !ok || target.Name == nil {
100+
continue
101+
}
102+
103+
if _, exists := colNames[*target.Name]; !exists {
104+
e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name)
105+
e.Location = target.Location
106+
return e
107+
}
108+
109+
if ref, ok := target.Val.(*ast.ColumnRef); ok {
110+
if excludedCol, ok := excludedColumnRef(ref); ok {
111+
if _, exists := colNames[excludedCol]; !exists {
112+
e := sqlerr.ColumnNotFound(excludedTable, excludedCol)
113+
e.Location = ref.Location
114+
return e
115+
}
116+
}
117+
}
118+
}
119+
38120
return nil
39121
}
122+
123+
// excludedColumnRef returns the column name if the ColumnRef is an EXCLUDED.col
124+
// reference, and ok=true. Returns "", false otherwise.
125+
func excludedColumnRef(ref *ast.ColumnRef) (string, bool) {
126+
if ref.Fields == nil || len(ref.Fields.Items) != 2 {
127+
return "", false
128+
}
129+
130+
first, ok := ref.Fields.Items[0].(*ast.String)
131+
if !ok || !strings.EqualFold(first.Str, excludedTable) {
132+
return "", false
133+
}
134+
135+
second, ok := ref.Fields.Items[1].(*ast.String)
136+
if !ok {
137+
return "", false
138+
}
139+
140+
return second.Str, true
141+
}

0 commit comments

Comments
 (0)