File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -173,6 +173,7 @@ def __call__(
173173 retriever_eval_results = self .evaluate_results (no_reranker_search_results_save_dir , k_values = k_values )
174174 self .output_eval_results_to_json (retriever_eval_results , eval_results_save_path )
175175
176+ retriever .stop_multi_process_pool ()
176177 # Reranking Stage
177178 if reranker is not None :
178179 reranker_search_results_save_dir = os .path .join (
@@ -215,6 +216,7 @@ def __call__(
215216 eval_results_save_path = os .path .join (reranker_search_results_save_dir , 'EVAL' , 'eval_results.json' )
216217 reranker_eval_results = self .evaluate_results (reranker_search_results_save_dir , k_values = k_values )
217218 self .output_eval_results_to_json (reranker_eval_results , eval_results_save_path )
219+ reranker .stop_multi_process_pool ()
218220
219221 @staticmethod
220222 def save_search_results (
Original file line number Diff line number Diff line change @@ -25,6 +25,10 @@ def __str__(self) -> str:
2525 """
2626 return os .path .basename (self .embedder .model .config ._name_or_path )
2727
28+ def stop_multi_process_pool (self ):
29+ if self .embedder .pool is not None :
30+ self .embedder .stop_multi_process_pool (self .embedder .pool )
31+
2832 @abstractmethod
2933 def __call__ (
3034 self ,
@@ -144,6 +148,10 @@ def __str__(self) -> str:
144148 """
145149 return os .path .basename (self .reranker .model .config ._name_or_path )
146150
151+ def stop_multi_process_pool (self ):
152+ if self .reranker .pool is not None :
153+ self .reranker .stop_multi_process_pool (self .reranker .pool )
154+
147155 def __call__ (
148156 self ,
149157 corpus : Dict [str , Dict [str , Any ]],
Original file line number Diff line number Diff line change @@ -145,6 +145,8 @@ def __call__(
145145 retriever_eval_results = self .evaluate_results (no_reranker_search_results_save_dir , k_values = k_values )
146146 self .output_eval_results_to_json (retriever_eval_results , eval_results_save_path )
147147
148+ retriever .stop_multi_process_pool ()
149+
148150 # Reranking Stage
149151 if reranker is not None :
150152 reranker_search_results_save_dir = os .path .join (
@@ -314,7 +316,8 @@ def __call__(
314316 eval_results_save_path = os .path .join (reranker_search_results_save_dir , 'EVAL' , 'eval_results.json' )
315317 reranker_eval_results = self .evaluate_results (reranker_search_results_save_dir , k_values = k_values )
316318 self .output_eval_results_to_json (reranker_eval_results , eval_results_save_path )
317-
319+ if reranker is not None :
320+ reranker .stop_multi_process_pool ()
318321 def evaluate_results (
319322 self ,
320323 search_results_save_dir : str ,
Original file line number Diff line number Diff line change 55 MSMARCOEvalRunner
66)
77
8- def main ():
9- parser = HfArgumentParser ((
10- MSMARCOEvalArgs ,
11- MSMARCOEvalModelArgs
12- ))
138
14- eval_args , model_args = parser .parse_args_into_dataclasses ()
15- eval_args : MSMARCOEvalArgs
16- model_args : MSMARCOEvalModelArgs
9+ parser = HfArgumentParser ((
10+ MSMARCOEvalArgs ,
11+ MSMARCOEvalModelArgs
12+ ))
1713
18- runner = MSMARCOEvalRunner (
19- eval_args = eval_args ,
20- model_args = model_args
21- )
14+ eval_args , model_args = parser .parse_args_into_dataclasses ()
15+ eval_args : MSMARCOEvalArgs
16+ model_args : MSMARCOEvalModelArgs
2217
23- runner .run ()
18+ runner = MSMARCOEvalRunner (
19+ eval_args = eval_args ,
20+ model_args = model_args
21+ )
2422
25- if __name__ == "__main__" :
26- main ()
23+ runner .run ()
You can’t perform that action at this time.
0 commit comments