1+ import json
2+ from dataclasses import dataclass , field
3+ from typing import List
4+
5+ import numpy as np
6+ import pytrec_eval
7+ from transformers import HfArgumentParser
8+ from FlagEmbedding import FlagReranker
9+
10+ @dataclass
11+ class Args ():
12+ input_path : str = field (
13+ default = "" ,
14+ metadata = {'help' : """
15+ The data path points to a file in JSONL format.
16+ Each line contains `query`, `pos`, and `neg`. Here, `query` is a string (`str`),
17+ while both `pos` and `neg` are lists of strings (`List[str]`).
18+ If each line includes `pos_label_scores`, it will use to compute `ndcg@k`, else it will set default `1`.
19+ """ }
20+ )
21+ metrics : List [str ] = field (
22+ default = None , # usage example: recall mrr ndcg
23+ metadata = {'help' : 'The evaluation metrics, you can set recall / mrr / ndcg' }
24+ )
25+ k_values : List [int ] = field (
26+ default = None ,
27+ metadata = {'help' : 'Present the top-k metrics evaluation.' }
28+ )
29+ cache_dir : str = field (
30+ default = None ,
31+ metadata = {'help' : 'The path to store the cache of reranker.' }
32+ )
33+ use_fp16 : bool = field (
34+ default = True ,
35+ metadata = {'help' : 'Whether to use fp16 to accelerate inference, it is not suitable for CPU only inference.' }
36+ )
37+ batch_size : int = field (
38+ default = 512
39+ )
40+ max_length : int = field (
41+ default = 1024
42+ )
43+
44+
45+ def evaluate_mrr (predicts , labels , cutoffs ):
46+ """
47+ Evaluate MRR.
48+ """
49+ metrics = {}
50+
51+ # MRR
52+ mrrs = np .zeros (len (cutoffs ))
53+ for pred , label in zip (predicts , labels ):
54+ jump = False
55+ for i , x in enumerate (pred , 1 ):
56+ if x in label :
57+ for k , cutoff in enumerate (cutoffs ):
58+ if i <= cutoff :
59+ mrrs [k ] += 1 / i
60+ jump = True
61+ if jump :
62+ break
63+ mrrs /= len (predicts )
64+ for i , cutoff in enumerate (cutoffs ):
65+ mrr = mrrs [i ]
66+ metrics [f"MRR@{ cutoff } " ] = mrr
67+
68+ return metrics
69+
70+ def main ():
71+ parser = HfArgumentParser ([Args ])
72+ args : Args = parser .parse_args_into_dataclasses ()[0 ]
73+ input_path = args .input_path
74+ metrics = args .metrics if args .metrics is not None else ['recall' , 'mrr' , 'ndcg' , 'map' , 'precision' ]
75+ k_values = args .k_values if args .k_values is not None else [1 , 5 , 10 , 50 , 100 ]
76+ cache_dir = args .cache_dir
77+ use_fp16 = args .use_fp16
78+ batch_size = args .batch_size
79+ max_length = args .max_length
80+
81+ reranker = FlagReranker ('BAAI/bge-reranker-v2-m3' , cache_dir = cache_dir , use_fp16 = use_fp16 )
82+
83+ data = []
84+ data_num = []
85+ with open (input_path ) as f :
86+ for line in f :
87+ data .append (json .loads (line ))
88+
89+ pairs = []
90+ for d in data :
91+ data_num .append (0 )
92+ passages = []
93+ passages .extend (d ['pos' ])
94+ passages .extend (d ['neg' ])
95+ for p in passages :
96+ pairs .append ((d ['query' ], p ))
97+ data_num [- 1 ] += 1
98+
99+ scores = reranker .compute_score (pairs , batch_size = batch_size , max_length = max_length )
100+ scores = np .asarray (scores )
101+ scores = scores .reshape (- 1 )
102+
103+ start_num = 0
104+ ground_truths = {}
105+ labels = []
106+ for i in range (len (data )):
107+ tmp = {}
108+ tmp_labels = []
109+ for ind in range (len (data [i ]['pos' ])):
110+ try :
111+ tmp [str (start_num + ind )] = int (data [i ]['pos_label_scores' ][ind ])
112+ except Exception as e :
113+ # print(e)
114+ tmp [str (start_num + ind )] = 1
115+ tmp_labels .append (start_num + ind )
116+ ground_truths [str (i )] = tmp
117+ start_num += data_num [i ]
118+ labels .append (tmp_labels )
119+
120+ start_num = 0
121+ rerank_results = {}
122+ predicts = []
123+ for i in range (len (data )):
124+ tmp = {}
125+ tmp_predicts = [(start_num + ind , scores [start_num + ind ]) for ind in range (data_num [i ])]
126+ tmp_predicts = [idx for (idx , _ ) in sorted (tmp_predicts , key = lambda x : x [1 ], reverse = True )]
127+ for ind in range (data_num [i ]):
128+ tmp [str (start_num + ind )] = float (scores [start_num + ind ])
129+ rerank_results [str (i )] = tmp
130+ start_num += data_num [i ]
131+ predicts .append (tmp_predicts )
132+
133+ ndcg = {}
134+ _map = {}
135+ recall = {}
136+ precision = {}
137+
138+ for k in k_values :
139+ ndcg [f"NDCG@{ k } " ] = 0.0
140+ _map [f"MAP@{ k } " ] = 0.0
141+ recall [f"Recall@{ k } " ] = 0.0
142+ precision [f"Precision@{ k } " ] = 0.0
143+
144+ map_string = "map_cut." + "," .join ([str (k ) for k in k_values ])
145+ ndcg_string = "ndcg_cut." + "," .join ([str (k ) for k in k_values ])
146+ recall_string = "recall." + "," .join ([str (k ) for k in k_values ])
147+ precision_string = "P." + "," .join ([str (k ) for k in k_values ])
148+ evaluator = pytrec_eval .RelevanceEvaluator (ground_truths ,
149+ {map_string , ndcg_string , recall_string , precision_string })
150+
151+ scores = evaluator .evaluate (rerank_results )
152+
153+ for query_id in scores .keys ():
154+ for k in k_values :
155+ ndcg [f"NDCG@{ k } " ] += scores [query_id ]["ndcg_cut_" + str (k )]
156+ _map [f"MAP@{ k } " ] += scores [query_id ]["map_cut_" + str (k )]
157+ recall [f"Recall@{ k } " ] += scores [query_id ]["recall_" + str (k )]
158+ precision [f"Precision@{ k } " ] += scores [query_id ]["P_" + str (k )]
159+
160+ for k in k_values :
161+ ndcg [f"NDCG@{ k } " ] = round (ndcg [f"NDCG@{ k } " ] / len (scores ), 5 )
162+ _map [f"MAP@{ k } " ] = round (_map [f"MAP@{ k } " ] / len (scores ), 5 )
163+ recall [f"Recall@{ k } " ] = round (recall [f"Recall@{ k } " ] / len (scores ), 5 )
164+ precision [f"Precision@{ k } " ] = round (precision [f"Precision@{ k } " ] / len (scores ), 5 )
165+
166+ mrr = evaluate_mrr (predicts , labels , k_values )
167+
168+ if 'mrr' in metrics :
169+ print (mrr )
170+ if 'recall' in metrics :
171+ print (recall )
172+ if 'ndcg' in metrics :
173+ print (ndcg )
174+ if 'map' in metrics :
175+ print (_map )
176+ if 'precision' in metrics :
177+ print (precision )
178+
179+ if __name__ == "__main__" :
180+ main ()
0 commit comments