Skip to content

Commit 6a504b9

Browse files
authored
Merge branch 'FlagOpen:master' into master
2 parents 6856d64 + 34d24e8 commit 6a504b9

164 files changed

Lines changed: 110466 additions & 202 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

C_MTEB/MKQA/README.md

Lines changed: 408 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
python step0-generate_embedding.py \
3+
--encoder BAAI/bge-m3 \
4+
--index_save_dir ./corpus-index \
5+
--max_passage_length 512 \
6+
--batch_size 256 \
7+
--fp16 \
8+
--pooling_method cls \
9+
--normalize_embeddings True
10+
"""
11+
import os
12+
import sys
13+
import faiss
14+
import datasets
15+
import numpy as np
16+
from tqdm import tqdm
17+
from pprint import pprint
18+
from FlagEmbedding import FlagModel
19+
from dataclasses import dataclass, field
20+
from transformers import HfArgumentParser
21+
22+
sys.path.append("..")
23+
24+
from utils.normalize_text import normalize
25+
26+
27+
@dataclass
28+
class ModelArgs:
29+
encoder: str = field(
30+
default="BAAI/bge-m3",
31+
metadata={'help': 'Name or path of encoder'}
32+
)
33+
fp16: bool = field(
34+
default=True,
35+
metadata={'help': 'Use fp16 in inference?'}
36+
)
37+
pooling_method: str = field(
38+
default='cls',
39+
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
40+
)
41+
normalize_embeddings: bool = field(
42+
default=True,
43+
metadata={'help': "Normalize embeddings or not"}
44+
)
45+
46+
47+
@dataclass
48+
class EvalArgs:
49+
index_save_dir: str = field(
50+
default='./corpus-index',
51+
metadata={'help': 'Dir to save index. Corpus index will be saved to `index_save_dir/{encoder_name}/index`. Corpus ids will be saved to `index_save_dir/{encoder_name}/docid` .'}
52+
)
53+
max_passage_length: int = field(
54+
default=512,
55+
metadata={'help': 'Max passage length.'}
56+
)
57+
batch_size: int = field(
58+
default=256,
59+
metadata={'help': 'Inference batch size.'}
60+
)
61+
overwrite: bool = field(
62+
default=False,
63+
metadata={'help': 'Whether to overwrite embedding'}
64+
)
65+
66+
67+
def get_model(model_args: ModelArgs):
68+
model = FlagModel(
69+
model_args.encoder,
70+
pooling_method=model_args.pooling_method,
71+
normalize_embeddings=model_args.normalize_embeddings,
72+
use_fp16=model_args.fp16
73+
)
74+
return model
75+
76+
77+
def parse_corpus(corpus: datasets.Dataset):
78+
corpus_list = []
79+
for data in tqdm(corpus, desc="Generating corpus"):
80+
_id = str(data['_id'])
81+
content = f"{data['title']}\n{data['text']}".lower()
82+
content = normalize(content)
83+
corpus_list.append({"id": _id, "content": content})
84+
85+
corpus = datasets.Dataset.from_list(corpus_list)
86+
return corpus
87+
88+
89+
def generate_index(model: FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
90+
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_passage_length)
91+
dim = corpus_embeddings.shape[-1]
92+
93+
faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
94+
corpus_embeddings = corpus_embeddings.astype(np.float32)
95+
faiss_index.train(corpus_embeddings)
96+
faiss_index.add(corpus_embeddings)
97+
return faiss_index, list(corpus["id"])
98+
99+
100+
def save_result(index: faiss.Index, docid: list, index_save_dir: str):
101+
docid_save_path = os.path.join(index_save_dir, 'docid')
102+
index_save_path = os.path.join(index_save_dir, 'index')
103+
with open(docid_save_path, 'w', encoding='utf-8') as f:
104+
for _id in docid:
105+
f.write(str(_id) + '\n')
106+
faiss.write_index(index, index_save_path)
107+
108+
109+
def main():
110+
parser = HfArgumentParser([ModelArgs, EvalArgs])
111+
model_args, eval_args = parser.parse_args_into_dataclasses()
112+
model_args: ModelArgs
113+
eval_args: EvalArgs
114+
115+
if model_args.encoder[-1] == '/':
116+
model_args.encoder = model_args.encoder[:-1]
117+
118+
model = get_model(model_args=model_args)
119+
120+
encoder = model_args.encoder
121+
if os.path.basename(encoder).startswith('checkpoint-'):
122+
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
123+
124+
print("==================================================")
125+
print("Start generating embedding with model:")
126+
print(model_args.encoder)
127+
128+
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
129+
if not os.path.exists(index_save_dir):
130+
os.makedirs(index_save_dir)
131+
if os.path.exists(os.path.join(index_save_dir, 'index')) and not eval_args.overwrite:
132+
print(f'Embedding already exists. Skip...')
133+
return
134+
135+
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
136+
corpus = parse_corpus(corpus=corpus)
137+
138+
index, docid = generate_index(
139+
model=model,
140+
corpus=corpus,
141+
max_passage_length=eval_args.max_passage_length,
142+
batch_size=eval_args.batch_size
143+
)
144+
save_result(index, docid, index_save_dir)
145+
146+
print("==================================================")
147+
print("Finish generating embeddings with following model:")
148+
pprint(model_args.encoder)
149+
150+
151+
if __name__ == "__main__":
152+
main()
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
python step1-search_results.py \
3+
--encoder BAAI/bge-m3 \
4+
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
5+
--index_save_dir ./corpus-index \
6+
--result_save_dir ./search_results \
7+
--qa_data_dir ../qa_data \
8+
--threads 16 \
9+
--batch_size 32 \
10+
--hits 1000 \
11+
--pooling_method cls \
12+
--normalize_embeddings True \
13+
--add_instruction False
14+
"""
15+
import os
16+
import sys
17+
import torch
18+
import datasets
19+
from tqdm import tqdm
20+
from pprint import pprint
21+
from dataclasses import dataclass, field
22+
from transformers import HfArgumentParser, is_torch_npu_available
23+
from pyserini.search.faiss import FaissSearcher, AutoQueryEncoder
24+
from pyserini.output_writer import get_output_writer, OutputFormat
25+
26+
27+
@dataclass
28+
class ModelArgs:
29+
encoder: str = field(
30+
default="BAAI/bge-m3",
31+
metadata={'help': 'Name or path of encoder'}
32+
)
33+
add_instruction: bool = field(
34+
default=False,
35+
metadata={'help': 'Add query-side instruction?'}
36+
)
37+
query_instruction_for_retrieval: str = field(
38+
default=None,
39+
metadata={'help': 'query instruction for retrieval'}
40+
)
41+
pooling_method: str = field(
42+
default='cls',
43+
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
44+
)
45+
normalize_embeddings: bool = field(
46+
default=True,
47+
metadata={'help': "Normalize embeddings or not"}
48+
)
49+
50+
51+
@dataclass
52+
class EvalArgs:
53+
languages: str = field(
54+
default="en",
55+
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
56+
"nargs": "+"}
57+
)
58+
index_save_dir: str = field(
59+
default='./corpus-index',
60+
metadata={'help': 'Dir to index and docid. Corpus index path is `index_save_dir/{encoder_name}/index`. Corpus ids path is `index_save_dir/{encoder_name}/docid` .'}
61+
)
62+
result_save_dir: str = field(
63+
default='./search_results',
64+
metadata={'help': 'Dir to saving search results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
65+
)
66+
qa_data_dir: str = field(
67+
default='../qa_data',
68+
metadata={'help': 'Dir to qa data.'}
69+
)
70+
threads: int = field(
71+
default=1,
72+
metadata={'help': 'Maximum threads to use during search'}
73+
)
74+
batch_size: int = field(
75+
default=32,
76+
metadata={'help': 'Search batch size.'}
77+
)
78+
hits: int = field(
79+
default=1000,
80+
metadata={'help': 'Number of hits'}
81+
)
82+
overwrite: bool = field(
83+
default=False,
84+
metadata={'help': 'Whether to overwrite embedding'}
85+
)
86+
87+
88+
def get_query_encoder(model_args: ModelArgs):
89+
if torch.cuda.is_available():
90+
device = torch.device("cuda")
91+
elif is_torch_npu_available():
92+
device = torch.device("npu")
93+
else:
94+
device = torch.device("cpu")
95+
model = AutoQueryEncoder(
96+
encoder_dir=model_args.encoder,
97+
device=device,
98+
pooling=model_args.pooling_method,
99+
l2_norm=model_args.normalize_embeddings
100+
)
101+
return model
102+
103+
104+
def check_languages(languages):
105+
if isinstance(languages, str):
106+
languages = [languages]
107+
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
108+
for lang in languages:
109+
if lang not in avaliable_languages:
110+
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
111+
return languages
112+
113+
114+
def get_queries_and_qids(qa_data_dir: str, lang: str, add_instruction: bool=False, query_instruction_for_retrieval: str=None):
115+
topics_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
116+
if not os.path.exists(topics_path):
117+
raise FileNotFoundError(f"{topics_path} not found")
118+
119+
dataset = datasets.load_dataset('json', data_files=topics_path)['train']
120+
121+
queries = []
122+
qids = []
123+
for data in dataset:
124+
qids.append(str(data['id']))
125+
queries.append(str(data['question']))
126+
if add_instruction and query_instruction_for_retrieval is not None:
127+
queries = [f"{query_instruction_for_retrieval}{query}" for query in queries]
128+
return queries, qids
129+
130+
131+
def save_result(search_results, result_save_path: str, qids: list, max_hits: int):
132+
output_writer = get_output_writer(result_save_path, OutputFormat(OutputFormat.TREC.value), 'w',
133+
max_hits=max_hits, tag='Faiss', topics=qids,
134+
use_max_passage=False,
135+
max_passage_delimiter='#',
136+
max_passage_hits=1000)
137+
with output_writer:
138+
for topic, hits in search_results:
139+
# For some test collections, a query is doc from the corpus (e.g., arguana in BEIR).
140+
# Remove the query from the results.
141+
hits = [hit for hit in hits if hit.docid != topic]
142+
143+
output_writer.write(topic, hits)
144+
145+
146+
def main():
147+
parser = HfArgumentParser([ModelArgs, EvalArgs])
148+
model_args, eval_args = parser.parse_args_into_dataclasses()
149+
model_args: ModelArgs
150+
eval_args: EvalArgs
151+
152+
languages = check_languages(eval_args.languages)
153+
154+
if model_args.encoder[-1] == '/':
155+
model_args.encoder = model_args.encoder[:-1]
156+
157+
query_encoder = get_query_encoder(model_args=model_args)
158+
159+
encoder = model_args.encoder
160+
if os.path.basename(encoder).startswith('checkpoint-'):
161+
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
162+
163+
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
164+
if not os.path.exists(index_save_dir):
165+
raise FileNotFoundError(f"{index_save_dir} not found")
166+
searcher = FaissSearcher(
167+
index_dir=index_save_dir,
168+
query_encoder=query_encoder
169+
)
170+
171+
print("==================================================")
172+
print("Start generating search results with model:", encoder)
173+
174+
print('Generate search results of following languages: ', languages)
175+
for lang in languages:
176+
print("**************************************************")
177+
print(f"Start searching results of {lang} ...")
178+
179+
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
180+
if not os.path.exists(os.path.dirname(result_save_path)):
181+
os.makedirs(os.path.dirname(result_save_path))
182+
183+
if os.path.exists(result_save_path) and not eval_args.overwrite:
184+
print(f'Search results of {lang} already exists. Skip...')
185+
continue
186+
187+
queries, qids = get_queries_and_qids(eval_args.qa_data_dir, lang=lang, add_instruction=model_args.add_instruction)
188+
189+
search_results = []
190+
for start_idx in tqdm(range(0, len(queries), eval_args.batch_size), desc="Searching"):
191+
batch_queries = queries[start_idx : start_idx+eval_args.batch_size]
192+
batch_qids = qids[start_idx : start_idx+eval_args.batch_size]
193+
batch_search_results = searcher.batch_search(
194+
queries=batch_queries,
195+
q_ids=batch_qids,
196+
k=eval_args.hits,
197+
threads=eval_args.threads
198+
)
199+
search_results.extend([(_id, batch_search_results[_id]) for _id in batch_qids])
200+
201+
save_result(
202+
search_results=search_results,
203+
result_save_path=result_save_path,
204+
qids=qids,
205+
max_hits=eval_args.hits
206+
)
207+
208+
print("==================================================")
209+
print("Finish generating search results with following model:")
210+
pprint(model_args.encoder)
211+
212+
213+
if __name__ == "__main__":
214+
main()

0 commit comments

Comments
 (0)