Skip to content

Commit b5e4597

Browse files
authored
Merge pull request #1186 from ZiyiXia/master
Inference docstring
2 parents 29fb8f4 + f370b72 commit b5e4597

13 files changed

Lines changed: 564 additions & 74 deletions

File tree

FlagEmbedding/abc/evaluation/data_loader.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ class AbsEvalDataLoader(ABC):
1717
1818
Args:
1919
eval_name (str): The experiment name of current evaluation.
20-
dataset_dir (str, optional): path to the datasets. Defaults to None.
21-
cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to None.
22-
token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to None.
23-
force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to False.
20+
dataset_dir (str, optional): path to the datasets. Defaults to :data:`None`.
21+
cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to :data:`None`.
22+
token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to :data:`None`.
23+
force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to :data:`False`.
2424
"""
2525
def __init__(
2626
self,
@@ -98,7 +98,7 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
9898
"""Load the corpus from the dataset.
9999
100100
Args:
101-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
101+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
102102
103103
Returns:
104104
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
@@ -116,8 +116,8 @@ def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') ->
116116
"""Load the corpus from the dataset.
117117
118118
Args:
119-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
120-
split (str, optional): The split to load relevance from. Defaults to 'test'.
119+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
120+
split (str, optional): The split to load relevance from. Defaults to :data:'test'.
121121
122122
Raises:
123123
ValueError
@@ -144,8 +144,8 @@ def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test')
144144
"""Load the queries from the dataset.
145145
146146
Args:
147-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
148-
split (str, optional): The split to load queries from. Defaults to 'test'.
147+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
148+
split (str, optional): The split to load queries from. Defaults to :data:`'test'`.
149149
150150
Raises:
151151
ValueError
@@ -176,8 +176,8 @@ def _load_remote_corpus(
176176
"""Abstract method to load corpus from remote dataset, to be overrode in child class.
177177
178178
Args:
179-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
180-
save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to None.
179+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
180+
save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to :data:`None`.
181181
182182
Raises:
183183
NotImplementedError: Loading remote corpus is not implemented.
@@ -196,9 +196,9 @@ def _load_remote_qrels(
196196
"""Abstract method to load relevance from remote dataset, to be overrode in child class.
197197
198198
Args:
199-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
200-
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
201-
save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to None.
199+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
200+
split (str, optional): Split to load from the remote dataset. Defaults to :data:`'test'`.
201+
save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to :data:`None`.
202202
203203
Raises:
204204
NotImplementedError: Loading remote qrels is not implemented.
@@ -217,9 +217,9 @@ def _load_remote_queries(
217217
"""Abstract method to load queries from remote dataset, to be overrode in child class.
218218
219219
Args:
220-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
221-
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
222-
save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to None.
220+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
221+
split (str, optional): Split to load from the remote dataset. Defaults to :data:`'test'`.
222+
save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to :data:`None`.
223223
224224
Raises:
225225
NotImplementedError
@@ -234,7 +234,7 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None)
234234
235235
Args:
236236
save_dir (str): Path to save the loaded corpus.
237-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
237+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
238238
239239
Returns:
240240
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
@@ -257,8 +257,8 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
257257
258258
Args:
259259
save_dir (str): Path to save the loaded relevance.
260-
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
261-
split (str, optional): Split to load from the local dataset. Defaults to 'test'.
260+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
261+
split (str, optional): Split to load from the local dataset. Defaults to :data:`'test'`.
262262
263263
Raises:
264264
ValueError

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ def __call__(
116116
Args:
117117
splits (Union[str, List[str]]): Splits of datasets.
118118
search_results_save_dir (str): Directory to save the search results.
119-
retriever (EvalRetriever): object of :class:EvalRetriever
120-
reranker (Optional[EvalReranker], optional): Object of :class:EvalReranker. Defaults to None.
121-
corpus_embd_save_dir (Optional[str], optional): Directory to save the embedded corpus. Defaults to None.
122-
ignore_identical_ids (bool, optional): If True, will ignore identical ids in search results. Defaults to False.
123-
k_values (List[int], optional): Cutoffs. Defaults to [1, 3, 5, 10, 100, 1000].
124-
dataset_name (Optional[str], optional): Name of the datasets. Defaults to None.
119+
retriever (EvalRetriever): object of :class:EvalRetriever.
120+
reranker (Optional[EvalReranker], optional): Object of :class:EvalReranker. Defaults to :data:`None`.
121+
corpus_embd_save_dir (Optional[str], optional): Directory to save the embedded corpus. Defaults to :data:`None`.
122+
ignore_identical_ids (bool, optional): If True, will ignore identical ids in search results. Defaults to :data:`False`.
123+
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
124+
dataset_name (Optional[str], optional): Name of the datasets. Defaults to :data:`None`.
125125
"""
126126
# Check Splits
127127
checked_splits = self.data_loader.check_splits(splits, dataset_name=dataset_name)
@@ -278,7 +278,7 @@ def save_search_results(
278278
search_results (Dict[str, Dict[str, float]]): Dictionary of search results.
279279
output_path (str): Output path to write the results.
280280
split (str): Split used in searching.
281-
dataset_name (Optional[str], optional): Name of dataset used. Defaults to None.
281+
dataset_name (Optional[str], optional): Name of dataset used. Defaults to :data:`None`.
282282
"""
283283
data = {
284284
"eval_name": eval_name,
@@ -354,7 +354,7 @@ def evaluate_results(
354354
355355
Args:
356356
search_results_save_dir (str): Path to the search results.
357-
k_values (List[int], optional): Cutoffs. Defaults to [1, 3, 5, 10, 100, 1000].
357+
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
358358
359359
Returns:
360360
_type_: _description_

FlagEmbedding/abc/evaluation/runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def evaluate_metrics(
145145
146146
Args:
147147
search_results_save_dir (str): Path to save the search results.
148-
output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown".
149-
output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md".
150-
metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"].
148+
output_method (str, optional): Output results to `json` or `markdown`. Defaults to :data:`"markdown"`.
149+
output_path (str, optional): Path to write the output. Defaults to :data:`"./eval_dev_results.md"`.
150+
metrics (Union[str, List[str]], optional): metrics to use. Defaults to :data:`["ndcg_at_10", "recall_at_10"]`.
151151
152152
Raises:
153153
FileNotFoundError: Eval results not found

FlagEmbedding/abc/evaluation/searcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __call__(
5757
queries: Dict[str, str]: Queries to search for.
5858
Structure: {<qid>: <query>}.
5959
Example: {"q-0": "This is a query."}
60+
corpus_embd_save_dir (Optional[str]): Defaults to :data:`None`.
61+
ignore_identical_ids (bool): Defaults to :data:`False`.
6062
**kwargs: Any: Additional arguments.
6163
6264
Returns: Dict[str, Dict[str, float]]: Top-k search results for each query. k is specified by search_top_k.
@@ -87,6 +89,8 @@ def __call__(
8789
queries: Dict[str, str]: Queries to search for.
8890
Structure: {<qid>: <query>}.
8991
Example: {"q-0": "This is a query."}
92+
corpus_embd_save_dir (Optional[str]): Defaults to :data:`None`.
93+
ignore_identical_ids (bool): Defaults to :data:`False`.
9094
**kwargs: Any: Additional arguments.
9195
9296
Returns: Dict[str, Dict[str, float]]: Top-k search results for each query. k is specified by search_top_k.

FlagEmbedding/abc/evaluation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ def search(
162162
163163
Args:
164164
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
165-
k (int, optional): Top k numbers of closest neighbours. Defaults to 100.
166-
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None.
167-
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
165+
k (int, optional): Top k numbers of closest neighbours. Defaults to :data:`100`.
166+
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to :data:`None`.
167+
load_path (Optional[str], optional): Path to load embeddings from. Defaults to :data:`None`.
168168
169169
Returns:
170170
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,23 @@
1818
class AbsEmbedder(ABC):
1919
"""
2020
Base class for embedder.
21-
Extend this class and implement :meth:`encode_queries`, :meth:`encode_passages`, :meth:`encode` for custom embedders.
21+
Extend this class and implement :meth:`encode_queries`, :meth:`encode_corpus`, :meth:`encode` for custom embedders.
2222
2323
Args:
2424
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
2525
load a model from HuggingFace Hub with the name.
26-
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Default: `True`.
26+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
2727
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
28-
degradation. Default: `True`.
28+
degradation. Defaults to :data:`True`.
2929
query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
30-
with :attr:`query_instruction_format`. Default: `None`.
31-
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Default: `"{}{}"`.
32-
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Default: `None`.
33-
batch_size (int, optional): Batch size for inference. Default: `256`.
34-
query_max_length (int, optional): Maximum length for query. Default: `512`.
35-
passage_max_length (int, optional): Maximum length for passage. Default: `512`.
36-
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Default: `None`.
37-
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Default: `"{}{}"`.
30+
with :attr:`query_instruction_format`. Defaults to :data:`None`.
31+
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
32+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
33+
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
34+
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
35+
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
3836
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
39-
Default: `True`.
37+
Defaults to :data:`True`.
4038
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
4139
"""
4240

@@ -139,10 +137,10 @@ def encode_queries(
139137
140138
Args:
141139
queries (Union[List[str], str]): Input queries to encode.
142-
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
143-
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
140+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
141+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
144142
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
145-
be a Torch Tensor. Defaults to None.
143+
be a Torch Tensor. Defaults to :data:`None`.
146144
147145
Returns:
148146
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
@@ -173,10 +171,10 @@ def encode_corpus(
173171
174172
Args:
175173
corpus (Union[List[str], str]): Input corpus to encode.
176-
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
177-
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
174+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
175+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
178176
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
179-
be a Torch Tensor. Defaults to None.
177+
be a Torch Tensor. Defaults to :data:`None`.
180178
181179
Returns:
182180
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
@@ -212,12 +210,12 @@ def encode(
212210
213211
Args:
214212
sentences (Union[List[str], str]): Input sentences to encode.
215-
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
216-
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
213+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
214+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
217215
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
218-
be a Torch Tensor. Defaults to None.
219-
instruction (Optional[str], optional): The text of instruction. Defaults to None.
220-
instruction_format (Optional[str], optional): Format for instruction. Defaults to None.
216+
be a Torch Tensor. Defaults to :data:`None`.
217+
instruction (Optional[str], optional): The text of instruction. Defaults to :data:`None`.
218+
instruction_format (Optional[str], optional): Format for instruction. Defaults to :data:`None`.
221219
222220
Returns:
223221
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
@@ -396,7 +394,7 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch
396394
"""concatenate and return the results from all the processes
397395
398396
Args:
399-
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): a list of results from all the processes
397+
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes.
400398
401399
Raises:
402400
NotImplementedError: Unsupported type for results_list

0 commit comments

Comments
 (0)