Skip to content

Commit 2970a69

Browse files
authored
fix: run eodag using asyncio.to_thread (#97)
1 parent 6d538d1 commit 2970a69

2 files changed

Lines changed: 27 additions & 19 deletions

File tree

stac_fastapi/eodag/core.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@
8181

8282
logger = logging.getLogger(__name__)
8383

84-
loop = asyncio.get_event_loop()
85-
8684

8785
@attr.s
8886
class EodagCoreClient(CustomCoreClient):
@@ -166,7 +164,7 @@ def has_ecmwf_search_plugin(federation_backends, request):
166164

167165
return Collection(**extended_collection)
168166

169-
def _search_base(self, search_request: BaseSearchPostRequest, request: Request) -> ItemCollection:
167+
async def _search_base(self, search_request: BaseSearchPostRequest, request: Request) -> ItemCollection:
170168
eodag_args = prepare_search_base_args(search_request=search_request, model=self.stac_metadata_model)
171169

172170
request.state.eodag_args = eodag_args
@@ -177,7 +175,7 @@ def _search_base(self, search_request: BaseSearchPostRequest, request: Request)
177175

178176
# check if the collection exists
179177
if collection := eodag_args.get("collection"):
180-
all_coll = request.app.state.dag.list_collections(fetch_providers=False)
178+
all_coll = await asyncio.to_thread(request.app.state.dag.list_collections, fetch_providers=False)
181179
# only check the first collection (EODAG search only support a single collection)
182180
existing_coll = [coll for coll in all_coll if coll.id == collection]
183181
if not existing_coll:
@@ -191,19 +189,20 @@ def _search_base(self, search_request: BaseSearchPostRequest, request: Request)
191189
search_result = SearchResult([])
192190
for item_id in ids:
193191
eodag_args["id"] = item_id
194-
search_result.extend(request.app.state.dag.search(validate=validate, **eodag_args))
192+
result = await asyncio.to_thread(request.app.state.dag.search, validate=validate, **eodag_args)
193+
search_result.extend(result)
195194
search_result.number_matched = len(search_result)
196195
elif eodag_args.get("token") and eodag_args.get("provider"):
197196
# search with pagination
198-
search_result = eodag_search_next_page(request.app.state.dag, eodag_args)
197+
search_result = await asyncio.to_thread(eodag_search_next_page, request.app.state.dag, eodag_args)
199198
else:
200199
# search without ids or pagination
201-
search_result = request.app.state.dag.search(validate=validate, **eodag_args)
200+
search_result = await asyncio.to_thread(request.app.state.dag.search, validate=validate, **eodag_args)
202201

203202
if search_result.errors and not len(search_result):
204203
raise ResponseSearchError(search_result.errors, self.stac_metadata_model)
205204

206-
request_json = loop.run_until_complete(request.json()) if request.method == "POST" else None
205+
request_json = await request.json() if request.method == "POST" else None
207206

208207
features: list[Item] = []
209208
extension_names = [type(ext).__name__ for ext in self.extensions]
@@ -268,7 +267,9 @@ async def all_collections(
268267
provider = parsed_query.get("federation:backends")
269268
provider = provider[0] if isinstance(provider, list) else provider
270269

271-
all_colls = request.app.state.dag.list_collections(provider=provider, fetch_providers=False)
270+
all_colls = await asyncio.to_thread(
271+
request.app.state.dag.list_collections, provider=provider, fetch_providers=False
272+
)
272273

273274
# datetime & free-text-search filters
274275
if any((q, datetime)):
@@ -279,8 +280,11 @@ async def all_collections(
279280
free_text = " AND ".join(q or [])
280281

281282
try:
282-
guessed_collections = request.app.state.dag.guess_collection(
283-
free_text=free_text, start_date=start, end_date=end
283+
guessed_collections = await asyncio.to_thread(
284+
request.app.state.dag.guess_collection,
285+
free_text=free_text,
286+
start_date=start,
287+
end_date=end,
284288
)
285289
guessed_collections_ids = [coll.id for coll in guessed_collections]
286290
except EodagNoMatchingCollection:
@@ -359,8 +363,9 @@ async def get_collection(self, collection_id: str, request: Request, **kwargs: A
359363
:returns: The collection.
360364
:raises NotFoundError: If the collection does not exist.
361365
"""
366+
all_collections = await asyncio.to_thread(request.app.state.dag.list_collections, fetch_providers=False)
362367
collection = next(
363-
(c for c in request.app.state.dag.list_collections(fetch_providers=False) if c.id == collection_id),
368+
(c for c in all_collections if c.id == collection_id),
364369
None,
365370
)
366371
if collection is None:
@@ -411,15 +416,17 @@ async def item_collection(
411416
)
412417

413418
search_request = self.post_request_model.model_validate(clean)
414-
item_collection = self._search_base(search_request, request)
419+
item_collection = await self._search_base(search_request, request)
415420
extension_names = [type(ext).__name__ for ext in self.extensions]
416421
links = ItemCollectionLinks(collection_id=collection_id, request=request).get_links(
417422
extensions=extension_names, extra_links=item_collection["links"]
418423
)
419424
item_collection["links"] = links
420425
return item_collection
421426

422-
def post_search(self, search_request: BaseSearchPostRequest, request: Request, **kwargs: Any) -> ItemCollection:
427+
async def post_search(
428+
self, search_request: BaseSearchPostRequest, request: Request, **kwargs: Any
429+
) -> ItemCollection:
423430
"""
424431
Handle POST search requests.
425432
@@ -428,9 +435,9 @@ def post_search(self, search_request: BaseSearchPostRequest, request: Request, *
428435
:param kwargs: Additional keyword arguments.
429436
:returns: Found items.
430437
"""
431-
return self._search_base(search_request, request)
438+
return await self._search_base(search_request, request)
432439

433-
def get_search(
440+
async def get_search(
434441
self,
435442
request: Request,
436443
collections: Optional[list[str]] = None,
@@ -489,7 +496,7 @@ def get_search(
489496
except ValidationError as err:
490497
raise HTTPException(status_code=400, detail=f"Invalid parameters provided {err}") from err
491498

492-
return self._search_base(search_request, request)
499+
return await self._search_base(search_request, request)
493500

494501
async def get_item(self, item_id: str, collection_id: str, request: Request, **kwargs: Any) -> Item:
495502
"""
@@ -506,7 +513,7 @@ async def get_item(self, item_id: str, collection_id: str, request: Request, **k
506513
await self.get_collection(collection_id, request=request)
507514

508515
search_request = self.post_request_model(ids=[item_id], collections=[collection_id], limit=1)
509-
item_collection = self._search_base(search_request, request)
516+
item_collection = await self._search_base(search_request, request)
510517
if not item_collection["features"]:
511518
raise NotFoundError(f"Item {item_id} in Collection {collection_id} does not exist.")
512519

stac_fastapi/eodag/extensions/filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# limitations under the License.
1818
"""Get Queryables."""
1919

20+
import asyncio
2021
from typing import Any, Optional, cast
2122

2223
import attr
@@ -192,7 +193,7 @@ async def get_queryables(
192193
eodag_params = {self.stac_metadata_model.to_eodag(param): validated_params[param] for param in validated_params}
193194
# get queryables from eodag
194195
try:
195-
eodag_queryables = request.app.state.dag.list_queryables(**eodag_params)
196+
eodag_queryables = await asyncio.to_thread(request.app.state.dag.list_queryables, **eodag_params)
196197
except UnsupportedCollection as err:
197198
raise NotFoundError(err) from err
198199

0 commit comments

Comments
 (0)