1+ import hashlib
12import os
23import platform
34import re
5+ import socket
46import sys
7+ import threading
58import time
69import traceback
710import warnings
11+ from contextlib import suppress
812
913import pytest
1014from _pytest .outcomes import fail
2327 # We have a pytest >= 6.1
2428 pass
2529
30+ try :
31+ from xdist .newhooks import pytest_handlecrashitem
2632
27- PYTEST_GTE_54 = parse_version (pytest .__version__ ) >= parse_version ("5.4" )
33+ HAS_PYTEST_HANDLECRASHITEM = True
34+ del pytest_handlecrashitem
35+ except ImportError :
36+ HAS_PYTEST_HANDLECRASHITEM = False
2837
38+
39+ PYTEST_GTE_54 = parse_version (pytest .__version__ ) >= parse_version ("5.4" )
2940PYTEST_GTE_63 = parse_version (pytest .__version__ ) >= parse_version ("6.3.0.dev" )
3041
3142
@@ -78,16 +89,6 @@ def pytest_addoption(parser):
7889 )
7990
8091
81- def pytest_configure (config ):
82- # add flaky marker
83- config .addinivalue_line (
84- "markers" ,
85- "flaky(reruns=1, reruns_delay=0): mark test to re-run up "
86- "to 'reruns' times. Add a delay of 'reruns_delay' seconds "
87- "between re-runs." ,
88- )
89-
90-
9192def _get_resultlog (config ):
9293 if not HAS_RESULTLOG :
9394 return None
@@ -302,6 +303,167 @@ def _should_not_rerun(item, report, reruns):
302303 )
303304
304305
306+ def is_master (config ):
307+ return not (hasattr (config , "workerinput" ) or hasattr (config , "slaveinput" ))
308+
309+
310+ def pytest_configure (config ):
311+ # add flaky marker
312+ config .addinivalue_line (
313+ "markers" ,
314+ "flaky(reruns=1, reruns_delay=0): mark test to re-run up "
315+ "to 'reruns' times. Add a delay of 'reruns_delay' seconds "
316+ "between re-runs." ,
317+ )
318+
319+ if HAS_PYTEST_HANDLECRASHITEM :
320+ if is_master (config ):
321+ config .failures_db = ServerStatusDB ()
322+ else :
323+ config .failures_db = ClientStatusDB (config .workerinput ["sock_port" ])
324+ else :
325+ config .failures_db = StatusDB () # no-op db
326+
327+
328+ if HAS_PYTEST_HANDLECRASHITEM :
329+
330+ def pytest_configure_node (node ):
331+ """xdist hook"""
332+ node .workerinput ["sock_port" ] = node .config .failures_db .sock_port
333+
334+ def pytest_handlecrashitem (crashitem , report , sched ):
335+ """
336+ Return the crashitem from pending and collection.
337+ """
338+ db = sched .config .failures_db
339+ reruns = db .get_test_reruns (crashitem )
340+ if db .get_test_failures (crashitem ) < reruns :
341+ sched .mark_test_pending (crashitem )
342+ report .outcome = "rerun"
343+
344+ db .add_test_failure (crashitem )
345+
346+
347+ # An in-memory db residing in the master that records
348+ # the number of reruns (set before test setup)
349+ # and failures (set after each failure or crash)
350+ # accessible from both the master and worker
351+ class StatusDB :
352+ def __init__ (self ):
353+ self .delim = b"\n "
354+ self .hmap = {}
355+
356+ def _hash (self , crashitem : str ) -> str :
357+ if crashitem not in self .hmap :
358+ self .hmap [crashitem ] = hashlib .sha1 (
359+ crashitem .encode (),
360+ ).hexdigest ()[:10 ]
361+
362+ return self .hmap [crashitem ]
363+
364+ def add_test_failure (self , crashitem ):
365+ hash = self ._hash (crashitem )
366+ failures = self ._get (hash , "f" )
367+ failures += 1
368+ self ._set (hash , "f" , failures )
369+
370+ def get_test_failures (self , crashitem ):
371+ hash = self ._hash (crashitem )
372+ return self ._get (hash , "f" )
373+
374+ def set_test_reruns (self , crashitem , reruns ):
375+ hash = self ._hash (crashitem )
376+ self ._set (hash , "r" , reruns )
377+
378+ def get_test_reruns (self , crashitem ):
379+ hash = self ._hash (crashitem )
380+ return self ._get (hash , "r" )
381+
382+ # i is a hash of the test name, t_f.py::test_t
383+ # k is f for failures or r for reruns
384+ # v is the number of failures or reruns (an int)
385+ def _set (self , i : str , k : str , v : int ):
386+ pass
387+
388+ def _get (self , i : str , k : str ) -> int :
389+ return 0
390+
391+
392+ class SocketDB (StatusDB ):
393+ def __init__ (self ):
394+ super ().__init__ ()
395+ self .sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
396+ self .sock .setblocking (1 )
397+
398+ def _sock_recv (self , conn ) -> str :
399+ buf = b""
400+ while True :
401+ b = conn .recv (1 )
402+ if b == self .delim :
403+ break
404+ buf += b
405+
406+ return buf .decode ()
407+
408+ def _sock_send (self , conn , msg : str ):
409+ conn .send (msg .encode () + self .delim )
410+
411+
412+ class ServerStatusDB (SocketDB ):
413+ def __init__ (self ):
414+ super ().__init__ ()
415+ self .sock .bind (("" , 0 ))
416+ self .sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
417+
418+ self .rerunfailures_db = {}
419+ t = threading .Thread (target = self .run_server , daemon = True )
420+ t .start ()
421+
422+ @property
423+ def sock_port (self ):
424+ return self .sock .getsockname ()[1 ]
425+
426+ def run_server (self ):
427+ self .sock .listen ()
428+ while True :
429+ conn , _ = self .sock .accept ()
430+ t = threading .Thread (target = self .run_connection , args = (conn ,), daemon = True )
431+ t .start ()
432+
433+ def run_connection (self , conn ):
434+ with suppress (ConnectionError ):
435+ while True :
436+ op , i , k , v = self ._sock_recv (conn ).split ("|" )
437+ if op == "set" :
438+ self ._set (i , k , int (v ))
439+ elif op == "get" :
440+ self ._sock_send (conn , str (self ._get (i , k )))
441+
442+ def _set (self , i : str , k : str , v : int ):
443+ if i not in self .rerunfailures_db :
444+ self .rerunfailures_db [i ] = {}
445+ self .rerunfailures_db [i ][k ] = v
446+
447+ def _get (self , i : str , k : str ) -> int :
448+ try :
449+ return self .rerunfailures_db [i ][k ]
450+ except KeyError :
451+ return 0
452+
453+
454+ class ClientStatusDB (SocketDB ):
455+ def __init__ (self , sock_port ):
456+ super ().__init__ ()
457+ self .sock .connect (("localhost" , sock_port ))
458+
459+ def _set (self , i : str , k : str , v : int ):
460+ self ._sock_send (self .sock , "|" .join (("set" , i , k , str (v ))))
461+
462+ def _get (self , i : str , k : str ) -> int :
463+ self ._sock_send (self .sock , "|" .join (("get" , i , k , "" )))
464+ return int (self ._sock_recv (self .sock ))
465+
466+
305467def pytest_runtest_protocol (item , nextitem ):
306468 """
307469 Run the test protocol.
@@ -319,8 +481,14 @@ def pytest_runtest_protocol(item, nextitem):
319481 # first item if necessary
320482 check_options (item .session .config )
321483 delay = get_reruns_delay (item )
322- parallel = hasattr (item .config , "slaveinput" ) or hasattr (item .config , "workerinput" )
323- item .execution_count = 0
484+ parallel = not is_master (item .config )
485+ item_location = (item .location [0 ] + "::" + item .location [2 ]).replace ("\\ " , "/" )
486+ db = item .session .config .failures_db
487+ item .execution_count = db .get_test_failures (item_location )
488+ db .set_test_reruns (item_location , reruns )
489+
490+ if item .execution_count > reruns :
491+ return True
324492
325493 need_to_run = True
326494 while need_to_run :
0 commit comments