Skip to content

Commit 7960b17

Browse files
committed
Add ClickHouse engine unit tests
Comprehensive test coverage for parser and converter: - Basic parsing functionality (parse_test.go) - AST conversion correctness (new_conversions_test.go) - Catalog initialization (catalog_test.go) - Edge cases and boundary conditions (parse_boundary_test.go) - Real-world queries (parse_actual_queries_test.go) - Type handling, identifiers, joins, arrays
1 parent f30b317 commit 7960b17

14 files changed

Lines changed: 4026 additions & 0 deletions
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package clickhouse
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
8+
)
9+
10+
// TestArrayJoinColumnAliases validates that ARRAY JOIN creates properly aliased columns
11+
// These columns should be available for reference in the SELECT list
12+
func TestArrayJoinColumnAliases(t *testing.T) {
13+
parser := NewParser()
14+
15+
tests := []struct {
16+
name string
17+
query string
18+
expectedColnames []string // column names from ARRAY JOIN
19+
wantErr bool
20+
}{
21+
{
22+
name: "simple array join with alias",
23+
query: `
24+
SELECT id, tag
25+
FROM users
26+
ARRAY JOIN tags AS tag
27+
`,
28+
expectedColnames: []string{"tag"},
29+
wantErr: false,
30+
},
31+
{
32+
name: "single array join with table alias and qualified name",
33+
query: `
34+
SELECT u.id, u.name, tag
35+
FROM users u
36+
ARRAY JOIN u.tags AS tag
37+
`,
38+
expectedColnames: []string{"tag"},
39+
wantErr: false,
40+
},
41+
{
42+
name: "multiple array joins with aliases",
43+
query: `
44+
SELECT event_id, event_name, prop_key, prop_value
45+
FROM events
46+
ARRAY JOIN properties.keys AS prop_key, properties.values AS prop_value
47+
`,
48+
expectedColnames: []string{"prop_key", "prop_value"},
49+
wantErr: false,
50+
},
51+
}
52+
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
stmts, err := parser.Parse(strings.NewReader(tt.query))
56+
if (err != nil) != tt.wantErr {
57+
t.Fatalf("Parse error: %v, wantErr %v", err, tt.wantErr)
58+
}
59+
60+
if len(stmts) == 0 {
61+
t.Fatal("No statements parsed")
62+
}
63+
64+
selectStmt, ok := stmts[0].Raw.Stmt.(*ast.SelectStmt)
65+
if !ok {
66+
t.Fatalf("Expected SelectStmt, got %T", stmts[0].Raw.Stmt)
67+
}
68+
69+
// Check that the FROM clause contains the ARRAY JOIN as a RangeSubselect
70+
if selectStmt.FromClause == nil || len(selectStmt.FromClause.Items) == 0 {
71+
t.Fatal("No FROM clause items found")
72+
}
73+
74+
// Find the RangeSubselect that represents the ARRAY JOIN
75+
var arrayJoinRangeSubselect *ast.RangeSubselect
76+
for _, item := range selectStmt.FromClause.Items {
77+
if rs, ok := item.(*ast.RangeSubselect); ok {
78+
arrayJoinRangeSubselect = rs
79+
break
80+
}
81+
}
82+
83+
if arrayJoinRangeSubselect == nil {
84+
t.Fatal("No RangeSubselect found for ARRAY JOIN")
85+
}
86+
87+
// Verify that the RangeSubselect has a Subquery (synthetic SELECT statement)
88+
if arrayJoinRangeSubselect.Subquery == nil {
89+
t.Error("ARRAY JOIN RangeSubselect has no Subquery")
90+
return
91+
}
92+
93+
syntheticSelect, ok := arrayJoinRangeSubselect.Subquery.(*ast.SelectStmt)
94+
if !ok {
95+
t.Errorf("Expected SelectStmt subquery, got %T", arrayJoinRangeSubselect.Subquery)
96+
return
97+
}
98+
99+
// Verify the target list has the expected column names
100+
if syntheticSelect.TargetList == nil || len(syntheticSelect.TargetList.Items) == 0 {
101+
t.Error("Synthetic SELECT has no target list")
102+
return
103+
}
104+
105+
if len(syntheticSelect.TargetList.Items) != len(tt.expectedColnames) {
106+
t.Errorf("Expected %d targets, got %d", len(tt.expectedColnames), len(syntheticSelect.TargetList.Items))
107+
return
108+
}
109+
110+
// Verify the target values (which should be ResTargets with Name set)
111+
for i, expected := range tt.expectedColnames {
112+
target, ok := syntheticSelect.TargetList.Items[i].(*ast.ResTarget)
113+
if !ok {
114+
t.Errorf("Target %d is not a ResTarget: %T", i, syntheticSelect.TargetList.Items[i])
115+
continue
116+
}
117+
118+
if target.Name == nil || *target.Name != expected {
119+
var name string
120+
if target.Name != nil {
121+
name = *target.Name
122+
}
123+
t.Errorf("Target %d: expected name %q, got %q", i, expected, name)
124+
}
125+
}
126+
})
127+
}
128+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package clickhouse
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
8+
)
9+
10+
func TestCaseSensitiveColumns(t *testing.T) {
11+
// ClickHouse is case-sensitive for identifiers
12+
// This test demonstrates the issue where columns with different cases
13+
// are incorrectly treated as the same column
14+
15+
sql := `
16+
CREATE TABLE test_table
17+
(
18+
UserId UInt32,
19+
userName String,
20+
EMAIL String
21+
)
22+
ENGINE = MergeTree()
23+
ORDER BY UserId;
24+
`
25+
26+
p := NewParser()
27+
stmts, err := p.Parse(strings.NewReader(sql))
28+
if err != nil {
29+
t.Fatalf("Parse failed: %v", err)
30+
}
31+
32+
if len(stmts) != 1 {
33+
t.Fatalf("Expected 1 statement, got %d", len(stmts))
34+
}
35+
36+
createStmt, ok := stmts[0].Raw.Stmt.(*ast.CreateTableStmt)
37+
if !ok {
38+
t.Fatalf("Expected CreateTableStmt, got %T", stmts[0].Raw.Stmt)
39+
}
40+
41+
// Check that column names preserve their case
42+
expectedColumns := map[string]bool{
43+
"UserId": true,
44+
"userName": true,
45+
"EMAIL": true,
46+
}
47+
48+
actualColumns := make(map[string]bool)
49+
for _, col := range createStmt.Cols {
50+
actualColumns[col.Colname] = true
51+
}
52+
53+
if len(actualColumns) != len(expectedColumns) {
54+
t.Errorf("Expected %d distinct columns, got %d", len(expectedColumns), len(actualColumns))
55+
}
56+
57+
for expected := range expectedColumns {
58+
if !actualColumns[expected] {
59+
t.Errorf("Column '%s' not found. Found columns: %v", expected, actualColumns)
60+
}
61+
}
62+
}
63+
64+
func TestCaseSensitiveColumnReference(t *testing.T) {
65+
// Test that column references preserve case in SELECT statements
66+
sql := "SELECT UserId, userName, EMAIL FROM test_table;"
67+
68+
p := NewParser()
69+
stmts, err := p.Parse(strings.NewReader(sql))
70+
if err != nil {
71+
t.Fatalf("Parse failed: %v", err)
72+
}
73+
74+
if len(stmts) != 1 {
75+
t.Fatalf("Expected 1 statement, got %d", len(stmts))
76+
}
77+
78+
selectStmt, ok := stmts[0].Raw.Stmt.(*ast.SelectStmt)
79+
if !ok {
80+
t.Fatalf("Expected SelectStmt, got %T", stmts[0].Raw.Stmt)
81+
}
82+
83+
expectedColRefs := []string{"UserId", "userName", "EMAIL"}
84+
if len(selectStmt.TargetList.Items) != len(expectedColRefs) {
85+
t.Fatalf("Expected %d target items, got %d", len(expectedColRefs), len(selectStmt.TargetList.Items))
86+
}
87+
88+
for i, expected := range expectedColRefs {
89+
target, ok := selectStmt.TargetList.Items[i].(*ast.ResTarget)
90+
if !ok {
91+
t.Fatalf("Item %d is not a ResTarget: %T", i, selectStmt.TargetList.Items[i])
92+
}
93+
94+
// Check if Name is set (for aliased columns) or extract from ColumnRef
95+
var got string
96+
if target.Name != nil && *target.Name != "" {
97+
got = *target.Name
98+
} else if colRef, ok := target.Val.(*ast.ColumnRef); ok && colRef != nil && colRef.Fields != nil && len(colRef.Fields.Items) > 0 {
99+
// Extract the column name from the ColumnRef
100+
if s, ok := colRef.Fields.Items[len(colRef.Fields.Items)-1].(*ast.String); ok {
101+
got = s.Str
102+
}
103+
}
104+
105+
if got != expected {
106+
t.Errorf("Column %d: expected '%s', got '%s'", i, expected, got)
107+
}
108+
}
109+
}
110+
111+
func TestCaseSensitiveWhereClauses(t *testing.T) {
112+
// Test that WHERE clauses with case-sensitive column names work correctly
113+
sql := "SELECT * FROM users WHERE UserId = 123 AND userName = 'John';"
114+
115+
p := NewParser()
116+
stmts, err := p.Parse(strings.NewReader(sql))
117+
if err != nil {
118+
t.Fatalf("Parse failed: %v", err)
119+
}
120+
121+
selectStmt, ok := stmts[0].Raw.Stmt.(*ast.SelectStmt)
122+
if !ok {
123+
t.Fatalf("Expected SelectStmt, got %T", stmts[0].Raw.Stmt)
124+
}
125+
126+
// Verify WHERE clause references preserve case
127+
if selectStmt.WhereClause == nil {
128+
t.Fatal("WHERE clause is nil")
129+
}
130+
131+
// The WHERE clause should contain column references with preserved case
132+
// This is a simple check - we'd need to traverse the AST to verify
133+
// that column names in the WHERE clause preserve their case
134+
whereStr := astToString(selectStmt.WhereClause)
135+
136+
// Check that the case is preserved in the where clause
137+
if !strings.Contains(whereStr, "UserId") || !strings.Contains(whereStr, "userName") {
138+
t.Errorf("WHERE clause should preserve column name case. Got: %s", whereStr)
139+
}
140+
}
141+
142+
// astToString converts AST nodes to a string representation for testing
143+
func astToString(node ast.Node) string {
144+
if node == nil {
145+
return ""
146+
}
147+
148+
switch n := node.(type) {
149+
case *ast.A_Expr:
150+
left := astToString(n.Lexpr)
151+
right := astToString(n.Rexpr)
152+
return left + " " + right
153+
case *ast.ColumnRef:
154+
if n.Fields != nil && len(n.Fields.Items) > 0 {
155+
if s, ok := n.Fields.Items[len(n.Fields.Items)-1].(*ast.String); ok {
156+
return s.Str
157+
}
158+
}
159+
case *ast.A_Const:
160+
if s, ok := n.Val.(*ast.String); ok {
161+
return s.Str
162+
}
163+
}
164+
return ""
165+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package clickhouse
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
8+
)
9+
10+
func TestCatalogIntegration(t *testing.T) {
11+
schema := `
12+
CREATE TABLE IF NOT EXISTS users
13+
(
14+
id UInt32,
15+
name String,
16+
email String
17+
)
18+
ENGINE = MergeTree()
19+
ORDER BY id;
20+
`
21+
22+
p := NewParser()
23+
stmts, err := p.Parse(strings.NewReader(schema))
24+
if err != nil {
25+
t.Fatalf("Parse failed: %v", err)
26+
}
27+
28+
if len(stmts) != 1 {
29+
t.Fatalf("Expected 1 statement, got %d", len(stmts))
30+
}
31+
32+
// Debug: check what's in the statement
33+
if stmts[0].Raw != nil && stmts[0].Raw.Stmt != nil {
34+
if createStmt, ok := stmts[0].Raw.Stmt.(*ast.CreateTableStmt); ok {
35+
t.Logf("CreateTableStmt: Schema='%s', Table='%s'", createStmt.Name.Schema, createStmt.Name.Name)
36+
t.Logf("CreateTableStmt: Cols count=%d", len(createStmt.Cols))
37+
} else {
38+
t.Logf("Statement type: %T", stmts[0].Raw.Stmt)
39+
}
40+
}
41+
42+
cat := NewCatalog()
43+
if cat.DefaultSchema != "default" {
44+
t.Errorf("Expected default schema 'default', got '%s'", cat.DefaultSchema)
45+
}
46+
47+
// Try to update catalog with the CREATE TABLE
48+
t.Logf("Calling catalog.Update()...")
49+
err = cat.Update(stmts[0], nil)
50+
if err != nil {
51+
t.Fatalf("Catalog update failed: %v", err)
52+
}
53+
t.Logf("Catalog update succeeded")
54+
55+
// Check if table was added
56+
t.Logf("Catalog has %d schemas", len(cat.Schemas))
57+
for i, schema := range cat.Schemas {
58+
t.Logf("Schema[%d]: Name='%s', Tables=%d", i, schema.Name, len(schema.Tables))
59+
}
60+
61+
if len(cat.Schemas) == 0 {
62+
t.Fatal("No schemas in catalog")
63+
}
64+
65+
defaultSchema := cat.Schemas[0]
66+
if len(defaultSchema.Tables) == 0 {
67+
t.Fatal("No tables in default schema")
68+
}
69+
70+
table := defaultSchema.Tables[0]
71+
if table.Rel.Name != "users" {
72+
t.Errorf("Expected table name 'users', got '%s'", table.Rel.Name)
73+
}
74+
75+
if len(table.Columns) != 3 {
76+
t.Errorf("Expected 3 columns, got %d", len(table.Columns))
77+
}
78+
79+
// Log column types for debugging
80+
for i, col := range table.Columns {
81+
t.Logf("Column[%d]: Name='%s', Type.Name='%s', NotNull=%v", i, col.Name, col.Type.Name, col.IsNotNull)
82+
}
83+
}

0 commit comments

Comments
 (0)