Skip to content

Commit dcf9ed5

Browse files
authored
Feat(mypyc)!: compile python generator (#7528)
1 parent 8e5e255 commit dcf9ed5

3 files changed

Lines changed: 112 additions & 112 deletions

File tree

sqlglot/executor/python.py

Lines changed: 4 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import itertools
33
import math
44

5-
from sqlglot import exp, generator, planner, tokens
6-
from sqlglot.dialects.dialect import Dialect, inline_array_sql
5+
from sqlglot import exp, planner, tokens
6+
from sqlglot.dialects.dialect import Dialect
77
from sqlglot.errors import ExecuteError
88
from sqlglot.executor.context import Context
99
from sqlglot.executor.env import ENV
1010
from sqlglot.executor.table import RowReader, Table
11-
from sqlglot.helper import subclasses
11+
from sqlglot.generators.python import PythonGenerator
1212

1313

1414
class PythonExecutor:
@@ -324,112 +324,8 @@ def set_operation(self, step, context):
324324
return self.context({step.name: sink})
325325

326326

327-
def _ordered_py(self, expression):
328-
this = self.sql(expression, "this")
329-
desc = "True" if expression.args.get("desc") else "False"
330-
nulls_first = "True" if expression.args.get("nulls_first") else "False"
331-
return f"ORDERED({this}, {desc}, {nulls_first})"
332-
333-
334-
def _rename(self, e):
335-
try:
336-
values = list(e.args.values())
337-
338-
if len(values) == 1:
339-
values = values[0]
340-
if not isinstance(values, list):
341-
return self.func(e.key, values)
342-
return self.func(e.key, *values)
343-
344-
if isinstance(e, exp.Func) and e.is_var_len_args:
345-
args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values)
346-
return self.func(e.key, *args)
347-
348-
return self.func(e.key, *values)
349-
except Exception as ex:
350-
raise Exception(f"Could not rename {repr(e)}") from ex
351-
352-
353-
def _case_sql(self, expression):
354-
this = self.sql(expression, "this")
355-
chain = self.sql(expression, "default") or "None"
356-
357-
for e in reversed(expression.args["ifs"]):
358-
true = self.sql(e, "true")
359-
condition = self.sql(e, "this")
360-
condition = f"{this} = ({condition})" if this else condition
361-
chain = f"{true} if {condition} else ({chain})"
362-
363-
return chain
364-
365-
366-
def _lambda_sql(self, e: exp.Lambda) -> str:
367-
names = {e.name.lower() for e in e.expressions}
368-
369-
e = e.transform(
370-
lambda n: (
371-
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
372-
)
373-
).assert_is(exp.Lambda)
374-
375-
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
376-
377-
378-
def _div_sql(self: generator.Generator, e: exp.Div) -> str:
379-
denominator = self.sql(e, "expression")
380-
381-
if e.args.get("safe"):
382-
denominator += " or None"
383-
384-
sql = f"DIV({self.sql(e, 'this')}, {denominator})"
385-
386-
if e.args.get("typed"):
387-
sql = f"int({sql})"
388-
389-
return sql
390-
391-
392327
class Python(Dialect):
393328
class Tokenizer(tokens.Tokenizer):
394329
STRING_ESCAPES = ["\\"]
395330

396-
class Generator(generator.Generator):
397-
TRANSFORMS = {
398-
**{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
399-
**{klass: _rename for klass in exp.ALL_FUNCTIONS},
400-
exp.Case: _case_sql,
401-
exp.Alias: lambda self, e: self.sql(e.this),
402-
exp.Array: inline_array_sql,
403-
exp.And: lambda self, e: self.binary(e, "and"),
404-
exp.Between: _rename,
405-
exp.Boolean: lambda self, e: "True" if e.this else "False",
406-
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DType.{e.args['to']})",
407-
exp.Column: lambda self, e: (
408-
f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]"
409-
),
410-
exp.Concat: lambda self, e: self.func(
411-
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
412-
),
413-
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
414-
exp.Div: _div_sql,
415-
exp.Extract: lambda self, e: (
416-
f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})"
417-
),
418-
exp.In: lambda self, e: (
419-
f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}"
420-
),
421-
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
422-
exp.Is: lambda self, e: (
423-
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
424-
),
425-
exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions),
426-
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
427-
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
428-
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
429-
exp.Lambda: _lambda_sql,
430-
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
431-
exp.Null: lambda *_: "None",
432-
exp.Or: lambda self, e: self.binary(e, "or"),
433-
exp.Ordered: _ordered_py,
434-
exp.Star: lambda *_: "1",
435-
}
331+
Generator = PythonGenerator

