Skip to content

Commit a902830

Browse files
committed
update eval
1 parent 5954566 commit a902830

4 files changed

Lines changed: 26 additions & 16 deletions

File tree

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def __call__(
173173
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
174174
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
175175

176+
retriever.stop_multi_process_pool()
176177
# Reranking Stage
177178
if reranker is not None:
178179
reranker_search_results_save_dir = os.path.join(
@@ -215,6 +216,7 @@ def __call__(
215216
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
216217
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
217218
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
219+
reranker.stop_multi_process_pool()
218220

219221
@staticmethod
220222
def save_search_results(

FlagEmbedding/abc/evaluation/searcher.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def __str__(self) -> str:
2525
"""
2626
return os.path.basename(self.embedder.model.config._name_or_path)
2727

28+
def stop_multi_process_pool(self):
29+
if self.embedder.pool is not None:
30+
self.embedder.stop_multi_process_pool(self.embedder.pool)
31+
2832
@abstractmethod
2933
def __call__(
3034
self,
@@ -144,6 +148,10 @@ def __str__(self) -> str:
144148
"""
145149
return os.path.basename(self.reranker.model.config._name_or_path)
146150

151+
def stop_multi_process_pool(self):
152+
if self.reranker.pool is not None:
153+
self.reranker.stop_multi_process_pool(self.reranker.pool)
154+
147155
def __call__(
148156
self,
149157
corpus: Dict[str, Dict[str, Any]],

FlagEmbedding/evaluation/beir/evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def __call__(
145145
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
146146
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
147147

148+
retriever.stop_multi_process_pool()
149+
148150
# Reranking Stage
149151
if reranker is not None:
150152
reranker_search_results_save_dir = os.path.join(
@@ -314,7 +316,8 @@ def __call__(
314316
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
315317
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
316318
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
317-
319+
if reranker is not None:
320+
reranker.stop_multi_process_pool()
318321
def evaluate_results(
319322
self,
320323
search_results_save_dir: str,

FlagEmbedding/evaluation/msmarco/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55
MSMARCOEvalRunner
66
)
77

8-
def main():
9-
parser = HfArgumentParser((
10-
MSMARCOEvalArgs,
11-
MSMARCOEvalModelArgs
12-
))
138

14-
eval_args, model_args = parser.parse_args_into_dataclasses()
15-
eval_args: MSMARCOEvalArgs
16-
model_args: MSMARCOEvalModelArgs
9+
parser = HfArgumentParser((
10+
MSMARCOEvalArgs,
11+
MSMARCOEvalModelArgs
12+
))
1713

18-
runner = MSMARCOEvalRunner(
19-
eval_args=eval_args,
20-
model_args=model_args
21-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: MSMARCOEvalArgs
16+
model_args: MSMARCOEvalModelArgs
2217

23-
runner.run()
18+
runner = MSMARCOEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2422

25-
if __name__ == "__main__":
26-
main()
23+
runner.run()

0 commit comments

Comments
 (0)