Skip to content

Commit fe19865

Browse files
committed
validateOnConflictClause
1 parent fca278d commit fe19865

File tree

1 file changed

+33
-47
lines changed

1 file changed

+33
-47
lines changed

internal/compiler/analyze.go

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
145145
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar)
146146

147147
var table *ast.TableName
148-
var insertStmt *ast.InsertStmt
149148
switch n := raw.Stmt.(type) {
150149
case *ast.InsertStmt:
151150
if err := check(validate.InsertStmt(n)); err != nil {
@@ -156,10 +155,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
156155
if err := check(err); err != nil {
157156
return nil, err
158157
}
159-
if err := check(c.validateOnConflictColumns(n)); err != nil {
160-
return nil, err
161-
}
162-
insertStmt = n
163158
}
164159

165160
if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
@@ -193,8 +188,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
193188
if err := check(err); err != nil {
194189
return nil, err
195190
}
196-
if c.conf.Engine == config.EnginePostgreSQL {
197-
if err := check(c.validateOnConflictTypes(insertStmt, params)); err != nil {
191+
if n, ok := raw.Stmt.(*ast.InsertStmt); ok {
192+
if err := check(c.validateOnConflictClause(n, params)); err != nil {
198193
return nil, err
199194
}
200195
}
@@ -227,12 +222,13 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
227222
}, rerr
228223
}
229224

230-
// validateOnConflictColumns checks column names in an ON CONFLICT DO UPDATE
231-
// clause against the target table:
232-
// - ON CONFLICT (col, ...) conflict target columns
233-
// - DO UPDATE SET col = ... assignment target columns
234-
// - EXCLUDED.col references in assignment values
235-
func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error {
225+
// validateOnConflictClause validates an ON CONFLICT DO UPDATE clause against
226+
// the target table. It checks:
227+
// - ON CONFLICT (col, ...) conflict target columns exist
228+
// - DO UPDATE SET col = ... assignment target columns exist
229+
// - EXCLUDED.col references exist
230+
// - For PostgreSQL: $N param and EXCLUDED.col type compatibility with SET target
231+
func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt, params []Parameter) error {
236232
if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
237233
return nil
238234
}
@@ -244,9 +240,11 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error {
244240
if err != nil {
245241
return err
246242
}
247-
colSet := make(map[string]struct{}, len(table.Columns))
243+
244+
// Build column name → DataType from catalog for existence and type checks.
245+
colDataTypes := make(map[string]string, len(table.Columns))
248246
for _, col := range table.Columns {
249-
colSet[col.Name] = struct{}{}
247+
colDataTypes[col.Name] = dataType(&col.Type)
250248
}
251249

252250
// Validate ON CONFLICT (col, ...) conflict target columns.
@@ -256,57 +254,52 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error {
256254
if !ok || elem.Name == nil {
257255
continue
258256
}
259-
if _, exists := colSet[*elem.Name]; !exists {
257+
if _, exists := colDataTypes[*elem.Name]; !exists {
260258
e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name)
261259
e.Location = n.OnConflictClause.Infer.Location
262260
return e
263261
}
264262
}
265263
}
266264

267-
// Validate DO UPDATE SET col = ... and EXCLUDED.col references.
265+
// Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references.
268266
for _, item := range n.OnConflictClause.TargetList.Items {
269267
target, ok := item.(*ast.ResTarget)
270268
if !ok || target.Name == nil {
271269
continue
272270
}
273-
// Validate the assignment target column.
274-
if _, exists := colSet[*target.Name]; !exists {
271+
if _, exists := colDataTypes[*target.Name]; !exists {
275272
e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name)
276273
e.Location = target.Location
277274
return e
278275
}
279-
// Validate EXCLUDED.col references in the assigned value.
280276
if ref, ok := target.Val.(*ast.ColumnRef); ok {
281-
if col, ok := excludedColumn(ref); ok {
282-
if _, exists := colSet[col]; !exists {
283-
e := sqlerr.ColumnNotFound(table.Rel.Name, col)
277+
if excludedCol, ok := excludedColumn(ref); ok {
278+
if _, exists := colDataTypes[excludedCol]; !exists {
279+
e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol)
284280
e.Location = ref.Location
285281
return e
286282
}
287283
}
288284
}
289285
}
290-
return nil
291-
}
292286

293-
// validateOnConflictTypes checks that $N params used in DO UPDATE SET assignments
294-
// are type-compatible with the target column, based on the type already resolved
295-
// for that param from the INSERT columns.
296-
func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter) error {
297-
if n == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
298-
return nil
299-
}
300-
fqn, err := ParseTableName(n.Relation)
301-
if err != nil {
302-
return err
303-
}
304-
table, err := c.catalog.GetTable(fqn)
305-
if err != nil {
306-
return err
287+
// Type compatibility checks (PostgreSQL only).
288+
// To remove type checking: delete this block and validateOnConflictSetTypes.
289+
if c.conf.Engine == config.EnginePostgreSQL {
290+
if err := c.validateOnConflictSetTypes(n, params, colDataTypes); err != nil {
291+
return err
292+
}
307293
}
308294

309-
// Build param number → resolved DataType string from already-resolved params.
295+
return nil
296+
}
297+
298+
// validateOnConflictSetTypes checks that values in DO UPDATE SET assignments
299+
// are type-compatible with their target columns (PostgreSQL only).
300+
// It handles $N params (typed from INSERT VALUES) and EXCLUDED.col references.
301+
func (c *Compiler) validateOnConflictSetTypes(n *ast.InsertStmt, params []Parameter, colDataTypes map[string]string) error {
302+
// Build param number → resolved DataType from already-resolved params.
310303
// Skips params with "any" type (unresolved).
311304
paramDataTypes := make(map[int]string, len(params))
312305
for i := range params {
@@ -315,13 +308,6 @@ func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter
315308
}
316309
}
317310

318-
// Build column name → DataType string using the same dataType() function
319-
// used by resolveCatalogRefs, so formats are comparable.
320-
colDataTypes := make(map[string]string, len(table.Columns))
321-
for _, col := range table.Columns {
322-
colDataTypes[col.Name] = dataType(&col.Type)
323-
}
324-
325311
for _, item := range n.OnConflictClause.TargetList.Items {
326312
target, ok := item.(*ast.ResTarget)
327313
if !ok || target.Name == nil {

0 commit comments

Comments
 (0)