sqlglot/generators/python.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
5+
from sqlglot import exp, generator
6+
from sqlglot.dialects.dialect import inline_array_sql
7+
from sqlglot.helper import subclasses
8+
9+
10+
def _ordered_py(self, expression):
11+
this = self.sql(expression, "this")
12+
desc = "True" if expression.args.get("desc") else "False"
13+
nulls_first = "True" if expression.args.get("nulls_first") else "False"
14+
return f"ORDERED({this}, {desc}, {nulls_first})"
15+
16+
17+
def _rename(self, e):
18+
try:
19+
values = list(e.args.values())
20+
21+
if len(values) == 1:
22+
values = values[0]
23+
if not isinstance(values, list):
24+
return self.func(e.key, values)
25+
return self.func(e.key, *values)
26+
27+
if isinstance(e, exp.Func) and e.is_var_len_args:
28+
args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values)
29+
return self.func(e.key, *args)
30+
31+
return self.func(e.key, *values)
32+
except Exception as ex:
33+
raise Exception(f"Could not rename {repr(e)}") from ex
34+
35+
36+
def _case_sql(self, expression):
37+
this = self.sql(expression, "this")
38+
chain = self.sql(expression, "default") or "None"
39+
40+
for e in reversed(expression.args["ifs"]):
41+
true = self.sql(e, "true")
42+
condition = self.sql(e, "this")
43+
condition = f"{this} = ({condition})" if this else condition
44+
chain = f"{true} if {condition} else ({chain})"
45+
46+
return chain
47+
48+
49+
def _lambda_sql(self, e: exp.Lambda) -> str:
50+
names = {e.name.lower() for e in e.expressions}
51+
52+
e = e.transform(
53+
lambda n: (
54+
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
55+
)
56+
).assert_is(exp.Lambda)
57+
58+
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
59+
60+
61+
def _div_sql(self: generator.Generator, e: exp.Div) -> str:
62+
denominator = self.sql(e, "expression")
63+
64+
if e.args.get("safe"):
65+
denominator += " or None"
66+
67+
sql = f"DIV({self.sql(e, 'this')}, {denominator})"
68+
69+
if e.args.get("typed"):
70+
sql = f"int({sql})"
71+
72+
return sql
73+
74+
75+
class PythonGenerator(generator.Generator):
76+
TRANSFORMS = {
77+
**{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
78+
**{klass: _rename for klass in exp.ALL_FUNCTIONS},
79+
exp.Case: _case_sql,
80+
exp.Alias: lambda self, e: self.sql(e.this),
81+
exp.Array: inline_array_sql,
82+
exp.And: lambda self, e: self.binary(e, "and"),
83+
exp.Between: _rename,
84+
exp.Boolean: lambda self, e: "True" if e.this else "False",
85+
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DType.{e.args['to']})",
86+
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
87+
exp.Concat: lambda self, e: self.func(
88+
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
89+
),
90+
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
91+
exp.Div: _div_sql,
92+
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
93+
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
94+
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
95+
exp.Is: lambda self, e: (
96+
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
97+
),
98+
exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions),
99+
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
100+
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
101+
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
102+
exp.Lambda: _lambda_sql,
103+
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
104+
exp.Null: lambda *_: "None",
105+
exp.Or: lambda self, e: self.binary(e, "or"),
106+
exp.Ordered: _ordered_py,
107+
exp.Star: lambda *_: "1",
108+
}

tests/test_executor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pandas as pd
1111
from pandas.testing import assert_frame_equal
1212

13-
import sqlglot.generator as _generator_module
1413
from sqlglot import exp, find_tables, parse_one, transpile
1514
from sqlglot.errors import ExecuteError
1615
from sqlglot.executor import execute
@@ -26,8 +25,6 @@
2625
load_sql_fixture_pairs,
2726
)
2827

29-
_GENERATOR_IS_COMPILED = getattr(_generator_module, "__file__", "").endswith(".so")
30-
3128
DIR_TPCH = FIXTURES_DIR + "/optimizer/tpc-h/"
3229
DIR_TPCDS = FIXTURES_DIR + "/optimizer/tpc-ds/"
3330

@@ -69,7 +66,6 @@ def mp_execute(expression, meta):
6966

7067

7168
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
72-
@unittest.skipIf(_GENERATOR_IS_COMPILED, "executor requires interpreted Generator subclass")
7369
class TestExecutor(unittest.TestCase):
7470
@classmethod
7571
def setUpClass(cls):

0 commit comments

Comments
 (0)