Skip to content

Commit 7ae0ecf

Browse files
committed
evaluation docstring
1 parent 134a1ad commit 7ae0ecf

4 files changed

Lines changed: 99 additions & 2 deletions

File tree

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __call__(
111111
dataset_name: Optional[str] = None,
112112
**kwargs,
113113
):
114-
"""Called to the whole evaluation process.
114+
"""This is called during the evaluation process.
115115
116116
Args:
117117
splits (Union[str, List[str]]): Splits of datasets.

FlagEmbedding/abc/evaluation/runner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515

1616
class AbsEvalRunner:
17+
"""
18+
Abstract class of evaluation runner.
19+
20+
Args:
21+
eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments.
22+
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
23+
"""
1724
def __init__(
1825
self,
1926
eval_args: AbsEvalArgs,
@@ -28,6 +35,15 @@ def __init__(
2835

2936
@staticmethod
3037
def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]:
38+
"""Get the embedding and reranker model
39+
40+
Args:
41+
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
42+
43+
Returns:
44+
Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and
45+
:class:FlagAutoReranker object of reranker model if path provided.
46+
"""
3147
embedder = FlagAutoModel.from_finetuned(
3248
model_name_or_path=model_args.embedder_name_or_path,
3349
model_class=model_args.embedder_model_class,
@@ -74,6 +90,12 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA
7490
return embedder, reranker
7591

7692
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
93+
"""Load retriever and reranker for evaluation
94+
95+
Returns:
96+
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
97+
:class:EvalReranker object if reranker provided.
98+
"""
7799
embedder, reranker = self.get_models(self.model_args)
78100
retriever = EvalDenseRetriever(
79101
embedder,
@@ -85,6 +107,11 @@ def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalRer
85107
return retriever, reranker
86108

87109
def load_data_loader(self) -> AbsEvalDataLoader:
110+
"""Load the data loader
111+
112+
Returns:
113+
AbsEvalDataLoader: Data loader object for that specific task.
114+
"""
88115
data_loader = AbsEvalDataLoader(
89116
eval_name=self.eval_args.eval_name,
90117
dataset_dir=self.eval_args.dataset_dir,
@@ -95,6 +122,11 @@ def load_data_loader(self) -> AbsEvalDataLoader:
95122
return data_loader
96123

97124
def load_evaluator(self) -> AbsEvaluator:
125+
"""Load the evaluator for evaluation
126+
127+
Returns:
128+
AbsEvaluator: the evaluator to run the evaluation.
129+
"""
98130
evaluator = AbsEvaluator(
99131
eval_name=self.eval_args.eval_name,
100132
data_loader=self.data_loader,
@@ -109,6 +141,18 @@ def evaluate_metrics(
109141
output_path: str = "./eval_dev_results.md",
110142
metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"]
111143
):
144+
"""Evaluate the provided metrics and write the results.
145+
146+
Args:
147+
search_results_save_dir (str): Path to save the search results.
148+
output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown".
149+
output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md".
150+
metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"].
151+
152+
Raises:
153+
FileNotFoundError: Eval results not found
154+
ValueError: Invalid output method
155+
"""
112156
eval_results_dict = {}
113157
for model_name in sorted(os.listdir(search_results_save_dir)):
114158
model_search_results_save_dir = os.path.join(search_results_save_dir, model_name)
@@ -136,6 +180,9 @@ def evaluate_metrics(
136180
raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']")
137181

138182
def run(self):
183+
"""
184+
Run the whole evaluation.
185+
"""
139186
if self.eval_args.dataset_names is None:
140187
dataset_names = self.data_loader.available_dataset_names()
141188
else:

FlagEmbedding/abc/evaluation/searcher.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717

1818
class EvalRetriever(ABC):
19+
"""
20+
This is the base class for retriever.
21+
"""
1922
def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False):
2023
self.embedder = embedder
2124
self.search_top_k = search_top_k
@@ -45,7 +48,7 @@ def __call__(
4548
**kwargs,
4649
) -> Dict[str, Dict[str, float]]:
4750
"""
48-
This is called during the retrieval process.
51+
Abstract method to be overrode. This is called during the retrieval process.
4952
5053
Parameters:
5154
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
@@ -63,6 +66,9 @@ def __call__(
6366

6467

6568
class EvalDenseRetriever(EvalRetriever):
69+
"""
70+
Child class of :class:EvalRetriever for dense retrieval.
71+
"""
6672
def __call__(
6773
self,
6874
corpus: Dict[str, Dict[str, Any]],
@@ -144,6 +150,9 @@ def __call__(
144150

145151

146152
class EvalReranker:
153+
"""
154+
Class for reranker during evaluation.
155+
"""
147156
def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100):
148157
self.reranker = reranker
149158
self.rerank_top_k = rerank_top_k

FlagEmbedding/abc/evaluation/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ def evaluate_mrr(
1616
results: Dict[str, Dict[str, float]],
1717
k_values: List[int],
1818
) -> Tuple[Dict[str, float]]:
19+
"""Compute mean reciprocal rank (MRR).
20+
21+
Args:
22+
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
23+
results (Dict[str, Dict[str, float]]): Search results to evaluate.
24+
k_values (List[int]): Cutoffs.
25+
26+
Returns:
27+
Tuple[Dict[str, float]]: MRR results at provided k values.
28+
"""
1929
mrr = defaultdict(list)
2030

2131
k_max, top_hits = max(k_values), {}
@@ -53,6 +63,17 @@ def evaluate_metrics(
5363
Dict[str, float],
5464
Dict[str, float],
5565
]:
66+
"""Evaluate the main metrics.
67+
68+
Args:
69+
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
70+
results (Dict[str, Dict[str, float]]): Search results to evaluate.
71+
k_values (List[int]): Cutoffs.
72+
73+
Returns:
74+
Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at
75+
different provided k values.
76+
"""
5677
all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
5778

5879
map_string = "map_cut." + ",".join([str(k) for k in k_values])
@@ -93,6 +114,17 @@ def index(
93114
load_path: Optional[str] = None,
94115
device: Optional[str] = None
95116
):
117+
"""Create and add embeddings into a Faiss index.
118+
119+
Args:
120+
index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat".
121+
corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None.
122+
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
123+
device (Optional[str], optional): Device to hold Faiss index. Defaults to None.
124+
125+
Returns:
126+
faiss.Index: The Faiss index that contains all the corpus embeddings.
127+
"""
96128
if corpus_embeddings is None:
97129
corpus_embeddings = np.load(load_path)
98130

@@ -127,6 +159,15 @@ def search(
127159
"""
128160
1. Encode queries into dense embeddings;
129161
2. Search through faiss index
162+
163+
Args:
164+
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
165+
k (int, optional): Top k numbers of closest neighbours. Defaults to 100.
166+
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None.
167+
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
168+
169+
Returns:
170+
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.
130171
"""
131172
if query_embeddings is None:
132173
query_embeddings = np.load(load_path)

0 commit comments

Comments
 (0)