Skip to content

Commit a0beff0

Browse files
committed
add reranker score
1 parent 637f862 commit a0beff0

1 file changed

Lines changed: 75 additions & 0 deletions

File tree

scripts/add_reranker_score.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
3+
from FlagEmbedding import FlagAutoReranker
4+
from FlagEmbedding.abc.evaluation import AbsEvalModelArgs
5+
from dataclasses import dataclass, field
6+
from transformers import HfArgumentParser
7+
8+
@dataclass
9+
class ScoreArgs:
10+
input_file: str = field(
11+
default=None, metadata={"help": "The input json file, each line includes query, pos and neg."}
12+
)
13+
output_file: str = field(
14+
default=None, metadata={"help": "The output json file, it includes query, pos, neg, pos_scores and neg_scores."}
15+
)
16+
17+
18+
if __name__ == '__main__':
19+
parser = HfArgumentParser((
20+
ScoreArgs,
21+
AbsEvalModelArgs
22+
))
23+
score_args, model_args = parser.parse_args_into_dataclasses()
24+
eval_args: ScoreArgs
25+
model_args: AbsEvalModelArgs
26+
27+
reranker = FlagAutoReranker.from_finetuned(
28+
model_name_or_path=model_args.reranker_name_or_path,
29+
model_class=model_args.reranker_model_class,
30+
peft_path=model_args.reranker_peft_path,
31+
use_fp16=model_args.use_fp16,
32+
use_bf16=model_args.use_bf16,
33+
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
34+
query_instruction_format=model_args.query_instruction_format_for_rerank,
35+
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
36+
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
37+
cache_dir=model_args.cache_dir,
38+
trust_remote_code=model_args.trust_remote_code,
39+
devices=model_args.devices,
40+
normalize=model_args.normalize,
41+
prompt=model_args.prompt,
42+
cutoff_layers=model_args.cutoff_layers,
43+
compress_layers=model_args.compress_layers,
44+
compress_ratio=model_args.compress_ratio,
45+
batch_size=model_args.reranker_batch_size,
46+
query_max_length=model_args.reranker_query_max_length,
47+
max_length=model_args.reranker_max_length,
48+
)
49+
50+
pairs = []
51+
data = []
52+
with open(score_args.input_file) as f:
53+
for line in f:
54+
data.append(json.loads(line))
55+
for p in data[-1]['pos']:
56+
pairs.append((data[-1]['query'], p))
57+
for p in data[-1]['neg']:
58+
pairs.append((data[-1]['query'], p))
59+
60+
scores = reranker.compute_score(pairs)
61+
62+
score_idx = 0
63+
for i in range(len(data)):
64+
data[i]['pos_scores'] = []
65+
data[i]['neg_scores'] = []
66+
for _ in range(len(data[i]['pos'])):
67+
data[i]['pos_scores'].append(float(scores[score_idx]))
68+
score_idx += 1
69+
for _ in range(len(data[i]['neg'])):
70+
data[i]['neg_scores'].append(float(scores[score_idx]))
71+
score_idx += 1
72+
73+
with open(score_args.output_dir, 'w') as f:
74+
for d in data:
75+
f.write(json.dumps(d) + '\n')

0 commit comments

Comments
 (0)