Skip to content

Commit 14e9917

Browse files
committed
Add JOIN USING support and refactor output column handling
Implement JOIN...USING clause support for ClickHouse and PostgreSQL. Refactor output_columns.go for improved type resolution. Add comprehensive tests for output columns and type resolution. Update quote character handling and catalog interface.
1 parent d426a16 commit 14e9917

8 files changed

Lines changed: 1503 additions & 269 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package compiler
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/config"
8+
"github.com/sqlc-dev/sqlc/internal/engine/clickhouse"
9+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
10+
)
11+
12+
// TestClickHouseJoinColumnResolution tests that column names are properly resolved
13+
// in JOIN queries now that JoinExpr is correctly converted
14+
func TestClickHouseJoinColumnResolution(t *testing.T) {
15+
parser := clickhouse.NewParser()
16+
cat := clickhouse.NewCatalog()
17+
18+
// Create database and tables
19+
schemaSQL := `CREATE DATABASE IF NOT EXISTS test_db;
20+
CREATE TABLE test_db.users (
21+
id UInt32,
22+
name String,
23+
email String
24+
);
25+
CREATE TABLE test_db.posts (
26+
id UInt32,
27+
user_id UInt32,
28+
title String,
29+
content String
30+
)`
31+
32+
stmts, err := parser.Parse(strings.NewReader(schemaSQL))
33+
if err != nil {
34+
t.Fatalf("Parse schema failed: %v", err)
35+
}
36+
37+
for _, stmt := range stmts {
38+
if err := cat.Update(stmt, nil); err != nil {
39+
t.Fatalf("Update catalog failed: %v", err)
40+
}
41+
}
42+
43+
// Create compiler
44+
conf := config.SQL{
45+
Engine: config.EngineClickHouse,
46+
}
47+
combo := config.CombinedSettings{
48+
Global: config.Config{},
49+
}
50+
51+
c, err := NewCompiler(conf, combo)
52+
if err != nil {
53+
t.Fatalf("Failed to create compiler: %v", err)
54+
}
55+
56+
// Replace catalog
57+
c.catalog = cat
58+
59+
// Parse a JOIN query
60+
querySQL := "SELECT u.id, u.name, p.id as post_id, p.title FROM test_db.users u LEFT JOIN test_db.posts p ON u.id = p.user_id WHERE u.id = 1"
61+
queryStmts, err := parser.Parse(strings.NewReader(querySQL))
62+
if err != nil {
63+
t.Fatalf("Parse query failed: %v", err)
64+
}
65+
66+
if len(queryStmts) == 0 {
67+
t.Fatal("No queries parsed")
68+
}
69+
70+
selectStmt := queryStmts[0].Raw.Stmt
71+
if selectStmt == nil {
72+
t.Fatal("Select statement is nil")
73+
}
74+
75+
selectAst, ok := selectStmt.(*ast.SelectStmt)
76+
if !ok {
77+
t.Fatalf("Expected SelectStmt, got %T", selectStmt)
78+
}
79+
80+
// Build query catalog and get output columns
81+
qc, err := c.buildQueryCatalog(c.catalog, selectAst, nil)
82+
if err != nil {
83+
t.Fatalf("Failed to build query catalog: %v", err)
84+
}
85+
86+
cols, err := c.outputColumns(qc, selectAst)
87+
if err != nil {
88+
t.Fatalf("Failed to get output columns: %v", err)
89+
}
90+
91+
if len(cols) != 4 {
92+
t.Errorf("Expected 4 columns, got %d", len(cols))
93+
}
94+
95+
expectedNames := []string{"id", "name", "post_id", "title"}
96+
for i, expected := range expectedNames {
97+
if i < len(cols) {
98+
if cols[i].Name != expected {
99+
t.Errorf("Column %d: expected name %q, got %q", i, expected, cols[i].Name)
100+
}
101+
}
102+
}
103+
}
104+
105+
// TestClickHouseLeftJoinNullability tests that LEFT JOIN correctly marks right-side columns as nullable
106+
// In ClickHouse, columns are non-nullable by default unless wrapped in Nullable(T)
107+
func TestClickHouseLeftJoinNullability(t *testing.T) {
108+
parser := clickhouse.NewParser()
109+
cat := clickhouse.NewCatalog()
110+
111+
schemaSQL := `CREATE TABLE orders (
112+
order_id UInt32,
113+
customer_name String,
114+
amount Float64,
115+
created_at DateTime
116+
);
117+
CREATE TABLE shipments (
118+
shipment_id UInt32,
119+
order_id UInt32,
120+
address String,
121+
shipped_at DateTime
122+
)`
123+
124+
stmts, err := parser.Parse(strings.NewReader(schemaSQL))
125+
if err != nil {
126+
t.Fatalf("Parse schema failed: %v", err)
127+
}
128+
129+
for _, stmt := range stmts {
130+
if err := cat.Update(stmt, nil); err != nil {
131+
t.Fatalf("Update catalog failed: %v", err)
132+
}
133+
}
134+
135+
conf := config.SQL{
136+
Engine: config.EngineClickHouse,
137+
}
138+
combo := config.CombinedSettings{
139+
Global: config.Config{},
140+
}
141+
142+
c, err := NewCompiler(conf, combo)
143+
if err != nil {
144+
t.Fatalf("Failed to create compiler: %v", err)
145+
}
146+
c.catalog = cat
147+
148+
querySQL := "SELECT o.order_id, o.customer_name, o.amount, o.created_at, s.shipment_id, s.address, s.shipped_at FROM orders o LEFT JOIN shipments s ON o.order_id = s.order_id ORDER BY o.created_at DESC"
149+
queryStmts, err := parser.Parse(strings.NewReader(querySQL))
150+
if err != nil {
151+
t.Fatalf("Parse query failed: %v", err)
152+
}
153+
154+
selectAst := queryStmts[0].Raw.Stmt.(*ast.SelectStmt)
155+
qc, err := c.buildQueryCatalog(c.catalog, selectAst, nil)
156+
if err != nil {
157+
t.Fatalf("Failed to build query catalog: %v", err)
158+
}
159+
160+
cols, err := c.outputColumns(qc, selectAst)
161+
if err != nil {
162+
t.Fatalf("Failed to get output columns: %v", err)
163+
}
164+
165+
if len(cols) != 7 {
166+
t.Errorf("Expected 7 columns, got %d", len(cols))
167+
}
168+
169+
// Left table columns should be non-nullable
170+
leftTableNonNull := map[string]bool{
171+
"order_id": true,
172+
"customer_name": true,
173+
"amount": true,
174+
"created_at": true,
175+
}
176+
177+
// Right table columns should be nullable (because of LEFT JOIN)
178+
rightTableNullable := map[string]bool{
179+
"shipment_id": true,
180+
"address": true,
181+
"shipped_at": true,
182+
}
183+
184+
for _, col := range cols {
185+
if expected, ok := leftTableNonNull[col.Name]; ok {
186+
if col.NotNull != expected {
187+
t.Errorf("Column %q: expected NotNull=%v, got %v", col.Name, expected, col.NotNull)
188+
}
189+
}
190+
if expected, ok := rightTableNullable[col.Name]; ok {
191+
if col.NotNull == expected {
192+
t.Errorf("Column %q: expected NotNull=%v, got %v", col.Name, !expected, col.NotNull)
193+
}
194+
}
195+
}
196+
}

