Skip to content

Commit 919a53b

Browse files
committed
add reranker score
1 parent a0beff0 commit 919a53b

1 file changed

Lines changed: 69 additions & 5 deletions

File tree

scripts/add_reranker_score.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,92 @@
11
import json
2+
from typing import Optional, List
23

34
from FlagEmbedding import FlagAutoReranker
4-
from FlagEmbedding.abc.evaluation import AbsEvalModelArgs
55
from dataclasses import dataclass, field
66
from transformers import HfArgumentParser
77

88
@dataclass
99
class ScoreArgs:
1010
input_file: str = field(
11-
default=None, metadata={"help": "The input json file, each line includes query, pos and neg."}
11+
default=None, metadata={"help": "The input jsonl file, each line includes query, pos and neg."}
1212
)
1313
output_file: str = field(
14-
default=None, metadata={"help": "The output json file, it includes query, pos, neg, pos_scores and neg_scores."}
14+
default=None, metadata={"help": "The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores."}
15+
)
16+
17+
@dataclass
18+
class ModelArgs:
19+
use_fp16: bool = field(
20+
default=True, metadata={"help": "whether to use fp16 for inference"}
21+
)
22+
devices: Optional[str] = field(
23+
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
24+
)
25+
trust_remote_code: bool = field(
26+
default=False, metadata={"help": "Trust remote code"}
27+
)
28+
reranker_name_or_path: Optional[str] = field(
29+
default=None, metadata={"help": "The reranker name or path."}
30+
)
31+
reranker_model_class: Optional[str] = field(
32+
default="auto", metadata={"help": "The reranker model class. Available classes: ['auto', 'encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: auto.", "choices": ["auto", "encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
33+
)
34+
reranker_peft_path: Optional[str] = field(
35+
default=None, metadata={"help": "The reranker peft path."}
36+
)
37+
use_bf16: bool = field(
38+
default=False, metadata={"help": "whether to use bf16 for inference"}
39+
)
40+
query_instruction_for_rerank: Optional[str] = field(
41+
default=None, metadata={"help": "Instruction for query"}
42+
)
43+
query_instruction_format_for_rerank: str = field(
44+
default="{}{}", metadata={"help": "Format for query instruction"}
45+
)
46+
passage_instruction_for_rerank: Optional[str] = field(
47+
default=None, metadata={"help": "Instruction for passage"}
48+
)
49+
passage_instruction_format_for_rerank: str = field(
50+
default="{}{}", metadata={"help": "Format for passage instruction"}
51+
)
52+
cache_dir: str = field(
53+
default=None, metadata={"help": "Cache directory for models."}
54+
)
55+
# ================ for inference ===============
56+
reranker_batch_size: int = field(
57+
default=3000, metadata={"help": "Batch size for inference."}
58+
)
59+
reranker_query_max_length: Optional[int] = field(
60+
default=None, metadata={"help": "Max length for reranking."}
61+
)
62+
reranker_max_length: int = field(
63+
default=512, metadata={"help": "Max length for reranking."}
64+
)
65+
normalize: bool = field(
66+
default=False, metadata={"help": "whether to normalize the reranking scores"}
67+
)
68+
prompt: Optional[str] = field(
69+
default=None, metadata={"help": "The prompt for the reranker."}
70+
)
71+
cutoff_layers: List[int] = field(
72+
default=None, metadata={"help": "The output layers of layerwise/lightweight reranker."}
73+
)
74+
compress_ratio: int = field(
75+
default=1, metadata={"help": "The compress ratio of lightweight reranker."}
76+
)
77+
compress_layers: Optional[int] = field(
78+
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
1579
)
1680

1781

1882
if __name__ == '__main__':
1983
parser = HfArgumentParser((
2084
ScoreArgs,
21-
AbsEvalModelArgs
85+
ModelArgs
2286
))
2387
score_args, model_args = parser.parse_args_into_dataclasses()
2488
eval_args: ScoreArgs
25-
model_args: AbsEvalModelArgs
89+
model_args: ModelArgs
2690

2791
reranker = FlagAutoReranker.from_finetuned(
2892
model_name_or_path=model_args.reranker_name_or_path,

0 commit comments

Comments
 (0)