1+ import json
2+
3+ from FlagEmbedding import FlagAutoReranker
4+ from FlagEmbedding .abc .evaluation import AbsEvalModelArgs
5+ from dataclasses import dataclass , field
6+ from transformers import HfArgumentParser
7+
8+ @dataclass
9+ class ScoreArgs :
10+ input_file : str = field (
11+ default = None , metadata = {"help" : "The input json file, each line includes query, pos and neg." }
12+ )
13+ output_file : str = field (
14+ default = None , metadata = {"help" : "The output json file, it includes query, pos, neg, pos_scores and neg_scores." }
15+ )
16+
17+
18+ if __name__ == '__main__' :
19+ parser = HfArgumentParser ((
20+ ScoreArgs ,
21+ AbsEvalModelArgs
22+ ))
23+ score_args , model_args = parser .parse_args_into_dataclasses ()
24+ eval_args : ScoreArgs
25+ model_args : AbsEvalModelArgs
26+
27+ reranker = FlagAutoReranker .from_finetuned (
28+ model_name_or_path = model_args .reranker_name_or_path ,
29+ model_class = model_args .reranker_model_class ,
30+ peft_path = model_args .reranker_peft_path ,
31+ use_fp16 = model_args .use_fp16 ,
32+ use_bf16 = model_args .use_bf16 ,
33+ query_instruction_for_rerank = model_args .query_instruction_for_rerank ,
34+ query_instruction_format = model_args .query_instruction_format_for_rerank ,
35+ passage_instruction_for_rerank = model_args .passage_instruction_for_rerank ,
36+ passage_instruction_format = model_args .passage_instruction_format_for_rerank ,
37+ cache_dir = model_args .cache_dir ,
38+ trust_remote_code = model_args .trust_remote_code ,
39+ devices = model_args .devices ,
40+ normalize = model_args .normalize ,
41+ prompt = model_args .prompt ,
42+ cutoff_layers = model_args .cutoff_layers ,
43+ compress_layers = model_args .compress_layers ,
44+ compress_ratio = model_args .compress_ratio ,
45+ batch_size = model_args .reranker_batch_size ,
46+ query_max_length = model_args .reranker_query_max_length ,
47+ max_length = model_args .reranker_max_length ,
48+ )
49+
50+ pairs = []
51+ data = []
52+ with open (score_args .input_file ) as f :
53+ for line in f :
54+ data .append (json .loads (line ))
55+ for p in data [- 1 ]['pos' ]:
56+ pairs .append ((data [- 1 ]['query' ], p ))
57+ for p in data [- 1 ]['neg' ]:
58+ pairs .append ((data [- 1 ]['query' ], p ))
59+
60+ scores = reranker .compute_score (pairs )
61+
62+ score_idx = 0
63+ for i in range (len (data )):
64+ data [i ]['pos_scores' ] = []
65+ data [i ]['neg_scores' ] = []
66+ for _ in range (len (data [i ]['pos' ])):
67+ data [i ]['pos_scores' ].append (float (scores [score_idx ]))
68+ score_idx += 1
69+ for _ in range (len (data [i ]['neg' ])):
70+ data [i ]['neg_scores' ].append (float (scores [score_idx ]))
71+ score_idx += 1
72+
73+ with open (score_args .output_dir , 'w' ) as f :
74+ for d in data :
75+ f .write (json .dumps (d ) + '\n ' )
0 commit comments