Skip to content

Commit 6fee11b

Browse files
authored
Merge pull request #1258 from hanhainebula/master
update code and README for scripts
2 parents a719aaa + ada9af0 commit 6fee11b

4 files changed

Lines changed: 188 additions & 61 deletions

File tree

scripts/README.md

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ Hard negatives is a widely used method to improve the quality of sentence embedd
1717

1818
```shell
1919
python hn_mine.py \
20-
--model_name_or_path BAAI/bge-base-en-v1.5 \
2120
--input_file toy_finetune_data.jsonl \
2221
--output_file toy_finetune_data_minedHN.jsonl \
2322
--range_for_sampling 2-200 \
2423
--negative_number 15 \
25-
--use_gpu_for_searching
24+
--use_gpu_for_searching \
25+
--embedder_name_or_path BAAI/bge-base-en-v1.5
2626
```
2727

2828
- **`input_file`**: json data for finetuning. This script will retrieve top-k documents for each query, and random sample negatives from the top-k documents (not including the positive documents).
@@ -31,6 +31,19 @@ python hn_mine.py \
3131
- **`range_for_sampling`**: where to sample negative. For example, `2-100` means sampling `negative_number` negatives from top2-top200 documents. **You can set larger value to reduce the difficulty of negatives (e.g., set it `60-300` to sample negatives from top60-300 passages)**
3232
- **`candidate_pool`**: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. The format of this file is the same as [pretrain data](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/pretrain#2-data-format). If input a candidate_pool, this script will retrieve negatives from this file.
3333
- **`use_gpu_for_searching`**: whether to use faiss-gpu to retrieve negatives.
34+
- **`search_batch_size`**: batch size for searching. Default is 64.
35+
- **`embedder_name_or_path`**: The name or path to the embedder.
36+
- **`embedder_model_class`**: Class of the model used for embedding (current options include 'encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl'.). Default is None. For the custom model, you should set this argument.
37+
- **`normalize_embeddings`**: Set to `True` to normalize embeddings.
38+
- **`pooling_method`**: The pooling method for the embedder.
39+
- **`use_fp16`**: Use FP16 precision for inference.
40+
- **`devices`**: List of devices used for inference.
41+
- **`query_instruction_for_retrieval`**, **`query_instruction_format_for_retrieval`**: Instructions and format for query during retrieval.
42+
- **`examples_for_task`**, **`examples_instruction_format`**: Example tasks and their instructions format. This is only used when `embedder_model_class` is set to `decoder-only-icl`.
43+
- **`trust_remote_code`**: Set to `True` to trust remote code execution.
44+
- **`cache_dir`**: Cache directory for models.
45+
- **`embedder_batch_size`**: Batch sizes for embedding and reranking.
46+
- **`embedder_query_max_length`**, **`embedder_passage_max_length`**: Maximum length for embedding queries and passages.
3447

3548
### Teacher Scores
3649

@@ -40,9 +53,7 @@ Teacher scores can be used for model distillation. You can obtain the scores usi
4053
python add_reranker_score.py \
4154
--input_file toy_finetune_data_minedHN.jsonl \
4255
--output_file toy_finetune_data_score.jsonl \
43-
--range_for_sampling 2-200 \
44-
--negative_number 15 \
45-
--use_gpu_for_searching
56+
--reranker_name_or_path BAAI/bge-reranker-v2-m3
4657
```
4758

4859
- **`input_file`**: path to save JSON data with mined hard negatives for finetuning
@@ -80,15 +91,14 @@ python split_data_by_length.py \
8091
--log_name .split_log \
8192
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
8293
--model_name_or_path BAAI/bge-m3 \
83-
--num_proc 16 \
84-
--overwrite False
94+
--num_proc 16
8595
```
8696

87-
- **`input_path`**: The path of input data. (Required)
88-
- **`output_dir`**: The directory of output data. (Required)
97+
- **`input_path`**: The path of input data. It can be a file or a directory containing multiple files.
98+
- **`output_dir`**: The directory of output data. The split data files will be saved to this directory.
8999
- **`cache_dir`**: The cache directory. Default: None
90100
- **`log_name`**: The name of the log file. Default: `.split_log`, which will be saved to `output_dir`
91101
- **`length_list`**: The length list to split. Default: [0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
92102
- **`model_name_or_path`**: The model name or path of the tokenizer. Default: `BAAI/bge-m3`
93103
- **`num_proc`**: The number of processes. Default: 16
94-
- **`overwrite`**: Whether to overwrite the output file. Default: False
104+
- **`overwrite`**: Whether to overwrite the output file. Default: False

scripts/add_reranker_score.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
from typing import Optional, List
33

4-
from FlagEmbedding import FlagAutoReranker
54
from dataclasses import dataclass, field
65
from transformers import HfArgumentParser
6+
from FlagEmbedding import FlagAutoReranker
7+
78

89
@dataclass
910
class ScoreArgs:
@@ -14,6 +15,7 @@ class ScoreArgs:
1415
default=None, metadata={"help": "The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores."}
1516
)
1617

18+
1719
@dataclass
1820
class ModelArgs:
1921
use_fp16: bool = field(
@@ -78,7 +80,8 @@ class ModelArgs:
7880
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
7981
)
8082

81-
def main(score_args, model_args):
83+
84+
def main(score_args: ScoreArgs, model_args: ModelArgs):
8285
reranker = FlagAutoReranker.from_finetuned(
8386
model_name_or_path=model_args.reranker_name_or_path,
8487
model_class=model_args.reranker_model_class,
@@ -130,7 +133,7 @@ def main(score_args, model_args):
130133
f.write(json.dumps(d) + '\n')
131134

132135

133-
if __name__ == '__main__':
136+
if __name__ == "__main__":
134137
parser = HfArgumentParser((
135138
ScoreArgs,
136139
ModelArgs
@@ -139,4 +142,3 @@ def main(score_args, model_args):
139142
score_args: ScoreArgs
140143
model_args: ModelArgs
141144
main(score_args, model_args)
142-

scripts/hn_mine.py

Lines changed: 156 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,98 @@
1-
import argparse
21
import json
32
import random
43
import numpy as np
5-
import faiss
64
from tqdm import tqdm
5+
from typing import Optional
6+
from dataclasses import dataclass, field
77

8-
from FlagEmbedding import FlagModel
9-
10-
11-
def get_args():
12-
parser = argparse.ArgumentParser()
13-
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
14-
parser.add_argument('--input_file', default=None, type=str)
15-
parser.add_argument('--candidate_pool', default=None, type=str)
16-
parser.add_argument('--output_file', default=None, type=str)
17-
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
18-
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
19-
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
20-
parser.add_argument('--query_instruction_for_retrieval', default="")
21-
22-
return parser.parse_args()
23-
24-
25-
def create_index(embeddings, use_gpu):
8+
import faiss
9+
from transformers import HfArgumentParser
10+
from FlagEmbedding import FlagAutoModel
11+
from FlagEmbedding.abc.inference import AbsEmbedder
12+
13+
14+
@dataclass
15+
class DataArgs:
16+
"""
17+
Data arguments for hard negative mining.
18+
"""
19+
input_file: str = field(
20+
metadata={"help": "The input file for hard negative mining."}
21+
)
22+
output_file: str = field(
23+
metadata={"help": "The output file for hard negative mining."}
24+
)
25+
candidate_pool: Optional[str] = field(
26+
default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."}
27+
)
28+
range_for_sampling: str = field(
29+
default="10-210", metadata={"help": "The range to sample negatives."}
30+
)
31+
negative_number: int = field(
32+
default=15, metadata={"help": "The number of negatives."}
33+
)
34+
use_gpu_for_searching: bool = field(
35+
default=False, metadata={"help": "Whether to use faiss-gpu for searching."}
36+
)
37+
search_batch_size: int = field(
38+
default=64, metadata={"help": "The batch size for searching."}
39+
)
40+
41+
42+
@dataclass
43+
class ModelArgs:
44+
"""
45+
Model arguments for embedder.
46+
"""
47+
embedder_name_or_path: str = field(
48+
metadata={"help": "The embedder name or path.", "required": True}
49+
)
50+
embedder_model_class: Optional[str] = field(
51+
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
52+
)
53+
normalize_embeddings: bool = field(
54+
default=True, metadata={"help": "whether to normalize the embeddings"}
55+
)
56+
pooling_method: str = field(
57+
default="cls", metadata={"help": "The pooling method fot the embedder."}
58+
)
59+
use_fp16: bool = field(
60+
default=True, metadata={"help": "whether to use fp16 for inference"}
61+
)
62+
devices: Optional[str] = field(
63+
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
64+
)
65+
query_instruction_for_retrieval: Optional[str] = field(
66+
default=None, metadata={"help": "Instruction for query"}
67+
)
68+
query_instruction_format_for_retrieval: str = field(
69+
default="{}{}", metadata={"help": "Format for query instruction"}
70+
)
71+
examples_for_task: Optional[str] = field(
72+
default=None, metadata={"help": "Examples for task"}
73+
)
74+
examples_instruction_format: str = field(
75+
default="{}{}", metadata={"help": "Format for examples instruction"}
76+
)
77+
trust_remote_code: bool = field(
78+
default=False, metadata={"help": "Trust remote code"}
79+
)
80+
cache_dir: str = field(
81+
default=None, metadata={"help": "Cache directory for models."}
82+
)
83+
# ================ for inference ===============
84+
batch_size: int = field(
85+
default=3000, metadata={"help": "Batch size for inference."}
86+
)
87+
embedder_query_max_length: int = field(
88+
default=512, metadata={"help": "Max length for query."}
89+
)
90+
embedder_passage_max_length: int = field(
91+
default=512, metadata={"help": "Max length for passage."}
92+
)
93+
94+
95+
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
2696
index = faiss.IndexFlatIP(len(embeddings[0]))
2797
embeddings = np.asarray(embeddings, dtype=np.float32)
2898
if use_gpu:
@@ -34,10 +104,12 @@ def create_index(embeddings, use_gpu):
34104
return index
35105

36106

37-
def batch_search(index,
38-
query,
39-
topk: int = 200,
40-
batch_size: int = 64):
107+
def batch_search(
108+
index: faiss.Index,
109+
query: np.ndarray,
110+
topk: int = 200,
111+
batch_size: int = 64
112+
):
41113
all_scores, all_inxs = [], []
42114
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
43115
batch_query = query[start_index:start_index + batch_size]
@@ -47,15 +119,24 @@ def batch_search(index,
47119
return all_scores, all_inxs
48120

49121

50-
def get_corpus(candidate_pool):
122+
def get_corpus(candidate_pool: str):
51123
corpus = []
52-
for line in open(candidate_pool):
53-
line = json.loads(line.strip())
54-
corpus.append(line['text'])
124+
with open(candidate_pool, "r", encoding="utf-8") as f:
125+
for line in f.readlines():
126+
line = json.loads(line.strip())
127+
corpus.append(line['text'])
55128
return corpus
56129

57130

58-
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
131+
def find_knn_neg(
132+
model: AbsEmbedder,
133+
input_file: str,
134+
output_file: str,
135+
candidate_pool: Optional[str] = None,
136+
sample_range: str = "10-210",
137+
negative_number: int = 15,
138+
use_gpu: bool = False
139+
):
59140
corpus = []
60141
queries = []
61142
train_data = []
@@ -75,9 +156,9 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
75156
corpus = list(set(corpus))
76157

77158
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
78-
p_vecs = model.encode(corpus, batch_size=256)
159+
p_vecs = model.encode(corpus)
79160
print(f'inferencing embedding for queries (number={len(queries)})--------------')
80-
q_vecs = model.encode_queries(queries, batch_size=256)
161+
q_vecs = model.encode_queries(queries)
81162

82163
print('create index and search------------------')
83164
index = create_index(p_vecs, use_gpu=use_gpu)
@@ -106,17 +187,47 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
106187
f.write(json.dumps(data, ensure_ascii=False) + '\n')
107188

108189

109-
if __name__ == '__main__':
110-
args = get_args()
111-
sample_range = args.range_for_sampling.split('-')
112-
sample_range = [int(x) for x in sample_range]
113-
114-
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
115-
116-
find_knn_neg(model,
117-
input_file=args.input_file,
118-
candidate_pool=args.candidate_pool,
119-
output_file=args.output_file,
120-
sample_range=sample_range,
121-
negative_number=args.negative_number,
122-
use_gpu=args.use_gpu_for_searching)
190+
def load_model(model_args: ModelArgs):
191+
model = FlagAutoModel.from_finetuned(
192+
model_name_or_path=model_args.embedder_name_or_path,
193+
model_class=model_args.embedder_model_class,
194+
normalize_embeddings=model_args.normalize_embeddings,
195+
pooling_method=model_args.pooling_method,
196+
use_fp16=model_args.use_fp16,
197+
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
198+
query_instruction_format=model_args.query_instruction_format_for_retrieval,
199+
devices=model_args.devices,
200+
examples_for_task=model_args.examples_for_task,
201+
examples_instruction_format=model_args.examples_instruction_format,
202+
trust_remote_code=model_args.trust_remote_code,
203+
cache_dir=model_args.cache_dir,
204+
batch_size=model_args.batch_size,
205+
query_max_length=model_args.embedder_query_max_length,
206+
passage_max_length=model_args.embedder_passage_max_length,
207+
)
208+
return model
209+
210+
211+
def main(data_args: DataArgs, model_args: ModelArgs):
212+
model = load_model(model_args)
213+
214+
find_knn_neg(
215+
model=model,
216+
input_file=data_args.input_file,
217+
output_file=data_args.output_file,
218+
candidate_pool=data_args.candidate_pool,
219+
sample_range=[int(x) for x in data_args.range_for_sampling.split('-')],
220+
negative_number=data_args.negative_number,
221+
use_gpu=data_args.use_gpu_for_searching
222+
)
223+
224+
225+
if __name__ == "__main__":
226+
parser = HfArgumentParser((
227+
DataArgs,
228+
ModelArgs
229+
))
230+
data_args, model_args = parser.parse_args_into_dataclasses()
231+
data_args: DataArgs
232+
model_args: ModelArgs
233+
main(data_args, model_args)

scripts/split_data_by_length.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def run(self, input_path: str, output_dir: str, log_name: str=None):
187187
f.write('\n')
188188

189189

190-
if __name__ == '__main__':
191-
args = get_args()
190+
def main(args):
192191
input_path = args.input_path
193192
output_dir = args.output_dir
194193
log_name = args.log_name
@@ -207,3 +206,8 @@ def run(self, input_path: str, output_dir: str, log_name: str=None):
207206
log_name=log_name
208207
)
209208
print('\nDONE!')
209+
210+
211+
if __name__ == "__main__":
212+
args = get_args()
213+
main(args)

0 commit comments

Comments
 (0)