Skip to content

Commit 61ab7e0

Browse files
committed
update mteb eval
1 parent 5951aa6 commit 61ab7e0

5 files changed

Lines changed: 25 additions & 20 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ def encode(
264264
return embeddings
265265

266266
def __del__(self):
267-
if self.pool is not None:
268-
self.stop_multi_process_pool(self.pool)
267+
self.stop_self_pool()
269268

270269
@abstractmethod
271270
def encode_single_device(

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def compute_score(
210210
return scores
211211

212212
def __del__(self):
213-
if self.pool is not None:
214-
self.stop_multi_process_pool(self.pool)
213+
self.stop_self_pool()
215214

216215
@abstractmethod
217216
def compute_score_single_gpu(

FlagEmbedding/evaluation/mteb/runner.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,13 @@ def run(self):
114114
task_types=task_types
115115
)
116116
output_folder = self.eval_args.output_dir
117-
new_tasks = []
118-
for task in tasks:
119-
if task.languages is not None:
120-
if len(task.languages) == len([e for e in languages if e in task.languages]):
121-
new_tasks.append(task)
122117

123-
for task in new_tasks:
118+
for task in tasks:
124119
task_name = task.metadata.name
125120
task_type = task.metadata.type
126121

122+
self.retriever.stop_pool()
123+
127124
if self.eval_args.use_special_instructions:
128125
try:
129126
instruction = get_task_def_by_task_name_and_type(task_name, task_type)

FlagEmbedding/evaluation/mteb/searcher.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ def get_instruction(self):
2020
def set_normalize_embeddings(self, normalize_embeddings: bool = True):
2121
self.embedder.normalize_embeddings = normalize_embeddings
2222

23+
def stop_pool(self):
24+
self.embedder.stop_self_pool()
25+
try:
26+
self.embedder.stop_self_query_pool()
27+
except:
28+
pass
29+
2330
def encode_queries(self, queries: List[str], **kwargs):
2431
emb = self.embedder.encode_queries(queries)
2532
if isinstance(emb, dict):

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import queue
55
from multiprocessing import Queue
66

7+
import gc
78
import torch
89
import numpy as np
910
from transformers import AutoModel, AutoTokenizer
@@ -121,10 +122,8 @@ def __init__(
121122
self.query_pool = None
122123

123124
def __del__(self):
124-
if self.pool is not None:
125-
self.stop_multi_process_pool(self.pool)
126-
if self.query_pool is not None:
127-
self.stop_multi_process_pool(self.query_pool)
125+
self.stop_self_pool()
126+
self.stop_self_query_pool()
128127

129128
def set_examples(self, examples_for_task: Optional[List[dict]] = None):
130129
"""Set the prefix to the provided examples.
@@ -175,6 +174,14 @@ def get_detailed_example(instruction_format: str, instruction: str, query: str,
175174
"""
176175
return instruction_format.format(instruction, query, response)
177176

177+
def stop_self_query_pool(self):
178+
if self.query_pool is not None:
179+
self.stop_multi_process_pool(self.query_pool)
180+
self.query_pool = None
181+
self.model.to('cpu')
182+
gc.collect()
183+
torch.cuda.empty_cache()
184+
178185
def encode_queries(
179186
self,
180187
queries: Union[List[str], str],
@@ -209,9 +216,7 @@ def encode_queries(
209216
**kwargs
210217
)
211218

212-
if self.pool is not None:
213-
self.stop_multi_process_pool(self.pool)
214-
self.pool = None
219+
self.stop_self_pool()
215220
if self.query_pool is None:
216221
self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
217222
embeddings = self.encode_multi_process(
@@ -244,9 +249,7 @@ def encode_corpus(
244249
Returns:
245250
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
246251
"""
247-
if self.query_pool is not None:
248-
self.stop_multi_process_pool(self.query_pool)
249-
self.query_pool = None
252+
self.stop_self_query_pool()
250253
return super().encode_corpus(
251254
corpus,
252255
batch_size=batch_size,

0 commit comments

Comments
 (0)