@@ -437,15 +437,6 @@ def _can_substitute(item: Function) -> bool:
437437 """Returns whether the specified function can be replaced by this class"""
438438 raise NotImplementedError ()
439439
440-
441- class Coroutine (PytestAsyncioFunction ):
442- """Pytest item created by a coroutine"""
443-
444- @staticmethod
445- def _can_substitute (item : Function ) -> bool :
446- func = item .obj
447- return inspect .iscoroutinefunction (func )
448-
449440 def runtest (self ) -> None :
450441 marker = self .get_closest_marker ("asyncio" )
451442 assert marker is not None
@@ -454,11 +445,33 @@ def runtest(self) -> None:
454445 runner_fixture_id = f"_{ loop_scope } _scoped_runner"
455446 runner = self ._request .getfixturevalue (runner_fixture_id )
456447 context = contextvars .copy_context ()
457- synchronized_obj = wrap_in_sync (self .obj , runner , context )
448+ synchronized_obj = wrap_in_sync (
449+ getattr (* self ._synchronization_target_attr ), runner , context
450+ )
458451 with MonkeyPatch .context () as c :
459- c .setattr (self , "obj" , synchronized_obj )
452+ c .setattr (* self . _synchronization_target_attr , synchronized_obj )
460453 super ().runtest ()
461454
455+ @property
456+ def _synchronization_target_attr (self ) -> tuple [object , str ]:
457+ """
458+ Return the coroutine that needs to be synchronized during the test run.
459+
460+ This method is inteded to be overwritten by subclasses when they need to apply
461+ the coroutine synchronizer to a value that's different from self.obj
462+ e.g. the AsyncHypothesisTest subclass.
463+ """
464+ return self , "obj"
465+
466+
467+ class Coroutine (PytestAsyncioFunction ):
468+ """Pytest item created by a coroutine"""
469+
470+ @staticmethod
471+ def _can_substitute (item : Function ) -> bool :
472+ func = item .obj
473+ return inspect .iscoroutinefunction (func )
474+
462475
463476class AsyncGenerator (PytestAsyncioFunction ):
464477 """Pytest item created by an asynchronous generator"""
@@ -495,19 +508,6 @@ def _can_substitute(item: Function) -> bool:
495508 func .__func__
496509 )
497510
498- def runtest (self ) -> None :
499- marker = self .get_closest_marker ("asyncio" )
500- assert marker is not None
501- default_loop_scope = _get_default_test_loop_scope (self .config )
502- loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
503- runner_fixture_id = f"_{ loop_scope } _scoped_runner"
504- runner = self ._request .getfixturevalue (runner_fixture_id )
505- context = contextvars .copy_context ()
506- synchronized_obj = wrap_in_sync (self .obj , runner , context = context )
507- with MonkeyPatch .context () as c :
508- c .setattr (self , "obj" , synchronized_obj )
509- super ().runtest ()
510-
511511
512512class AsyncHypothesisTest (PytestAsyncioFunction ):
513513 """
@@ -524,18 +524,9 @@ def _can_substitute(item: Function) -> bool:
524524 and inspect .iscoroutinefunction (func .hypothesis .inner_test )
525525 )
526526
527- def runtest (self ) -> None :
528- marker = self .get_closest_marker ("asyncio" )
529- assert marker is not None
530- default_loop_scope = _get_default_test_loop_scope (self .config )
531- loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
532- runner_fixture_id = f"_{ loop_scope } _scoped_runner"
533- runner = self ._request .getfixturevalue (runner_fixture_id )
534- context = contextvars .copy_context ()
535- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test , runner , context )
536- with MonkeyPatch .context () as c :
537- c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
538- super ().runtest ()
527+ @property
528+ def _synchronization_target_attr (self ) -> tuple [object , str ]:
529+ return self .obj .hypothesis , "inner_test"
539530
540531
541532# The function name needs to start with "pytest_"
0 commit comments