Skip to content

Commit 56ed1ad

Browse files
wip: validate column names in ON CONFLICT DO UPDATE SET
- Add OnConflictClause validation in internal/sql/validate/on_conflict.go - Integrate validation into analyze.go pipeline - Add unit tests (5 cases) - Add endtoend testcase for invalid column TODO: - Test qualified column names (cart_items.col) - Test composite expressions on right side of SET - Test WHERE clause in ON CONFLICT - Test explicit schema (public.cart_items)
1 parent c048334 commit 56ed1ad

File tree

7 files changed

+271
-0
lines changed

7 files changed

+271
-0
lines changed

internal/compiler/analyze.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
152152
if err := check(err); err != nil {
153153
return nil, err
154154
}
155+
if err := check(validate.OnConflictClause(c.catalog, n, table)); err != nil {
156+
return nil, err
157+
}
155158
}
156159

157160
if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- name: AddItem :exec
2+
INSERT INTO cart_items (owner_id, product_id, price_amount, price_currency)
3+
VALUES ($1, $2, $3, $4)
4+
ON CONFLICT (owner_id, product_id) DO UPDATE
5+
SET price_amount1 = EXCLUDED.price_amount1,
6+
price_currency = EXCLUDED.price_currency;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
CREATE TABLE cart_items (
2+
owner_id VARCHAR(255) NOT NULL,
3+
product_id UUID NOT NULL,
4+
price_amount DECIMAL NOT NULL,
5+
price_currency VARCHAR(3) NOT NULL,
6+
PRIMARY KEY (owner_id, product_id)
7+
);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"version": "2",
3+
"sql": [
4+
{
5+
"engine": "postgresql",
6+
"schema": "schema.sql",
7+
"queries": "query.sql",
8+
"gen": {
9+
"go": {
10+
"package": "querytest",
11+
"out": "go",
12+
"sql_package": "pgx/v5"
13+
}
14+
}
15+
}
16+
]
17+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:5:9: column "price_amount1" of relation "cart_items" does not exist
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package validate
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
8+
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
9+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
10+
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
11+
)
12+
13+
func OnConflictClause(cat *catalog.Catalog, stmt *ast.InsertStmt, tableName *ast.TableName) error {
14+
if stmt.OnConflictClause == nil {
15+
return nil
16+
}
17+
18+
occ := stmt.OnConflictClause
19+
20+
if occ.Action != ast.OnConflictActionUpdate {
21+
return nil
22+
}
23+
24+
if tableName == nil {
25+
return nil
26+
}
27+
28+
tbl, err := cat.GetTable(tableName)
29+
if err != nil {
30+
return err
31+
}
32+
33+
relName := ""
34+
if tbl.Rel != nil {
35+
relName = tbl.Rel.Name
36+
}
37+
38+
validCols := make(map[string]struct{}, len(tbl.Columns))
39+
for _, c := range tbl.Columns {
40+
validCols[strings.ToLower(c.Name)] = struct{}{}
41+
}
42+
43+
if occ.TargetList == nil {
44+
return nil
45+
}
46+
47+
for _, item := range occ.TargetList.Items {
48+
res, ok := item.(*ast.ResTarget)
49+
if !ok {
50+
continue
51+
}
52+
53+
if res.Name != nil {
54+
colName := strings.ToLower(*res.Name)
55+
if _, exists := validCols[colName]; !exists {
56+
return &sqlerr.Error{
57+
Code: "42703",
58+
Message: fmt.Sprintf("column %q of relation %q does not exist", *res.Name, relName),
59+
Location: res.Location,
60+
}
61+
}
62+
}
63+
64+
if res.Val != nil {
65+
if err := validateExcludedRefs(res.Val, validCols, relName); err != nil {
66+
return err
67+
}
68+
}
69+
}
70+
71+
return nil
72+
}
73+
74+
func validateExcludedRefs(node ast.Node, validCols map[string]struct{}, tableName string) error {
75+
refs := astutils.Search(node, func(n ast.Node) bool {
76+
_, ok := n.(*ast.ColumnRef)
77+
return ok
78+
})
79+
80+
for _, ref := range refs.Items {
81+
colRef, ok := ref.(*ast.ColumnRef)
82+
if !ok {
83+
continue
84+
}
85+
86+
parts := make([]string, 0, len(colRef.Fields.Items))
87+
for _, field := range colRef.Fields.Items {
88+
if s, ok := field.(*ast.String); ok {
89+
parts = append(parts, s.Str)
90+
}
91+
}
92+
93+
if len(parts) == 2 && strings.ToLower(parts[0]) == "excluded" {
94+
colName := strings.ToLower(parts[1])
95+
if _, exists := validCols[colName]; !exists {
96+
return &sqlerr.Error{
97+
Code: "42703",
98+
Message: fmt.Sprintf("column %q does not exist in relation %q (via EXCLUDED)", parts[1], tableName),
99+
Location: colRef.Location,
100+
}
101+
}
102+
}
103+
}
104+
105+
return nil
106+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package validate
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
8+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
9+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
10+
)
11+
12+
func makeTestCatalog(t *testing.T) (*catalog.Catalog, *ast.TableName) {
13+
t.Helper()
14+
15+
p := postgresql.NewParser()
16+
stmts, err := p.Parse(strings.NewReader(`
17+
CREATE TABLE cart_items (
18+
owner_id VARCHAR(255) NOT NULL,
19+
product_id UUID NOT NULL,
20+
price_amount DECIMAL NOT NULL,
21+
price_currency VARCHAR(3) NOT NULL,
22+
PRIMARY KEY (owner_id, product_id)
23+
);
24+
`))
25+
if err != nil {
26+
t.Fatalf("parse schema: %v", err)
27+
}
28+
29+
cat := catalog.New("public")
30+
for _, stmt := range stmts {
31+
if err := cat.Update(stmt, nil); err != nil {
32+
t.Fatalf("update catalog: %v", err)
33+
}
34+
}
35+
36+
tableName := &ast.TableName{Schema: "public", Name: "cart_items"}
37+
return cat, tableName
38+
}
39+
40+
func makeStmt(action ast.OnConflictAction, setItems []struct{ col, val string }) *ast.InsertStmt {
41+
stmt := &ast.InsertStmt{
42+
Relation: &ast.RangeVar{
43+
Schemaname: strPtr("public"),
44+
Relname: strPtr("cart_items"),
45+
},
46+
}
47+
48+
if action == ast.OnConflictActionNone {
49+
return stmt
50+
}
51+
52+
items := make([]ast.Node, 0, len(setItems))
53+
for _, si := range setItems {
54+
colName := si.col
55+
items = append(items, &ast.ResTarget{
56+
Name: &colName,
57+
Val: &ast.ColumnRef{
58+
Fields: &ast.List{
59+
Items: []ast.Node{
60+
&ast.String{Str: "excluded"},
61+
&ast.String{Str: si.val},
62+
},
63+
},
64+
},
65+
})
66+
}
67+
68+
stmt.OnConflictClause = &ast.OnConflictClause{
69+
Action: action,
70+
TargetList: &ast.List{Items: items},
71+
}
72+
return stmt
73+
}
74+
75+
func strPtr(s string) *string { return &s }
76+
77+
func TestOnConflictClause(t *testing.T) {
78+
cat, tableName := makeTestCatalog(t)
79+
80+
tests := []struct {
81+
name string
82+
stmt *ast.InsertStmt
83+
wantErr bool
84+
}{
85+
{
86+
name: "valid columns in SET and EXCLUDED",
87+
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
88+
{"price_amount", "price_amount"},
89+
{"price_currency", "price_currency"},
90+
}),
91+
wantErr: false,
92+
},
93+
{
94+
name: "invalid column on left side of SET",
95+
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
96+
{"price_amount1", "price_amount"},
97+
}),
98+
wantErr: true,
99+
},
100+
{
101+
name: "invalid EXCLUDED reference on right side",
102+
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
103+
{"price_amount", "price_amount1"},
104+
}),
105+
wantErr: true,
106+
},
107+
{
108+
name: "DO NOTHING skips column validation",
109+
stmt: makeStmt(ast.OnConflictActionNothing, nil),
110+
wantErr: false,
111+
},
112+
{
113+
name: "no OnConflictClause passes without error",
114+
stmt: makeStmt(ast.OnConflictActionNone, nil),
115+
wantErr: false,
116+
},
117+
}
118+
119+
for _, tt := range tests {
120+
t.Run(tt.name, func(t *testing.T) {
121+
err := OnConflictClause(cat, tt.stmt, tableName)
122+
if tt.wantErr && err == nil {
123+
t.Error("expected error but got none")
124+
}
125+
if !tt.wantErr && err != nil {
126+
t.Errorf("unexpected error: %v", err)
127+
}
128+
})
129+
}
130+
}

0 commit comments

Comments
 (0)