Skip to content

Commit 820715b

Browse files
committed
Adopt a common validate_type method across all the models
1 parent 22309c4 commit 820715b

6 files changed

Lines changed: 180 additions & 121 deletions

File tree

pygeoapi/models/config.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from dataclasses import dataclass, fields, asdict
3434
from typing import Any, Dict
3535

36+
from pygeoapi.models.validation import validate_type
37+
3638

3739
SEMVER_PATTERN = re.compile(r'^\d+\.\d+\..+$')
3840

@@ -74,31 +76,15 @@ class APIRules:
7476
strict_slashes: bool = False
7577

7678
def __post_init__(self):
77-
if not isinstance(self.api_version, str):
78-
raise APIRulesValidationError(
79-
"api_version must be a string, "
80-
f"got {type(self.api_version).__name__}"
81-
)
79+
try:
80+
validate_type(self)
81+
except ValueError as e:
82+
raise APIRulesValidationError(str(e)) from e
8283
if not SEMVER_PATTERN.match(self.api_version):
8384
raise APIRulesValidationError(
8485
f"Invalid semantic version: '{self.api_version}'. "
8586
f"Expected format: MAJOR.MINOR.PATCH"
8687
)
87-
if not isinstance(self.url_prefix, str):
88-
raise APIRulesValidationError(
89-
"url_prefix must be a string, "
90-
f"got {type(self.url_prefix).__name__}"
91-
)
92-
if not isinstance(self.version_header, str):
93-
raise APIRulesValidationError(
94-
"version_header must be a string, "
95-
f"got {type(self.version_header).__name__}"
96-
)
97-
if not isinstance(self.strict_slashes, bool):
98-
raise APIRulesValidationError(
99-
"strict_slashes must be a bool, "
100-
f"got {type(self.strict_slashes).__name__}"
101-
)
10288

10389
@classmethod
10490
def create(cls, **rules_config) -> 'APIRules':

pygeoapi/models/openapi.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from enum import Enum
3232
from typing import Any, Dict
3333

34+
from pygeoapi.models.validation import validate_type
35+
3436

3537
class SupportedFormats(Enum):
3638
JSON = 'json'
@@ -51,8 +53,7 @@ class OAPIFormat:
5153
root: SupportedFormats = SupportedFormats.YAML
5254

5355
def __post_init__(self):
54-
if isinstance(self.root, SupportedFormats):
55-
return
56+
# Coerce str to enum before type validation
5657
if isinstance(self.root, str):
5758
try:
5859
self.root = SupportedFormats(self.root)
@@ -62,11 +63,7 @@ def __post_init__(self):
6263
f"Must be one of: "
6364
f"{[f.value for f in SupportedFormats]}"
6465
)
65-
else:
66-
raise ValueError(
67-
f"root must be a string or SupportedFormats, "
68-
f"got {type(self.root).__name__}"
69-
)
66+
validate_type(self)
7067

7168
def __eq__(self, other):
7269
if isinstance(other, str):

pygeoapi/models/provider/base.py

Lines changed: 11 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -31,90 +31,19 @@
3131
#
3232
# =================================================================
3333

34-
from dataclasses import dataclass, field, fields as dc_fields
34+
from dataclasses import dataclass, field
3535
from datetime import datetime
3636
from enum import Enum
3737
import json
3838
from pathlib import Path
39-
from typing import Any, Dict, List, Optional, get_type_hints
39+
from typing import Any, Dict, List, Optional
4040

41+
from pygeoapi.models.validation import validate_type
4142
from pygeoapi.util import DEFINITIONSDIR
4243

4344
TMS_DIR = DEFINITIONSDIR / 'tiles'
4445

4546

46-
def _validate_type(dc_instance: Any) -> None:
47-
"""
48-
Validate field types on a dataclass instance.
49-
50-
Checks each field value against its declared type,
51-
matching dataclass runtime type.
52-
Supports Optional[T], List[T], and plain types.
53-
54-
:param dc_instance: dataclass instance to validate
55-
56-
:raises ValueError: if a field value has the wrong type
57-
"""
58-
hints = get_type_hints(dc_instance.__class__)
59-
for f in dc_fields(dc_instance):
60-
value = getattr(dc_instance, f.name)
61-
expected = hints[f.name]
62-
63-
# Extract inner type from Optional[T]
64-
origin = getattr(expected, '__origin__', None)
65-
args = getattr(expected, '__args__', ())
66-
67-
is_optional = (
68-
origin is type(None) # noqa: E721
69-
or (origin is not None
70-
and type(None) in args)
71-
)
72-
73-
if is_optional and value is None:
74-
continue
75-
76-
# Unwrap Optional to get the inner type
77-
if is_optional and args:
78-
inner_types = [
79-
a for a in args if a is not type(None)
80-
]
81-
if len(inner_types) == 1:
82-
expected = inner_types[0]
83-
origin = getattr(expected, '__origin__', None)
84-
args = getattr(expected, '__args__', ())
85-
86-
# Check List[T]
87-
if origin is list:
88-
if not isinstance(value, list):
89-
raise ValueError(
90-
f"{f.name} must be a list, "
91-
f"got {type(value).__name__}"
92-
)
93-
# Check plain types (str, int, float, bool, Enum)
94-
elif origin is None:
95-
if isinstance(expected, type):
96-
# bool is subclass of int, check bool first
97-
if expected is bool:
98-
if not isinstance(value, bool):
99-
raise ValueError(
100-
f"{f.name} must be a bool, "
101-
f"got {type(value).__name__}"
102-
)
103-
elif expected is int:
104-
if isinstance(value, bool) \
105-
or not isinstance(value, int):
106-
raise ValueError(
107-
f"{f.name} must be an int, "
108-
f"got {type(value).__name__}"
109-
)
110-
elif not isinstance(value, expected):
111-
raise ValueError(
112-
f"{f.name} must be a "
113-
f"{expected.__name__}, "
114-
f"got {type(value).__name__}"
115-
)
116-
117-
11847
class TilesMetadataFormat(str, Enum):
11948
# Tile Set Metadata
12049
JSON = "JSON"
@@ -160,7 +89,7 @@ class TileMatrixSetEnumType:
16089
tileMatrices: List[dict] = field(default_factory=list)
16190

