|
2 | 2 | import itertools |
3 | 3 | import math |
4 | 4 |
|
5 | | -from sqlglot import exp, generator, planner, tokens |
6 | | -from sqlglot.dialects.dialect import Dialect, inline_array_sql |
| 5 | +from sqlglot import exp, planner, tokens |
| 6 | +from sqlglot.dialects.dialect import Dialect |
7 | 7 | from sqlglot.errors import ExecuteError |
8 | 8 | from sqlglot.executor.context import Context |
9 | 9 | from sqlglot.executor.env import ENV |
10 | 10 | from sqlglot.executor.table import RowReader, Table |
11 | | -from sqlglot.helper import subclasses |
| 11 | +from sqlglot.generators.python import PythonGenerator |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class PythonExecutor: |
@@ -324,112 +324,8 @@ def set_operation(self, step, context): |
324 | 324 | return self.context({step.name: sink}) |
325 | 325 |
|
326 | 326 |
|
327 | | -def _ordered_py(self, expression): |
328 | | - this = self.sql(expression, "this") |
329 | | - desc = "True" if expression.args.get("desc") else "False" |
330 | | - nulls_first = "True" if expression.args.get("nulls_first") else "False" |
331 | | - return f"ORDERED({this}, {desc}, {nulls_first})" |
332 | | - |
333 | | - |
334 | | -def _rename(self, e): |
335 | | - try: |
336 | | - values = list(e.args.values()) |
337 | | - |
338 | | - if len(values) == 1: |
339 | | - values = values[0] |
340 | | - if not isinstance(values, list): |
341 | | - return self.func(e.key, values) |
342 | | - return self.func(e.key, *values) |
343 | | - |
344 | | - if isinstance(e, exp.Func) and e.is_var_len_args: |
345 | | - args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values) |
346 | | - return self.func(e.key, *args) |
347 | | - |
348 | | - return self.func(e.key, *values) |
349 | | - except Exception as ex: |
350 | | - raise Exception(f"Could not rename {repr(e)}") from ex |
351 | | - |
352 | | - |
353 | | -def _case_sql(self, expression): |
354 | | - this = self.sql(expression, "this") |
355 | | - chain = self.sql(expression, "default") or "None" |
356 | | - |
357 | | - for e in reversed(expression.args["ifs"]): |
358 | | - true = self.sql(e, "true") |
359 | | - condition = self.sql(e, "this") |
360 | | - condition = f"{this} = ({condition})" if this else condition |
361 | | - chain = f"{true} if {condition} else ({chain})" |
362 | | - |
363 | | - return chain |
364 | | - |
365 | | - |
366 | | -def _lambda_sql(self, e: exp.Lambda) -> str: |
367 | | - names = {e.name.lower() for e in e.expressions} |
368 | | - |
369 | | - e = e.transform( |
370 | | - lambda n: ( |
371 | | - exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n |
372 | | - ) |
373 | | - ).assert_is(exp.Lambda) |
374 | | - |
375 | | - return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" |
376 | | - |
377 | | - |
378 | | -def _div_sql(self: generator.Generator, e: exp.Div) -> str: |
379 | | - denominator = self.sql(e, "expression") |
380 | | - |
381 | | - if e.args.get("safe"): |
382 | | - denominator += " or None" |
383 | | - |
384 | | - sql = f"DIV({self.sql(e, 'this')}, {denominator})" |
385 | | - |
386 | | - if e.args.get("typed"): |
387 | | - sql = f"int({sql})" |
388 | | - |
389 | | - return sql |
390 | | - |
391 | | - |
392 | 327 | class Python(Dialect): |
393 | 328 | class Tokenizer(tokens.Tokenizer): |
394 | 329 | STRING_ESCAPES = ["\\"] |
395 | 330 |
|
396 | | - class Generator(generator.Generator): |
397 | | - TRANSFORMS = { |
398 | | - **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, |
399 | | - **{klass: _rename for klass in exp.ALL_FUNCTIONS}, |
400 | | - exp.Case: _case_sql, |
401 | | - exp.Alias: lambda self, e: self.sql(e.this), |
402 | | - exp.Array: inline_array_sql, |
403 | | - exp.And: lambda self, e: self.binary(e, "and"), |
404 | | - exp.Between: _rename, |
405 | | - exp.Boolean: lambda self, e: "True" if e.this else "False", |
406 | | - exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DType.{e.args['to']})", |
407 | | - exp.Column: lambda self, e: ( |
408 | | - f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]" |
409 | | - ), |
410 | | - exp.Concat: lambda self, e: self.func( |
411 | | - "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions |
412 | | - ), |
413 | | - exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", |
414 | | - exp.Div: _div_sql, |
415 | | - exp.Extract: lambda self, e: ( |
416 | | - f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})" |
417 | | - ), |
418 | | - exp.In: lambda self, e: ( |
419 | | - f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}" |
420 | | - ), |
421 | | - exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", |
422 | | - exp.Is: lambda self, e: ( |
423 | | - self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is") |
424 | | - ), |
425 | | - exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions), |
426 | | - exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]", |
427 | | - exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'", |
428 | | - exp.JSONPathSubscript: lambda self, e: f"'{e.this}'", |
429 | | - exp.Lambda: _lambda_sql, |
430 | | - exp.Not: lambda self, e: f"not {self.sql(e.this)}", |
431 | | - exp.Null: lambda *_: "None", |
432 | | - exp.Or: lambda self, e: self.binary(e, "or"), |
433 | | - exp.Ordered: _ordered_py, |
434 | | - exp.Star: lambda *_: "1", |
435 | | - } |
| 331 | + Generator = PythonGenerator |
0 commit comments