Skip to content

Commit a7bbf09

Browse files
committed
Merge branch 'new-flagembedding-v1' of https://github.com/hanhainebula/FlagEmbedding into new-flagembedding-v1
2 parents fcd6937 + c6f5ec4 commit a7bbf09

7 files changed

Lines changed: 52 additions & 15 deletions

File tree

FlagEmbedding/abc/evaluation/data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _download_file(self, download_url: str, save_dir: str):
192192
try:
193193
subprocess.run(cmd, check=True)
194194
except subprocess.CalledProcessError as e:
195-
logger.error(f"Error code: {e.returncode}. Error message: {e.stderr}")
195+
logger.warning(e.output)
196196

197197
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
198198
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
@@ -217,7 +217,7 @@ def _download_gz_file(self, download_url: str, save_dir: str):
217217
try:
218218
subprocess.run(cmd, check=True)
219219
except subprocess.CalledProcessError as e:
220-
logger.error(f"Error code: {e.returncode}. Error message: {e.output}")
220+
logger.warning(e.output)
221221

222222
file_path = gz_file_path.replace(".gz", "")
223223
if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0:
@@ -236,7 +236,7 @@ def _download_zip_file(self, download_url: str, save_dir: str):
236236
try:
237237
subprocess.run(cmd, check=True)
238238
except subprocess.CalledProcessError as e:
239-
logger.error(f"Error code: {e.returncode}. Error message: {e.output}")
239+
logger.warning(e.output)
240240

241241
if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0:
242242
raise FileNotFoundError(f"Failed to unzip file {zip_file_path}")

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def __call__(
173173
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
174174
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
175175

176+
retriever.stop_multi_process_pool()
176177
# Reranking Stage
177178
if reranker is not None:
178179
reranker_search_results_save_dir = os.path.join(
@@ -215,6 +216,7 @@ def __call__(
215216
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
216217
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
217218
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
219+
reranker.stop_multi_process_pool()
218220

219221
@staticmethod
220222
def save_search_results(

FlagEmbedding/abc/evaluation/searcher.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def __str__(self) -> str:
2525
"""
2626
return os.path.basename(self.embedder.model.config._name_or_path)
2727

28+
def stop_multi_process_pool(self):
29+
if self.embedder.pool is not None:
30+
self.embedder.stop_multi_process_pool(self.embedder.pool)
31+
2832
@abstractmethod
2933
def __call__(
3034
self,
@@ -144,6 +148,10 @@ def __str__(self) -> str:
144148
"""
145149
return os.path.basename(self.reranker.model.config._name_or_path)
146150

151+
def stop_multi_process_pool(self):
152+
if self.reranker.pool is not None:
153+
self.reranker.stop_multi_process_pool(self.reranker.pool)
154+
147155
def __call__(
148156
self,
149157
corpus: Dict[str, Dict[str, Any]],

FlagEmbedding/evaluation/air_bench/arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ class AIRBenchEvalModelArgs:
99
embedder_name_or_path: str = field(
1010
metadata={"help": "The embedder name or path.", "required": True}
1111
)
12+
embedder_model_class: Optional[str] = field(
13+
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
14+
)
1215
normalize_embeddings: bool = field(
1316
default=True, metadata={"help": "whether to normalize the embeddings"}
1417
)
18+
pooling_method: str = field(
19+
default="cls", metadata={"help": "The pooling method fot the embedder."}
20+
)
1521
use_fp16: bool = field(
1622
default=True, metadata={"help": "whether to use fp16 for inference"}
1723
)
@@ -36,6 +42,9 @@ class AIRBenchEvalModelArgs:
3642
reranker_name_or_path: Optional[str] = field(
3743
default=None, metadata={"help": "The reranker name or path."}
3844
)
45+
reranker_model_class: Optional[str] = field(
46+
default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
47+
)
3948
reranker_peft_path: Optional[str] = field(
4049
default=None, metadata={"help": "The reranker peft path."}
4150
)

FlagEmbedding/evaluation/beir/evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def __call__(
145145
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
146146
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
147147

148+
retriever.stop_multi_process_pool()
149+
148150
# Reranking Stage
149151
if reranker is not None:
150152
reranker_search_results_save_dir = os.path.join(
@@ -314,7 +316,8 @@ def __call__(
314316
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
315317
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
316318
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
317-
319+
if reranker is not None:
320+
reranker.stop_multi_process_pool()
318321
def evaluate_results(
319322
self,
320323
search_results_save_dir: str,

FlagEmbedding/evaluation/msmarco/data_loader.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,33 @@ def _load_remote_corpus(
4545
corpus_dict = {}
4646
with open(save_path, "w", encoding="utf-8") as f:
4747
for data in tqdm(corpus, desc="Loading and Saving corpus"):
48-
_data = {
49-
"id": data["docid"],
50-
"title": data["title"],
51-
"text": data.get("text", data.get("body", ""))
52-
}
53-
corpus_dict[data["docid"]] = {
54-
"title": data["title"],
55-
"text": data.get("text", data.get("body", ""))
56-
}
48+
if dataset_name == 'passage':
49+
_data = {
50+
"id": data["docid"],
51+
"title": data["title"],
52+
"text": data["text"]
53+
}
54+
corpus_dict[data["docid"]] = {
55+
"title": data["title"],
56+
"text": data["text"]
57+
}
58+
else:
59+
_data = {
60+
"id": data["doc_id"],
61+
"title": data["title"],
62+
"text": data["body"]
63+
}
64+
corpus_dict[data["doc_id"]] = {
65+
"title": data["title"],
66+
"text": data["body"]
67+
}
5768
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
5869
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
5970
else:
60-
corpus_dict = {data["docid"]: {"title": data["title"], "text": data.get("text", data.get("body", ""))} for data in tqdm(corpus, desc="Loading corpus")}
71+
if dataset_name == 'passage':
72+
corpus_dict = {data["docid"]: {"title": data["title"], "text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
73+
else:
74+
corpus_dict = {data["doc_id"]: {"title": data["title"], "text": data["body"]} for data in tqdm(corpus, desc="Loading corpus")}
6175
return datasets.DatasetDict(corpus_dict)
6276

6377
def _load_remote_qrels(

FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def get_model(
3838
colbert_dim: int = -1,
3939
cache_dir: str = None
4040
):
41+
cache_folder = os.getenv('HF_HUB_CACHE', None) if cache_dir is None else cache_dir
4142
if not os.path.exists(model_name_or_path):
42-
cache_folder = os.getenv('HF_HUB_CACHE', None) if cache_dir is None else cache_dir
4343
model_name_or_path = snapshot_download(
4444
repo_id=model_name_or_path,
4545
cache_dir=cache_folder,
@@ -48,6 +48,7 @@ def get_model(
4848

4949
model = AutoModel.from_pretrained(
5050
model_name_or_path,
51+
cache_dir=cache_folder,
5152
trust_remote_code=trust_remote_code
5253
)
5354
colbert_linear = torch.nn.Linear(

0 commit comments

Comments
 (0)