16291
def __post_init__(self):
163-
_validate_type(self)
92+
validate_type(self)
16493

16594
def model_dump(
16695
self, exclude_none: bool = False
@@ -244,7 +173,7 @@ class TileMatrixLimitsType:
244173
maxTileCol: int = 0
245174

246175
def __post_init__(self):
247-
_validate_type(self)
176+
validate_type(self)
248177

249178
def model_dump(
250179
self, exclude_none: bool = False
@@ -274,7 +203,7 @@ class TwoDBoundingBoxType:
274203
crs: Optional[str] = None
275204

276205
def __post_init__(self):
277-
_validate_type(self)
206+
validate_type(self)
278207

279208
def model_dump(
280209
self, exclude_none: bool = False
@@ -305,7 +234,7 @@ class LinkType:
305234
length: Optional[int] = None
306235

307236
def __post_init__(self):
308-
_validate_type(self)
237+
validate_type(self)
309238

310239
def model_dump(
311240
self, exclude_none: bool = False
@@ -351,7 +280,7 @@ class GeospatialDataType:
351280
propertiesSchema: Optional[dict] = None
352281

353282
def __post_init__(self):
354-
_validate_type(self)
283+
validate_type(self)
355284

356285
def model_dump(
357286
self, exclude_none: bool = False
@@ -383,7 +312,7 @@ class StyleType:
383312
links: Optional[LinkType] = None
384313

385314
def __post_init__(self):
386-
_validate_type(self)
315+
validate_type(self)
387316

388317
def model_dump(
389318
self, exclude_none: bool = False
@@ -413,7 +342,7 @@ class TilePointType:
413342
tileMatrix: str = ''
414343

415344
def __post_init__(self):
416-
_validate_type(self)
345+
validate_type(self)
417346

418347
def model_dump(
419348
self, exclude_none: bool = False
@@ -463,7 +392,7 @@ class TileSetMetadata:
463392
links: Optional[List[LinkType]] = None
464393

465394
def __post_init__(self):
466-
_validate_type(self)
395+
validate_type(self)
467396

468397
def model_dump(
469398
self, exclude_none: bool = False

pygeoapi/models/provider/mvt.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
#
3030
# =================================================================
3131

32-
from dataclasses import dataclass
32+
from dataclasses import dataclass, fields as dc_fields
3333
from typing import Any, Dict, List, Optional
3434

35-
from pygeoapi.models.provider.base import _validate_type
35+
from pygeoapi.models.validation import validate_type
3636

3737

3838
@dataclass
@@ -46,7 +46,7 @@ class VectorLayers:
4646
fields: Optional[dict] = None
4747

4848
def __post_init__(self):
49-
_validate_type(self)
49+
validate_type(self)
5050

5151
def model_dump(
5252
self, exclude_none: bool = False
@@ -63,7 +63,12 @@ def model_dump(
6363

6464
@dataclass
6565
class MVTTilesJson:
66-
"""TileJSON 3.0 specification."""
66+
"""TileJSON 3.0 specification.
67+
68+
Accepts and silently ignores unknown kwargs to match
69+
the validation behaviour when instantiated from arbitrary
70+
JSON metadata dicts (e.g. ``MVTTilesJson(**json_data)``).
71+
"""
6772

6873
tilejson: str = "3.0.0"
6974
name: Optional[str] = None
@@ -76,8 +81,19 @@ class MVTTilesJson:
7681
description: Optional[str] = None
7782
vector_layers: Optional[List[VectorLayers]] = None
7883

79-
def __post_init__(self):
80-
_validate_type(self)
84+
def __init__(self, **kwargs):
85+
for f in dc_fields(self):
86+
value = kwargs.get(f.name, getattr(self, f.name))
87+
# Coerce str to int for Optional[int] fields
88+
if (value is not None
89+
and isinstance(value, str)
90+
and 'int' in str(f.type)):
91+
try:
92+
value = int(value)
93+
except (ValueError, TypeError):
94+
pass
95+
setattr(self, f.name, value)
96+
validate_type(self)
8197

8298
def model_dump(
8399
self, exclude_none: bool = False

0 commit comments

Comments
 (0)