Skip to content
This repository was archived by the owner on Mar 29, 2023. It is now read-only.

Commit bfe539a

Browse files
authored
fix: ensure that ScalarParameter names are used instead of Alias names (#135)
* fix: ensure that ScalarParameter names are used instead of Alias names * fix: handle new properties required of UDFs * test: fix scalar-parameterized tests
1 parent 71a01b9 commit bfe539a

File tree

4 files changed

+119
-45
lines changed

4 files changed

+119
-45
lines changed

ibis_bigquery/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
except ImportError:
3030
pass
3131

32+
try:
33+
from ibis.expr.operations import Alias
34+
except ImportError:
35+
# Allow older versions of ibis to work with ScalarParameters as well as
36+
# versions >= 3.0.0
37+
Alias = None
38+
3239

3340
__version__: str = ibis_bigquery_version.__version__
3441

@@ -222,7 +229,24 @@ def _execute(self, stmt, results=True, query_parameters=None):
222229

223230
def raw_sql(self, query: str, results=False, params=None):
224231
query_parameters = [
225-
bigquery_param(param, value) for param, value in (params or {}).items()
232+
bigquery_param(
233+
# unwrap Alias instances
234+
#
235+
# Without unwrapping we try to execute compiled code that uses
236+
# the ScalarParameter's raw name (e.g., @param_1) and not the
237+
# alias's name which will fail. By unwrapping, we always use
238+
# the raw name.
239+
#
240+
# This workaround is backwards compatible and doesn't require
241+
# changes to ibis.
242+
(
243+
param
244+
if Alias is None or not isinstance(param.op(), Alias)
245+
else param.op().arg
246+
),
247+
value,
248+
)
249+
for param, value in (params or {}).items()
226250
]
227251
return self._execute(query, results=results, query_parameters=query_parameters)
228252

ibis_bigquery/udf/__init__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -181,21 +181,18 @@ def wrapper(f):
181181
signature = inspect.signature(f)
182182
parameter_names = signature.parameters.keys()
183183

184-
udf_node_fields = collections.OrderedDict(
185-
[
186-
(name, Arg(rlz.value(type)))
187-
for name, type in zip(parameter_names, input_type)
188-
]
189-
+ [
190-
(
191-
"output_type",
192-
lambda self, output_type=output_type: rlz.shape_like(
193-
self.args, dtype=output_type
194-
),
195-
),
196-
("__slots__", ("js",)),
197-
]
198-
)
184+
udf_node_fields = {
185+
name: Arg(rlz.value(type))
186+
for name, type in zip(parameter_names, input_type)
187+
}
188+
189+
try:
190+
udf_node_fields["output_type"] = rlz.shape_like("args", dtype=output_type)
191+
except TypeError:
192+
udf_node_fields["output_dtype"] = property(lambda _: output_type)
193+
udf_node_fields["output_shape"] = rlz.shape_like("args")
194+
195+
udf_node_fields["__slots__"] = ("js",)
199196

200197
udf_node = create_udf_node(f.__name__, udf_node_fields)
201198

tests/system/test_client.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import datetime
33
import decimal
4+
import re
45

56
import ibis
67
import ibis.expr.datatypes as dt
@@ -11,6 +12,7 @@
1112
import pandas.testing as tm
1213
import pytest
1314
import pytz
15+
from pytest import param
1416

1517
import ibis_bigquery
1618
from ibis_bigquery.client import bigquery_param
@@ -19,6 +21,13 @@
1921
IBIS_1_4_VERSION = packaging.version.Version("1.4.0")
2022
IBIS_3_0_VERSION = packaging.version.Version("3.0.0")
2123

24+
older_than_3 = pytest.mark.xfail(
25+
IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3"
26+
)
27+
at_least_3 = pytest.mark.xfail(
28+
IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3"
29+
)
30+
2231

2332
def test_table(alltypes):
2433
assert isinstance(alltypes, ir.TableExpr)
@@ -204,7 +213,43 @@ def test_different_partition_col_name(monkeypatch, client):
204213
assert col in parted_alltypes.columns
205214

206215

207-
def test_subquery_scalar_params(alltypes, project_id, dataset_id):
216+
def scalar_params_ibis3(project_id, dataset_id):
217+
return f"""\
218+
SELECT count\\(`foo`\\) AS `count`
219+
FROM \\(
220+
SELECT `string_col`, sum\\(`float_col`\\) AS `foo`
221+
FROM \\(
222+
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
223+
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`
224+
\\) t1
225+
WHERE `timestamp_col` < @param_\\d+
226+
GROUP BY 1
227+
\\) t0"""
228+
229+
230+
def scalar_params_not_ibis3(project_id, dataset_id):
231+
return f"""\
232+
SELECT count\\(`foo`\\) AS `count`
233+
FROM \\(
234+
SELECT `string_col`, sum\\(`float_col`\\) AS `foo`
235+
FROM \\(
236+
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
237+
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`
238+
WHERE `timestamp_col` < @my_param
239+
\\) t1
240+
GROUP BY 1
241+
\\) t0"""
242+
243+
244+
@pytest.mark.parametrize(
245+
"expected_fn",
246+
[
247+
param(scalar_params_ibis3, marks=[older_than_3], id="ibis3"),
248+
param(scalar_params_not_ibis3, marks=[at_least_3], id="not_ibis3"),
249+
],
250+
)
251+
def test_subquery_scalar_params(alltypes, project_id, dataset_id, expected_fn):
252+
expected = expected_fn(project_id, dataset_id)
208253
t = alltypes
209254
param = ibis.param("timestamp").name("my_param")
210255
expr = (
@@ -216,20 +261,7 @@ def test_subquery_scalar_params(alltypes, project_id, dataset_id):
216261
.foo.count()
217262
)
218263
result = expr.compile(params={param: "20140101"})
219-
expected = """\
220-
SELECT count(`foo`) AS `count`
221-
FROM (
222-
SELECT `string_col`, sum(`float_col`) AS `foo`
223-
FROM (
224-
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
225-
FROM `{}.{}.functional_alltypes`
226-
WHERE `timestamp_col` < @my_param
227-
) t1
228-
GROUP BY 1
229-
) t0""".format(
230-
project_id, dataset_id
231-
)
232-
assert result == expected
264+
assert re.match(expected, result) is not None
233265

234266

235267
def test_scalar_param_string(alltypes, df):
@@ -457,18 +489,21 @@ def test_raw_sql(client):
457489
assert client.raw_sql("SELECT 1").fetchall() == [(1,)]
458490

459491

460-
def test_scalar_param_scope(alltypes, project_id, dataset_id):
492+
@pytest.mark.parametrize(
493+
"pattern",
494+
[
495+
param(r"@param_\d+", marks=[older_than_3], id="ibis3"),
496+
param("@param", marks=[at_least_3], id="not_ibis3"),
497+
],
498+
)
499+
def test_scalar_param_scope(alltypes, project_id, dataset_id, pattern):
461500
t = alltypes
462501
param = ibis.param("timestamp")
463-
mut = t.mutate(param=param).compile(params={param: "2017-01-01"})
464-
assert (
465-
mut
466-
== """\
467-
SELECT *, @param AS `param`
468-
FROM `{}.{}.functional_alltypes`""".format(
469-
project_id, dataset_id
470-
)
471-
)
502+
result = t.mutate(param=param).compile(params={param: "2017-01-01"})
503+
expected = f"""\
504+
SELECT \\*, {pattern} AS `param`
505+
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`"""
506+
assert re.match(expected, result) is not None
472507

473508

474509
def test_parted_column_rename(parted_alltypes):

tests/system/test_compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,42 @@
1+
import re
2+
13
import ibis
24
import ibis.expr.datatypes as dt
35
import packaging.version
46
import pytest
7+
from pytest import param
58

69
pytestmark = pytest.mark.bigquery
710

811
IBIS_VERSION = packaging.version.Version(ibis.__version__)
912
IBIS_1_VERSION = packaging.version.Version("1.4.0")
13+
IBIS_3_0_VERSION = packaging.version.Version("3.0.0")
1014

15+
older_than_3 = pytest.mark.xfail(
16+
IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3"
17+
)
18+
at_least_3 = pytest.mark.xfail(
19+
IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3"
20+
)
1121

12-
def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id):
22+
23+
@pytest.mark.parametrize(
24+
"pattern",
25+
[
26+
param(r"@param_\d+", marks=[older_than_3], id="ibis3"),
27+
param("@param", marks=[at_least_3], id="not_ibis3"),
28+
],
29+
)
30+
def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id, pattern):
1331
date_string = "2009-03-01"
1432
param = ibis.param(dt.timestamp).name("param_0")
1533
expr = alltypes.mutate(param=param)
1634
params = {param: date_string}
1735
result = expr.compile(params=params)
1836
expected = f"""\
19-
SELECT *, @param AS `param`
20-
FROM `{project_id}.{dataset_id}.functional_alltypes`"""
21-
assert result == expected
37+
SELECT \\*, {pattern} AS `param`
38+
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`"""
39+
assert re.match(expected, result) is not None
2240

2341

2442
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)