Skip to content

Commit d1c3b3f

Browse files
committed
evaluation docs
1 parent 7caa598 commit d1c3b3f

32 files changed

Lines changed: 744 additions & 0 deletions

FlagEmbedding/evaluation/air_bench/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
@dataclass
66
class AIRBenchEvalModelArgs:
7+
"""
8+
Evaluation Model arguments for AIR Bench.
9+
"""
710
embedder_name_or_path: str = field(
811
metadata={"help": "The embedder name or path.", "required": True}
912
)

FlagEmbedding/evaluation/air_bench/runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111

1212
class AIRBenchEvalRunner:
13+
"""
14+
Evaluation runner for AIR Bench.
15+
16+
Args:
17+
eval_args (AIRBenchEvalArgs): :class:AIRBenchEvalArgs object with the evaluation arguments.
18+
model_args (AIRBenchEvalModelArgs): :class:AIRBenchEvalModelArgs object with the model arguments.
19+
"""
1320
def __init__(
1421
self,
1522
eval_args: AIRBenchEvalArgs,
@@ -22,6 +29,12 @@ def __init__(
2229
self.retriever, self.reranker = self.load_retriever_and_reranker()
2330

2431
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
32+
"""Load retriever and reranker for evaluation
33+
34+
Returns:
35+
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
36+
:class:EvalReranker object if reranker provided.
37+
"""
2538
embedder, reranker = AbsEvalRunner.get_models(self.model_args)
2639
retriever = EvalDenseRetriever(
2740
embedder,
@@ -33,6 +46,9 @@ def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalRer
3346
return retriever, reranker
3447

3548
def run(self):
49+
"""
50+
Run the whole evaluation.
51+
"""
3652
evaluation = AIRBench(
3753
benchmark_version=self.eval_args.benchmark_version,
3854
task_types=self.eval_args.task_types,

FlagEmbedding/evaluation/beir/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
@dataclass
77
class BEIREvalArgs(AbsEvalArgs):
8+
"""
9+
Argument class for BEIR evaluation.
10+
"""
811
use_special_instructions: bool = field(
912
default=False, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: False"}
1013
)

FlagEmbedding/evaluation/beir/data_loader.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,42 @@
1313

1414

1515
class BEIREvalDataLoader(AbsEvalDataLoader):
16+
"""
17+
Data loader class for BEIR.
18+
"""
1619
def available_dataset_names(self) -> List[str]:
20+
"""
21+
Get the available dataset names.
22+
23+
Returns:
24+
List[str]: All the available dataset names.
25+
"""
1726
return ['arguana', 'climate-fever', 'cqadupstack', 'dbpedia-entity', 'fever', 'fiqa', 'hotpotqa', 'msmarco', 'nfcorpus', 'nq', 'quora', 'scidocs', 'scifact', 'trec-covid', 'webis-touche2020']
1827

1928
def available_sub_dataset_names(self, dataset_name: Optional[str] = None) -> List[str]:
29+
"""
30+
Get the available sub-dataset names.
31+
32+
Args:
33+
dataset_name (Optional[str], optional): All the available sub-dataset names. Defaults to ``None``.
34+
35+
Returns:
36+
List[str]: All the available sub-dataset names.
37+
"""
2038
if dataset_name == 'cqadupstack':
2139
return ['android', 'english', 'gaming', 'gis', 'mathematica', 'physics', 'programmers', 'stats', 'tex', 'unix', 'webmasters', 'wordpress']
2240
return None
2341

2442
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
43+
"""
44+
Get the avaialble splits.
45+
46+
Args:
47+
dataset_name (str): Dataset name.
48+
49+
Returns:
50+
List[str]: All the available splits for the dataset.
51+
"""
2552
if dataset_name == 'msmarco':
2653
return ['dev']
2754
return ['test']
@@ -32,6 +59,16 @@ def _load_remote_corpus(
3259
sub_dataset_name: Optional[str] = None,
3360
save_dir: Optional[str] = None
3461
) -> datasets.DatasetDict:
62+
"""Load the corpus dataset from HF.
63+
64+
Args:
65+
dataset_name (str): Name of the dataset.
66+
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
67+
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
68+
69+
Returns:
70+
datasets.DatasetDict: Loaded datasets instance of corpus.
71+
"""
3572
if dataset_name != 'cqadupstack':
3673
corpus = datasets.load_dataset(
3774
'BeIR/{d}'.format(d=dataset_name),
@@ -94,6 +131,17 @@ def _load_remote_qrels(
94131
split: str = 'dev',
95132
save_dir: Optional[str] = None
96133
) -> datasets.DatasetDict:
134+
"""Load the qrels from HF.
135+
136+
Args:
137+
dataset_name (str): Name of the dataset.
138+
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
139+
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
140+
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
141+
142+
Returns:
143+
datasets.DatasetDict: Loaded datasets instance of qrel.
144+
"""
97145
if dataset_name != 'cqadupstack':
98146
qrels = datasets.load_dataset(
99147
'BeIR/{d}-qrels'.format(d=dataset_name),
@@ -168,6 +216,17 @@ def _load_remote_queries(
168216
split: str = 'test',
169217
save_dir: Optional[str] = None
170218
) -> datasets.DatasetDict:
219+
"""Load the queries from HF.
220+
221+
Args:
222+
dataset_name (str): Name of the dataset.
223+
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
224+
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
225+
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
226+
227+
Returns:
228+
datasets.DatasetDict: Loaded datasets instance of queries.
229+
"""
171230
qrels = self.load_qrels(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
172231

173232
if dataset_name != 'cqadupstack':
@@ -230,6 +289,15 @@ def _load_remote_queries(
230289
return datasets.DatasetDict(queries_dict)
231290

232291
def load_corpus(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None) -> datasets.DatasetDict:
292+
"""Load the corpus from the dataset.
293+
294+
Args:
295+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
296+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
297+
298+
Returns:
299+
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
300+
"""
233301
if self.dataset_dir is not None:
234302
if dataset_name is None:
235303
save_dir = self.dataset_dir
@@ -240,6 +308,19 @@ def load_corpus(self, dataset_name: Optional[str] = None, sub_dataset_name: Opti
240308
return self._load_remote_corpus(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name)
241309

242310
def load_qrels(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
311+
"""Load the qrels from the dataset.
312+
313+
Args:
314+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
315+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
316+
split (str, optional): The split to load relevance from. Defaults to ``'test'``.
317+
318+
Raises:
319+
ValueError
320+
321+
Returns:
322+
datasets.DatasetDict: A dict of relevance of query and document.
323+
"""
243324
if self.dataset_dir is not None:
244325
if dataset_name is None:
245326
save_dir = self.dataset_dir
@@ -256,6 +337,19 @@ def load_qrels(self, dataset_name: Optional[str] = None, sub_dataset_name: Optio
256337
return self._load_remote_qrels(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
257338

258339
def load_queries(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
340+
"""Load the queries from the dataset.
341+
342+
Args:
343+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
344+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
345+
split (str, optional): The split to load queries from. Defaults to ``'test'``.
346+
347+
Raises:
348+
ValueError
349+
350+
Returns:
351+
datasets.DatasetDict: A dict of queries with id as key, query text as value.
352+
"""
259353
if self.dataset_dir is not None:
260354
if dataset_name is None:
261355
save_dir = self.dataset_dir
@@ -272,6 +366,16 @@ def load_queries(self, dataset_name: Optional[str] = None, sub_dataset_name: Opt
272366
return self._load_remote_queries(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
273367

274368
def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None) -> datasets.DatasetDict:
369+
"""Load corpus from local dataset.
370+
371+
Args:
372+
save_dir (str): Path to save the loaded corpus.
373+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
374+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
375+
376+
Returns:
377+
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
378+
"""
275379
if sub_dataset_name is None:
276380
corpus_path = os.path.join(save_dir, 'corpus.jsonl')
277381
else:
@@ -291,6 +395,20 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None,
291395
return datasets.DatasetDict(corpus)
292396

293397
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
398+
"""Load relevance from local dataset.
399+
400+
Args:
401+
save_dir (str): Path to save the loaded relevance.
402+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
403+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
404+
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
405+
406+
Raises:
407+
ValueError
408+
409+
Returns:
410+
datasets.DatasetDict: A dict of relevance of query and document.
411+
"""
294412
checked_split = self.check_splits(split)
295413
if len(checked_split) == 0:
296414
raise ValueError(f"Split {split} not found in the dataset.")
@@ -318,6 +436,20 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
318436
return datasets.DatasetDict(qrels)
319437

320438
def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
439+
"""Load queries from local dataset.
440+
441+
Args:
442+
save_dir (str): Path to save the loaded queries.
443+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
444+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
445+
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
446+
447+
Raises:
448+
ValueError
449+
450+
Returns:
451+
datasets.DatasetDict: A dict of queries with id as key, query text as value.
452+
"""
321453
checked_split = self.check_splits(split)
322454
if len(checked_split) == 0:
323455
raise ValueError(f"Split {split} not found in the dataset.")

FlagEmbedding/evaluation/beir/evaluator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111

1212
class BEIREvaluator(AbsEvaluator):
13+
"""
14+
Evaluator class of BEIR
15+
"""
1316
def check_data_info(
1417
self,
1518
data_info: Dict[str, str],
@@ -19,6 +22,23 @@ def check_data_info(
1922
dataset_name: Optional[str] = None,
2023
sub_dataset_name: Optional[str] = None,
2124
):
25+
"""Check the validity of data info.
26+
27+
Args:
28+
data_info (Dict[str, str]): The loaded data info to be check.
29+
model_name (str): Name of model used.
30+
reranker_name (str): Name of reranker used.
31+
split (str): Split used in searching.
32+
dataset_name (Optional[str], optional): Name of dataset used. Defaults to None.
33+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
34+
35+
Raises:
36+
ValueError: eval_name mismatch
37+
ValueError: model_name or reranker_name mismatch
38+
ValueError: split mismatch
39+
ValueError: dataset_name mismatch
40+
ValueError: sub_dataset_name mismatch
41+
"""
2242
if data_info["eval_name"] != self.eval_name:
2343
raise ValueError(
2444
f'eval_name mismatch: {data_info["eval_name"]} vs {self.eval_name}'
@@ -317,11 +337,21 @@ def __call__(
317337
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
318338
if reranker is not None:
319339
reranker.stop_multi_process_pool()
340+
320341
def evaluate_results(
321342
self,
322343
search_results_save_dir: str,
323344
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
324345
):
346+
"""Compute metrics according to the results in the directory.
347+
348+
Args:
349+
search_results_save_dir (str): Path to the search results.
350+
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
351+
352+
Returns:
353+
dict: Evaluation results.
354+
"""
325355
eval_results_dict = {}
326356
cqadupstack_results = None
327357
cqadupstack_num = 0
@@ -386,6 +416,18 @@ def save_search_results(
386416
dataset_name: Optional[str] = None,
387417
sub_dataset_name: Optional[str] = None,
388418
):
419+
"""Save the metadata and search results into a file.
420+
421+
Args:
422+
eval_name (str): The experiment name of current evaluation.
423+
model_name (str): Name of model used.
424+
reranker_name (str): Name of reranker used.
425+
search_results (Dict[str, Dict[str, float]]): Dictionary of search results.
426+
output_path (str): Output path to write the results.
427+
split (str): Split used in searching.
428+
dataset_name (Optional[str], optional): Name of dataset used. Defaults to ``None``.
429+
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
430+
"""
389431
data = {
390432
"eval_name": eval_name,
391433
"model_name": model_name,

FlagEmbedding/evaluation/beir/runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010

1111
class BEIREvalRunner(AbsEvalRunner):
12+
"""
13+
Runner class of BEIR evaluation.
14+
"""
1215
def run(self):
16+
"""
17+
Run the whole evaluation.
18+
"""
1319
if self.eval_args.dataset_names is None:
1420
dataset_names = self.data_loader.available_dataset_names()
1521
else:
@@ -54,6 +60,11 @@ def run(self):
5460
)
5561

5662
def load_data_loader(self) -> BEIREvalDataLoader:
63+
"""Load the data loader
64+
65+
Returns:
66+
BEIREvalDataLoader: BEIR data loader object.
67+
"""
5768
data_loader = BEIREvalDataLoader(
5869
eval_name=self.eval_args.eval_name,
5970
dataset_dir=self.eval_args.dataset_dir,
@@ -64,6 +75,11 @@ def load_data_loader(self) -> BEIREvalDataLoader:
6475
return data_loader
6576

6677
def load_evaluator(self) -> BEIREvaluator:
78+
"""Load the evaluator for evaluation
79+
80+
Returns:
81+
BEIREvaluator: The BEIR evaluator to run the evaluation.
82+
"""
6783
evaluator = BEIREvaluator(
6884
eval_name=self.eval_args.eval_name,
6985
data_loader=self.data_loader,

0 commit comments

Comments
 (0)