Skip to content

Commit ffb1b34

Browse files
committed
add reranker evaluate code
1 parent 9da55b7 commit ffb1b34

2 files changed

Lines changed: 191 additions & 0 deletions

File tree

FlagEmbedding/llm_reranker/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,17 @@ If you download reranker-v2-minicpm-layerwise, you can load it with the followin
370370
},
371371
```
372372

373+
## Evaluate Script
374+
375+
```shell
376+
python evaluate.py \
377+
--input_path ../LLARA/data/finetune/toy_finetune_data.jsonl \
378+
--metrics mrr recall ndcg map precision \
379+
--k_values 1 10 100
380+
```
381+
382+
If you want to use another reranker, please replace `reranker = FlagReranker('BAAI/bge-reranker-v2-m3', cache_dir=cache_dir, use_fp16=use_fp16)` with your own reranker.
383+
373384
## Evaluation
374385

375386
- llama-index.
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import json
2+
from dataclasses import dataclass, field
3+
from typing import List
4+
5+
import numpy as np
6+
import pytrec_eval
7+
from transformers import HfArgumentParser
8+
from FlagEmbedding import FlagReranker
9+
10+
@dataclass
11+
class Args():
12+
input_path: str = field(
13+
default="",
14+
metadata={'help': """
15+
The data path points to a file in JSONL format.
16+
Each line contains `query`, `pos`, and `neg`. Here, `query` is a string (`str`),
17+
while both `pos` and `neg` are lists of strings (`List[str]`).
18+
If each line includes `pos_label_scores`, it will use to compute `ndcg@k`, else it will set default `1`.
19+
"""}
20+
)
21+
metrics: List[str] = field(
22+
default=None, # usage example: recall mrr ndcg
23+
metadata={'help': 'The evaluation metrics, you can set recall / mrr / ndcg'}
24+
)
25+
k_values: List[int] = field(
26+
default=None,
27+
metadata={'help': 'Present the top-k metrics evaluation.'}
28+
)
29+
cache_dir: str = field(
30+
default=None,
31+
metadata={'help': 'The path to store the cache of reranker.'}
32+
)
33+
use_fp16: bool = field(
34+
default=True,
35+
metadata={'help': 'Whether to use fp16 to accelerate inference, it is not suitable for CPU only inference.'}
36+
)
37+
batch_size: int = field(
38+
default=512
39+
)
40+
max_length: int = field(
41+
default=1024
42+
)
43+
44+
45+
def evaluate_mrr(predicts, labels, cutoffs):
46+
"""
47+
Evaluate MRR.
48+
"""
49+
metrics = {}
50+
51+
# MRR
52+
mrrs = np.zeros(len(cutoffs))
53+
for pred, label in zip(predicts, labels):
54+
jump = False
55+
for i, x in enumerate(pred, 1):
56+
if x in label:
57+
for k, cutoff in enumerate(cutoffs):
58+
if i <= cutoff:
59+
mrrs[k] += 1 / i
60+
jump = True
61+
if jump:
62+
break
63+
mrrs /= len(predicts)
64+
for i, cutoff in enumerate(cutoffs):
65+
mrr = mrrs[i]
66+
metrics[f"MRR@{cutoff}"] = mrr
67+
68+
return metrics
69+
70+
def main():
71+
parser = HfArgumentParser([Args])
72+
args: Args = parser.parse_args_into_dataclasses()[0]
73+
input_path = args.input_path
74+
metrics = args.metrics if args.metrics is not None else ['recall', 'mrr', 'ndcg', 'map', 'precision']
75+
k_values = args.k_values if args.k_values is not None else [1, 5, 10, 50, 100]
76+
cache_dir = args.cache_dir
77+
use_fp16 = args.use_fp16
78+
batch_size = args.batch_size
79+
max_length = args.max_length
80+
81+
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', cache_dir=cache_dir, use_fp16=use_fp16)
82+
83+
data = []
84+
data_num = []
85+
with open(input_path) as f:
86+
for line in f:
87+
data.append(json.loads(line))
88+
89+
pairs = []
90+
for d in data:
91+
data_num.append(0)
92+
passages = []
93+
passages.extend(d['pos'])
94+
passages.extend(d['neg'])
95+
for p in passages:
96+
pairs.append((d['query'], p))
97+
data_num[-1] += 1
98+
99+
scores = reranker.compute_score(pairs, batch_size=batch_size, max_length=max_length)
100+
scores = np.asarray(scores)
101+
scores = scores.reshape(-1)
102+
103+
start_num = 0
104+
ground_truths = {}
105+
labels = []
106+
for i in range(len(data)):
107+
tmp = {}
108+
tmp_labels = []
109+
for ind in range(len(data[i]['pos'])):
110+
try:
111+
tmp[str(start_num + ind)] = int(data[i]['pos_label_scores'][ind])
112+
except Exception as e:
113+
# print(e)
114+
tmp[str(start_num + ind)] = 1
115+
tmp_labels.append(start_num + ind)
116+
ground_truths[str(i)] = tmp
117+
start_num += data_num[i]
118+
labels.append(tmp_labels)
119+
120+
start_num = 0
121+
rerank_results = {}
122+
predicts = []
123+
for i in range(len(data)):
124+
tmp = {}
125+
tmp_predicts = [(start_num + ind, scores[start_num + ind]) for ind in range(data_num[i])]
126+
tmp_predicts = [idx for (idx, _) in sorted(tmp_predicts, key=lambda x: x[1], reverse=True)]
127+
for ind in range(data_num[i]):
128+
tmp[str(start_num + ind)] = float(scores[start_num + ind])
129+
rerank_results[str(i)] = tmp
130+
start_num += data_num[i]
131+
predicts.append(tmp_predicts)
132+
133+
ndcg = {}
134+
_map = {}
135+
recall = {}
136+
precision = {}
137+
138+
for k in k_values:
139+
ndcg[f"NDCG@{k}"] = 0.0
140+
_map[f"MAP@{k}"] = 0.0
141+
recall[f"Recall@{k}"] = 0.0
142+
precision[f"Precision@{k}"] = 0.0
143+
144+
map_string = "map_cut." + ",".join([str(k) for k in k_values])
145+
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
146+
recall_string = "recall." + ",".join([str(k) for k in k_values])
147+
precision_string = "P." + ",".join([str(k) for k in k_values])
148+
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
149+
{map_string, ndcg_string, recall_string, precision_string})
150+
151+
scores = evaluator.evaluate(rerank_results)
152+
153+
for query_id in scores.keys():
154+
for k in k_values:
155+
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
156+
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
157+
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
158+
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
159+
160+
for k in k_values:
161+
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
162+
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
163+
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
164+
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
165+
166+
mrr = evaluate_mrr(predicts, labels, k_values)
167+
168+
if 'mrr' in metrics:
169+
print(mrr)
170+
if 'recall' in metrics:
171+
print(recall)
172+
if 'ndcg' in metrics:
173+
print(ndcg)
174+
if 'map' in metrics:
175+
print(_map)
176+
if 'precision' in metrics:
177+
print(precision)
178+
179+
if __name__ == "__main__":
180+
main()

0 commit comments

Comments
 (0)