Skip to content

Commit e458c69

Browse files
authored
fix(queryables): adapt parameters' value before passing them to EODAG (#93)
Ensure the queryable parameter values match the expected type before passing them to the call to `dag.list_queryables(...)`. E.g. `ecmwf:data_format` must be passed as string; `ecmwf:variable` must be passes as list. A call to `dag.list_queryables(...)` with only provider and collection (if they are available) is used to assess the expected type of the parameters. The parameters given by the user are searched in the queryables by both field name and alias, e.g. `ecmwf_variable`, `ecmwf:variable` and `variable` are all searched.
1 parent 9ce6fe4 commit e458c69

3 files changed

Lines changed: 411 additions & 27 deletions

File tree

stac_fastapi/eodag/extensions/filter.py

Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
"""Get Queryables."""
1919

2020
import asyncio
21-
from typing import Any, Optional, cast
21+
from typing import Any, Literal, Optional, cast, get_args, get_origin
2222

2323
import attr
2424
from fastapi import Request
25-
from pydantic import BaseModel, ConfigDict, create_model
25+
from pydantic import AliasChoices, AliasPath, BaseModel, ConfigDict, create_model
2626
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
2727
from stac_fastapi.types.errors import NotFoundError
2828
from stac_fastapi.types.requests import get_base_url
@@ -174,23 +174,8 @@ async def get_queryables(
174174
under OGC CQL but it is allowed by the STAC API Filter Extension
175175
https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables
176176
"""
177-
params: dict[str, list[Any]] = {}
178-
for k, v in request.query_params.multi_items():
179-
params.setdefault(k, []).append(v)
180-
181-
# parameter provider is deprecated
182-
providers = params.pop("provider", [None])
183-
federation_backends = params.pop("federation:backends", [None])
177+
eodag_params = await self._get_eodag_params(request, collection_id)
184178

185-
# validate params and transform to eodag params
186-
validated_params_model = QueryablesGetParams.model_validate(
187-
{
188-
**{"provider": federation_backends[0] or providers[0], "collection": collection_id},
189-
**params,
190-
}
191-
)
192-
validated_params = validated_params_model.model_dump(exclude_none=True, by_alias=True)
193-
eodag_params = {self.stac_metadata_model.to_eodag(param): validated_params[param] for param in validated_params}
194179
# get queryables from eodag
195180
try:
196181
eodag_queryables = await asyncio.to_thread(request.app.state.dag.list_queryables, **eodag_params)
@@ -252,3 +237,95 @@ async def get_queryables(
252237
properties[pk] = pv
253238

254239
return queryables
240+
241+
async def _get_eodag_params(
242+
self,
243+
request: Request,
244+
collection_id: Optional[str] = None,
245+
) -> dict[str, Any]:
246+
"""Return the EODAG parameters from the given HTTP Request.
247+
248+
:param request: The request object.
249+
:param collection_id: The collection ID.
250+
:return: The EODAG parameters.
251+
"""
252+
params: dict[str, list[Any]] = {}
253+
for k, v in request.query_params.multi_items():
254+
params.setdefault(k, []).append(v)
255+
256+
# parameter provider is deprecated
257+
providers = params.pop("provider", [None])
258+
federation_backends = params.pop("federation:backends", [None])
259+
260+
# validate params and transform to eodag params
261+
validated_params_model = QueryablesGetParams.model_validate(
262+
{
263+
**{"provider": federation_backends[0] or providers[0], "collection": collection_id},
264+
**params,
265+
}
266+
)
267+
validated_params = validated_params_model.model_dump(exclude_none=True, by_alias=True)
268+
eodag_params = {self.stac_metadata_model.to_eodag(param): validated_params[param] for param in validated_params}
269+
270+
# the parameters in eodag_params are all lists:
271+
# adapt them to use list or primitive type according to the collection queryables
272+
eodag_params_pc = {k: eodag_params[k] for k in ["provider", "collection"] if k in eodag_params}
273+
try:
274+
eodag_queryables = await asyncio.to_thread(request.app.state.dag.list_queryables, **eodag_params_pc)
275+
except UnsupportedCollection as err:
276+
raise NotFoundError(err) from err
277+
278+
for queryables_key, annotation in eodag_queryables.items():
279+
if queryables_key in ("provider", "collection"):
280+
continue
281+
param_args = get_args(annotation)
282+
base_type = get_origin(param_args[0])
283+
if base_type is None:
284+
base_type = param_args[0]
285+
field_info = param_args[1]
286+
287+
# get the aliases of queryable_key
288+
validation_alias = field_info.validation_alias
289+
aliases: list[str]
290+
if isinstance(validation_alias, str):
291+
aliases = [queryables_key, validation_alias]
292+
elif isinstance(validation_alias, AliasChoices):
293+
# e.g. aliases == ['ecmwf_data_format', 'ecmwf:data_format', 'data_format']
294+
if any(not isinstance(c, str) for c in validation_alias.choices):
295+
# currently only choices of type string are used by EODAG
296+
raise NotImplementedError(
297+
f"Error for stac name {queryables_key}: "
298+
"only AliasChoices of type string are handled to get field aliases"
299+
)
300+
choices: list[str] = [str(c) for c in validation_alias.choices]
301+
aliases = [queryables_key, *choices]
302+
elif isinstance(validation_alias, AliasPath):
303+
# currently AliasPath is not used by EODAG
304+
raise NotImplementedError(
305+
f"Error for stac name {queryables_key}: AliasPath is not currently handled to get field aliases"
306+
)
307+
elif validation_alias is None:
308+
aliases = [queryables_key]
309+
else:
310+
raise NotImplementedError(
311+
f"Error for stac name {queryables_key}: validation alias no supported: {validation_alias}"
312+
)
313+
314+
# check if any of the aliases is in eodag_params
315+
eodag_key = next((n for n in aliases if n in eodag_params.keys()), None)
316+
if not eodag_key:
317+
# queryable_key is not in eodag_params: skip
318+
continue
319+
320+
# adapt the value
321+
if base_type in (Literal, str):
322+
if isinstance(eodag_params[eodag_key], list):
323+
# convert list to single value
324+
eodag_params[eodag_key] = eodag_params[eodag_key][-1]
325+
elif base_type in (tuple, list):
326+
if not isinstance(eodag_params[eodag_key], list):
327+
# convert single value to list
328+
eodag_params[eodag_key] = [eodag_params[eodag_key]]
329+
else:
330+
raise NotImplementedError(f"Error for stac name {queryables_key}: type not supported: {param_args[0]}")
331+
return eodag_params

tests/conftest.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
"""main conftest"""
1919

2020
import os
21+
import typing
2122
import unittest.mock
2223
from dataclasses import dataclass, field
2324
from pathlib import Path
2425
from string import ascii_uppercase
2526
from tempfile import TemporaryDirectory
26-
from typing import Any, Iterator, Optional, Union
27+
from typing import Annotated, Any, Iterator, Optional, Union
2728
from urllib.parse import urljoin
2829

2930
import pytest
@@ -43,6 +44,7 @@
4344
from eodag.utils import StreamResponse
4445
from fastapi import FastAPI
4546
from httpx import ASGITransport, AsyncClient
47+
from pydantic import AliasChoices, Field
4648

4749
from stac_fastapi.eodag.app import api, stac_metadata_model
4850
from stac_fastapi.eodag.config import get_settings
@@ -315,6 +317,174 @@ def mock_list_queryables(mocker, app):
315317
return mocker.patch.object(app.state.dag, "list_queryables")
316318

317319

320+
@pytest.fixture(scope="function")
321+
def mock_list_queryables_return_value(mock_list_queryables):
322+
"""
323+
Mocks the return value of `list_queryables` method of the `app.state.dag` object.
324+
"""
325+
return_value = {
326+
"ecmwf_area": Annotated[
327+
tuple[
328+
Annotated[
329+
float,
330+
Field(
331+
default=None,
332+
description="West border of the bounding box",
333+
ge=-180,
334+
le=180,
335+
),
336+
],
337+
Annotated[
338+
float,
339+
Field(
340+
default=None,
341+
description="South border of the bounding box",
342+
ge=-90,
343+
le=90,
344+
),
345+
],
346+
Annotated[
347+
float,
348+
Field(
349+
default=None,
350+
description="East border of the bounding box",
351+
ge=-180,
352+
le=180,
353+
),
354+
],
355+
Annotated[
356+
float,
357+
Field(
358+
default=None,
359+
description="North border of the bounding box",
360+
ge=-90,
361+
le=90,
362+
),
363+
],
364+
],
365+
Field(
366+
default=None,
367+
title="Sub-region extraction",
368+
description="Select a sub-region of the available area by"
369+
" providing its limits on latitude and longitude",
370+
validation_alias=AliasChoices("ecmwf:area", "area"),
371+
serialization_alias="ecmwf:area",
372+
alias_priority=2,
373+
),
374+
],
375+
"ecmwf_data_format": typing.Annotated[
376+
typing.Literal["grib", "netcdf_zip"],
377+
Field(
378+
...,
379+
title="Data format",
380+
validation_alias=AliasChoices("ecmwf:data_format", "data_format"),
381+
serialization_alias="ecmwf:data_format",
382+
alias_priority=2,
383+
),
384+
"json_schema_required",
385+
],
386+
"ecmwf_date": typing.Annotated[
387+
typing.Literal[
388+
"2025-03-04",
389+
"2025-03-05/2025-03-21",
390+
"2025-03-22",
391+
"2025-03-23/2025-04-20",
392+
"2025-04-21",
393+
"2025-04-22/2025-04-29",
394+
"2025-04-30",
395+
"2025-05-01/2025-06-01",
396+
"2025-05-03",
397+
"2025-05-04/2026-02-03",
398+
"2025-06-02",
399+
"2025-06-03/2025-08-09",
400+
"2025-08-10",
401+
"2025-08-11",
402+
"2025-08-12/2025-12-06",
403+
"2025-12-07",
404+
"2025-12-08/2026-02-02",
405+
"2026-02-03",
406+
"2026-02-04",
407+
],
408+
Field(
409+
...,
410+
title="Date",
411+
validation_alias=AliasChoices("ecmwf:date", "date"),
412+
serialization_alias="ecmwf:date",
413+
alias_priority=2,
414+
description="date formatted like yyyy-mm-dd/yyyy-mm-dd",
415+
),
416+
"json_schema_required",
417+
],
418+
"ecmwf_variable": typing.Annotated[
419+
list[
420+
typing.Literal[
421+
"alder_pollen",
422+
"ammonia",
423+
"birch_pollen",
424+
"carbon_monoxide",
425+
"dust",
426+
"formaldehyde",
427+
"glyoxal",
428+
"grass_pollen",
429+
"mugwort_pollen",
430+
],
431+
],
432+
Field(
433+
...,
434+
title="Variable",
435+
validation_alias=AliasChoices("ecmwf:variable", "variable"),
436+
serialization_alias="ecmwf:variable",
437+
alias_priority=2,
438+
),
439+
"json_schema_required",
440+
],
441+
"end": Annotated[
442+
str,
443+
Field(
444+
default=None,
445+
alias="end_datetime",
446+
alias_priority=2,
447+
description="Date/time as string in ISO 8601 format (e.g. '2024-06-10T12:00:00Z')",
448+
),
449+
],
450+
"start": Annotated[
451+
str,
452+
Field(
453+
default=None,
454+
alias="start_datetime",
455+
alias_priority=2,
456+
description="Date/time as string in ISO 8601 format (e.g. '2024-06-10T12:00:00Z')",
457+
),
458+
],
459+
"dolorem": Annotated[
460+
str,
461+
Field(
462+
default=None,
463+
alias="dol",
464+
alias_priority=2,
465+
),
466+
],
467+
"ipsum": Annotated[
468+
str,
469+
Field(
470+
default=None,
471+
alias="ips",
472+
alias_priority=2,
473+
),
474+
],
475+
"bar": Annotated[
476+
str,
477+
Field(
478+
default=None,
479+
alias_priority=2,
480+
),
481+
],
482+
}
483+
484+
mock_list_queryables.return_value = return_value
485+
return return_value
486+
487+
318488
@pytest.fixture(scope="function")
319489
def mock_stac_discover_queryables(mocker):
320490
"""

0 commit comments

Comments
 (0)