internal/compiler/expand.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ func (c *Compiler) quoteIdent(ident string) string {
7171

7272
func (c *Compiler) quote(x string) string {
7373
switch c.conf.Engine {
74+
case config.EngineClickHouse:
75+
return "`" + x + "`"
7476
case config.EngineMySQL:
7577
return "`" + x + "`"
7678
default:
@@ -84,6 +86,9 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
8486
return nil, err
8587
}
8688

89+
// Track USING columns to avoid duplicating them in SELECT * expansion
90+
usingMap := getJoinUsingMap(node)
91+
8792
var targets *ast.List
8893
switch n := node.(type) {
8994
case *ast.DeleteStmt:
@@ -126,8 +131,13 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
126131
counts := map[string]int{}
127132
if scope == "" {
128133
for _, t := range tables {
129-
for _, c := range t.Columns {
130-
counts[c.Name] += 1
134+
for _, col := range t.Columns {
135+
// Don't count columns that are in USING clause for this table
136+
// since they won't be included in the expansion
137+
if usingInfo, ok := usingMap[t.Rel.Name]; ok && usingInfo.HasColumn(col.Name) {
138+
continue
139+
}
140+
counts[col.Name] += 1
131141
}
132142
}
133143
}
@@ -138,6 +148,12 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
138148
tableName := c.quoteIdent(t.Rel.Name)
139149
scopeName := c.quoteIdent(scope)
140150
for _, column := range t.Columns {
151+
// Skip columns that are in USING clause for this table
152+
// to avoid duplication (USING naturally returns only one column)
153+
if usingInfo, ok := usingMap[t.Rel.Name]; ok && usingInfo.HasColumn(column.Name) {
154+
continue
155+
}
156+
141157
cname := column.Name
142158
if res.Name != nil {
143159
cname = *res.Name

0 commit comments

Comments
 (0)