Skip to content

Commit d681574

Browse files
committed
update stop pool
1 parent 1b971d0 commit d681574

10 files changed

Lines changed: 90 additions & 50 deletions

File tree

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,11 @@ def __call__(
210210
dataset_name=dataset_name,
211211
)
212212
no_reranker_search_results_dict[split] = search_results
213+
retriever.stop_multi_process_pool()
213214
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
214215
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
215216
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
216217

217-
retriever.stop_multi_process_pool()
218218
# Reranking Stage
219219
if reranker is not None:
220220
reranker_search_results_save_dir = os.path.join(
@@ -254,10 +254,10 @@ def __call__(
254254
split=split,
255255
dataset_name=dataset_name,
256256
)
257+
reranker.stop_multi_process_pool()
257258
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
258259
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
259260
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
260-
reranker.stop_multi_process_pool()
261261

262262
@staticmethod
263263
def save_search_results(

FlagEmbedding/abc/evaluation/searcher.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ def __str__(self) -> str:
3131
return os.path.basename(self.embedder.model.config._name_or_path)
3232

3333
def stop_multi_process_pool(self):
34-
if self.embedder.pool is not None:
35-
self.embedder.stop_multi_process_pool(self.embedder.pool)
36-
self.embedder.pool = None
37-
self.embedder.model.to('cpu')
38-
gc.collect()
39-
torch.cuda.empty_cache()
34+
self.embedder.stop_self_pool()
35+
# if self.embedder.pool is not None:
36+
# self.embedder.stop_multi_process_pool(self.embedder.pool)
37+
# self.embedder.pool = None
38+
# self.embedder.model.to('cpu')
39+
# gc.collect()
40+
# torch.cuda.empty_cache()
4041

4142
@abstractmethod
4243
def __call__(
@@ -168,12 +169,13 @@ def __str__(self) -> str:
168169
return os.path.basename(self.reranker.model.config._name_or_path)
169170

170171
def stop_multi_process_pool(self):
171-
if self.reranker.pool is not None:
172-
self.reranker.stop_multi_process_pool(self.reranker.pool)
173-
self.reranker.pool = None
174-
self.reranker.model.to('cpu')
175-
gc.collect()
176-
torch.cuda.empty_cache()
172+
self.reranker.stop_self_pool()
173+
# if self.reranker.pool is not None:
174+
# self.reranker.stop_multi_process_pool(self.reranker.pool)
175+
# self.reranker.pool = None
176+
# self.reranker.model.to('cpu')
177+
# gc.collect()
178+
# torch.cuda.empty_cache()
177179

178180
def __call__(
179181
self,

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from multiprocessing import Queue
99

1010
import math
11+
import gc
1112
import torch
1213
import numpy as np
1314
from transformers import is_torch_npu_available
@@ -53,6 +54,7 @@ def __init__(
5354
convert_to_numpy: bool = True,
5455
**kwargs: Any,
5556
):
57+
query_instruction_format = query_instruction_format.replace('\\n', '\n')
5658
self.model_name_or_path = model_name_or_path
5759
self.normalize_embeddings = normalize_embeddings
5860
self.use_fp16 = use_fp16
@@ -74,6 +76,14 @@ def __init__(
7476
self.tokenizer = None
7577
self.model = None
7678
self.pool = None
79+
80+
def stop_self_pool(self):
81+
if self.pool is not None:
82+
self.stop_multi_process_pool(self.pool)
83+
self.pool = None
84+
self.model.to('cpu')
85+
gc.collect()
86+
torch.cuda.empty_cache()
7787

7888
@staticmethod
7989
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
@@ -355,6 +365,7 @@ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"],
355365

356366
pool["input"].close()
357367
pool["output"].close()
368+
pool = None
358369

359370
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
360371
def encode_multi_process(

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from multiprocessing import Queue
77

88
import math
9+
import gc
910
import torch
1011
import numpy as np
1112
from tqdm import tqdm, trange
@@ -77,6 +78,14 @@ def __init__(
7778
self.tokenizer = None
7879
self.pool = None
7980

81+
def stop_self_pool(self):
82+
if self.pool is not None:
83+
self.stop_multi_process_pool(self.pool)
84+
self.pool = None
85+
self.model.to('cpu')
86+
gc.collect()
87+
torch.cuda.empty_cache()
88+
8089
@staticmethod
8190
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
8291
"""

FlagEmbedding/evaluation/beir/evaluator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,11 @@ def __call__(
141141
sub_dataset_name=sub_dataset_name,
142142
)
143143
no_reranker_search_results_dict[split] = search_results
144+
retriever.stop_multi_process_pool()
144145
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
145146
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
146147
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
147148

148-
retriever.stop_multi_process_pool()
149-
150149
# Reranking Stage
151150
if reranker is not None:
152151
reranker_search_results_save_dir = os.path.join(

FlagEmbedding/evaluation/mteb/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17+
def ensure_dir(file_path):
18+
directory = os.path.dirname(file_path)
19+
if not os.path.exists(directory):
20+
os.makedirs(directory)
1721

1822
class MTEBEvalRunner(AbsEvalRunner):
1923
def __init__(
@@ -147,6 +151,7 @@ def run(self):
147151
evaluation = mteb.MTEB(tasks=[task])
148152
results = evaluation.run(self.retriever, output_folder=f"{output_folder}/{str(self.retriever)}")
149153

154+
ensure_dir(self.eval_args.eval_output_path)
150155
logger.info("Start computing metrics. Only save results as json.")
151156
tasks_results = self.read_results(f"{output_folder}/{str(self.retriever)}/no_model_name_available/no_revision_available", new_tasks)
152157
self.output_json(tasks_results, self.eval_args.eval_output_path)

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from tqdm import tqdm
1+
from tqdm import tqdm, trange
22
from typing import cast, Any, List, Union, Optional
33

44
import torch
@@ -224,7 +224,7 @@ def encode_single_device(
224224

225225
# tokenize without padding to get the correct length
226226
all_inputs = []
227-
for start_index in range(0, len(sentences), batch_size):
227+
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
228228
sentences_batch = sentences[start_index:start_index + batch_size]
229229
inputs_batch = self.tokenizer(
230230
sentences_batch,

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from tqdm import tqdm
1+
from tqdm import tqdm, trange
22
from typing import cast, Any, List, Union, Optional
33

44
import queue
@@ -69,6 +69,7 @@ def __init__(
6969
use_fp16: bool = True,
7070
query_instruction_for_retrieval: Optional[str] = None,
7171
query_instruction_format: str = "<instruct>{}\n<query>{}", # specify the format of query_instruction_for_retrieval
72+
suffix: str = '\n<response>',
7273
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
7374
# Additional parameters for ICLLLMEmbedder
7475
examples_for_task: Optional[List[dict]] = None,
@@ -82,6 +83,8 @@ def __init__(
8283
convert_to_numpy: bool = True,
8384
**kwargs: Any,
8485
):
86+
query_instruction_format = query_instruction_format.replace('\\n', '\n')
87+
examples_instruction_format = examples_instruction_format.replace('\\n', '\n')
8588
super().__init__(
8689
model_name_or_path,
8790
normalize_embeddings=normalize_embeddings,
@@ -113,7 +116,15 @@ def __init__(
113116
raise ValueError("Pooling method must be 'last_token' for LLM-based models.")
114117

115118
self.set_examples()
116-
self.suffix = '\n<response>'
119+
self.suffix = suffix
120+
121+
self.query_pool = None
122+
123+
def __del__(self):
124+
if self.pool is not None:
125+
self.stop_multi_process_pool(self.pool)
126+
if self.query_pool is not None:
127+
self.stop_multi_process_pool(self.query_pool)
117128

118129
def set_examples(self, examples_for_task: Optional[List[dict]] = None):
119130
"""Set the prefix to the provided examples.
@@ -198,16 +209,19 @@ def encode_queries(
198209
**kwargs
199210
)
200211

201-
pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
212+
if self.pool is not None:
213+
self.stop_multi_process_pool(self.pool)
214+
self.pool = None
215+
if self.query_pool is None:
216+
self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
202217
embeddings = self.encode_multi_process(
203218
queries,
204-
pool,
219+
self.query_pool,
205220
batch_size=batch_size,
206221
max_length=max_length,
207222
convert_to_numpy=convert_to_numpy,
208223
**kwargs
209224
)
210-
self.stop_multi_process_pool(pool)
211225
return embeddings
212226

213227
def encode_corpus(
@@ -230,6 +244,9 @@ def encode_corpus(
230244
Returns:
231245
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
232246
"""
247+
if self.query_pool is not None:
248+
self.stop_multi_process_pool(self.query_pool)
249+
self.query_pool = None
233250
return super().encode_corpus(
234251
corpus,
235252
batch_size=batch_size,
@@ -338,16 +355,27 @@ def encode_queries_single_device(
338355
suffix_ids = self.tokenizer(self.suffix, add_special_tokens=False)['input_ids']
339356

340357
_len_1 = len(self.tokenizer('<s>', add_special_tokens=False)['input_ids'])
341-
_len_2 = len(self.tokenizer('\n<response></s>', add_special_tokens=False)['input_ids'])
358+
_len_2 = len(self.tokenizer(f'{self.suffix}</s>', add_special_tokens=False)['input_ids'])
359+
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length + 8) // 8 * 8 + 8
342360

343361
# tokenize without padding to get the correct length
344362
all_inputs = []
345-
for start_index in range(0, len(input_texts), batch_size):
363+
for start_index in trange(0, len(input_texts), batch_size, desc='pre tokenize'):
346364
sentences_batch = input_texts[start_index:start_index + batch_size]
347365
inputs_batch = self.tokenizer(
348366
sentences_batch,
349367
truncation=True,
350-
max_length=max_length,
368+
max_length=max_length - _len_1 - _len_2,
369+
add_special_tokens=False,
370+
**kwargs
371+
)
372+
sentences_batch = self.tokenizer.batch_decode(inputs_batch['input_ids'])
373+
for i in range(len(sentences_batch)):
374+
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
375+
inputs_batch = self.tokenizer(
376+
sentences_batch,
377+
truncation=True,
378+
max_length=new_max_length,
351379
**kwargs
352380
)
353381
inputs_batch = [{
@@ -385,30 +413,16 @@ def encode_queries_single_device(
385413
all_embeddings = []
386414
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
387415
disable=len(sentences_sorted) < 256):
388-
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
389-
inputs = self.tokenizer(
390-
sentences_batch,
391-
max_length=max_length - _len_1 - _len_2,
392-
return_token_type_ids=False,
393-
truncation=True,
394-
return_tensors=None,
395-
add_special_tokens=False
396-
)
397-
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length + 8) // 8 * 8 + 8
398-
sentences_batch = self.tokenizer.batch_decode(inputs['input_ids'])
399-
for i in range(len(sentences_batch)):
400-
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
401-
inputs = self.tokenizer(
402-
sentences_batch,
416+
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
417+
inputs_batch = self.tokenizer.pad(
418+
inputs_batch,
403419
padding=True,
404-
truncation=True,
405420
return_tensors='pt',
406-
max_length=new_max_length,
407-
add_special_tokens=True
421+
**kwargs
408422
).to(device)
409423

410-
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
411-
embeddings = last_token_pool(last_hidden_state, inputs['attention_mask'])
424+
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
425+
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
412426
if self.normalize_embeddings:
413427
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
414428
embeddings = cast(torch.Tensor, embeddings)
@@ -469,7 +483,7 @@ def encode_single_device(
469483

470484
# tokenize without padding to get the correct length
471485
all_inputs = []
472-
for start_index in range(0, len(sentences), batch_size):
486+
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
473487
sentences_batch = sentences[start_index:start_index + batch_size]
474488
inputs_batch = self.tokenizer(
475489
sentences_batch,

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from tqdm import tqdm
1+
from tqdm import tqdm, trange
22
from typing import cast, Any, List, Union, Optional
33

44
import torch
@@ -205,7 +205,7 @@ def encode_single_device(
205205

206206
# tokenize without padding to get the correct length
207207
all_inputs = []
208-
for start_index in range(0, len(sentences), batch_size):
208+
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
209209
sentences_batch = sentences[start_index:start_index + batch_size]
210210
inputs_batch = self.tokenizer(
211211
sentences_batch,

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
369369

370370
# tokenize without padding to get the correct length
371371
all_inputs = []
372-
for start_index in range(0, len(sentences), batch_size):
372+
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
373373
sentences_batch = sentences[start_index:start_index + batch_size]
374374
inputs_batch = self.tokenizer(
375375
sentences_batch,

0 commit comments

Comments
 (0)