Skip to content

Commit 1ce4fec

Browse files
committed
feat: Option to disable using null for optional arguments
Signed-off-by: Dmitry Dygalo <dmitry.dygalo@workato.com>
1 parent 73ec25c commit 1ce4fec

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased] - TBD
44

5+
### Added
6+
7+
- The `allow_null` option that controls if optional arguments may be `null`. `True` by default.
8+
59
## [0.11.0] - 2023-11-29
610

711
### Added

src/hypothesis_graphql/_strategies/strategy.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class GraphQLStrategy:
5454
schema: graphql.GraphQLSchema
5555
alphabet: st.SearchStrategy[str]
5656
custom_scalars: CustomScalarStrategies = dataclasses.field(default_factory=dict)
57+
allow_null: bool = True
5758
# As the schema is assumed to be immutable, there are a few strategy caches possible for internal components
5859
# This is a per-method cache without limits as they are proportionate to the schema size
5960
_cache: Dict[str, Dict] = dataclasses.field(default_factory=dict)
@@ -75,6 +76,8 @@ def values(
7576
- GraphQLNonNull -> T (processed with nullable=False)
7677
"""
7778
type_, nullable = check_nullable(type_)
79+
if not self.allow_null:
80+
nullable = False
7881
# Types without children
7982
if isinstance(type_, graphql.GraphQLScalarType):
8083
type_name = type_.name
@@ -374,15 +377,16 @@ def _make_strategy(
374377
fields: Optional[Iterable[str]] = None,
375378
custom_scalars: Optional[CustomScalarStrategies] = None,
376379
alphabet: st.SearchStrategy[str],
380+
allow_null: bool = True,
377381
) -> st.SearchStrategy[List[graphql.FieldNode]]:
378382
if fields is not None:
379383
fields = tuple(fields)
380384
validation.validate_fields(fields, list(type_.fields))
381385
if custom_scalars:
382386
validation.validate_custom_scalars(custom_scalars)
383-
return GraphQLStrategy(schema=schema, alphabet=alphabet, custom_scalars=custom_scalars or {}).selections(
384-
type_, fields=fields
385-
)
387+
return GraphQLStrategy(
388+
schema=schema, alphabet=alphabet, custom_scalars=custom_scalars or {}, allow_null=allow_null
389+
).selections(type_, fields=fields)
386390

387391

388392
def _build_alphabet(allow_x00: bool = True, codec: Optional[str] = "utf-8") -> st.SearchStrategy[str]:
@@ -399,6 +403,7 @@ def queries(
399403
custom_scalars: Optional[CustomScalarStrategies] = None,
400404
print_ast: AstPrinter = graphql.print_ast,
401405
allow_x00: bool = True,
406+
allow_null: bool = True,
402407
codec: Optional[str] = "utf-8",
403408
) -> st.SearchStrategy[str]:
404409
r"""A strategy for generating valid queries for the given GraphQL schema.
@@ -410,6 +415,7 @@ def queries(
410415
:param custom_scalars: Strategies for generating custom scalars.
411416
:param print_ast: A function to convert the generated AST to a string.
412417
:param allow_x00: Determines whether to allow the generation of `\x00` bytes within strings.
418+
:param allow_null: Whether `null` values should be used for optional arguments.
413419
:param codec: Specifies the codec used for generating strings.
414420
"""
415421
parsed_schema = validation.maybe_parse_schema(schema)
@@ -423,6 +429,7 @@ def queries(
423429
fields=fields,
424430
custom_scalars=custom_scalars,
425431
alphabet=alphabet,
432+
allow_null=allow_null,
426433
)
427434
.map(make_query)
428435
.map(print_ast)
@@ -437,6 +444,7 @@ def mutations(
437444
custom_scalars: Optional[CustomScalarStrategies] = None,
438445
print_ast: AstPrinter = graphql.print_ast,
439446
allow_x00: bool = True,
447+
allow_null: bool = True,
440448
codec: Optional[str] = "utf-8",
441449
) -> st.SearchStrategy[str]:
442450
r"""A strategy for generating valid mutations for the given GraphQL schema.
@@ -448,6 +456,7 @@ def mutations(
448456
:param custom_scalars: Strategies for generating custom scalars.
449457
:param print_ast: A function to convert the generated AST to a string.
450458
:param allow_x00: Determines whether to allow the generation of `\x00` bytes within strings.
459+
:param allow_null: Whether `null` values should be used for optional arguments.
451460
:param codec: Specifies the codec used for generating strings.
452461
"""
453462
parsed_schema = validation.maybe_parse_schema(schema)
@@ -461,6 +470,7 @@ def mutations(
461470
fields=fields,
462471
custom_scalars=custom_scalars,
463472
alphabet=alphabet,
473+
allow_null=allow_null,
464474
)
465475
.map(make_mutation)
466476
.map(print_ast)
@@ -475,6 +485,7 @@ def from_schema(
475485
custom_scalars: Optional[CustomScalarStrategies] = None,
476486
print_ast: AstPrinter = graphql.print_ast,
477487
allow_x00: bool = True,
488+
allow_null: bool = True,
478489
codec: Optional[str] = "utf-8",
479490
) -> st.SearchStrategy[str]:
480491
r"""A strategy for generating valid queries and mutations for the given GraphQL schema.
@@ -484,6 +495,7 @@ def from_schema(
484495
:param custom_scalars: Strategies for generating custom scalars.
485496
:param print_ast: A function to convert the generated AST to a string.
486497
:param allow_x00: Determines whether to allow the generation of `\x00` bytes within strings.
498+
:param allow_null: Whether `null` values should be used for optional arguments.
487499
:param codec: Specifies the codec used for generating strings.
488500
"""
489501
parsed_schema = validation.maybe_parse_schema(schema)
@@ -506,7 +518,9 @@ def from_schema(
506518
validation.validate_fields(fields, available_fields)
507519

508520
alphabet = _build_alphabet(allow_x00=allow_x00, codec=codec)
509-
strategy = GraphQLStrategy(parsed_schema, alphabet=alphabet, custom_scalars=custom_scalars or {})
521+
strategy = GraphQLStrategy(
522+
parsed_schema, alphabet=alphabet, custom_scalars=custom_scalars or {}, allow_null=allow_null
523+
)
510524
strategies = [
511525
strategy.selections(type_, fields=type_fields).map(node_factory).map(print_ast)
512526
for (type_, type_fields, node_factory) in (

test/test_queries.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_query_from_graphql_schema(data, schema, validate_operation):
6767
validate_operation(schema, query)
6868

6969

70+
@pytest.mark.parametrize("allow_null", (True, False))
7071
@pytest.mark.parametrize("notnull", (True, False))
7172
@pytest.mark.parametrize(
7273
"arguments, node_names",
@@ -96,20 +97,22 @@ def test_query_from_graphql_schema(data, schema, validate_operation):
9697
),
9798
)
9899
@given(data=st.data())
99-
def test_arguments(data, schema, arguments, node_names, notnull, validate_operation):
100+
def test_arguments(data, schema, arguments, node_names, allow_null, notnull, validate_operation):
100101
if notnull:
101102
arguments += "!"
102103
query_type = f"""type Query {{
103104
getModel({arguments}): Model
104105
}}"""
105106

106107
schema = schema + query_type
107-
query = data.draw(queries(schema))
108+
query = data.draw(queries(schema, allow_null=allow_null))
108109
validate_operation(schema, query)
109110
for node_name in node_names:
110111
assert node_name not in query
111112
if notnull:
112113
assert "getModel(" in query
114+
if not allow_null:
115+
assert "null)" not in query
113116
parsed = graphql.parse(query)
114117
selection = parsed.definitions[0].selection_set.selections[0]
115118
if notnull:

0 commit comments

Comments
 (0)