Skip to content

Commit 43bdddf

Browse files
committed
docstring
1 parent d2be215 commit 43bdddf

1 file changed

Lines changed: 40 additions & 1 deletion

File tree

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,26 @@
1616

1717
class AbsReranker(ABC):
1818
"""
19-
Base class for embedder.
19+
Base class for Reranker.
2020
Extend this class and implement `compute_score_single_gpu` for custom rerankers.
21+
22+
Args:
23+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
24+
load a model from HuggingFace Hub with the name.
25+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
26+
degradation. Default: `False`.
27+
query_instruction_for_rerank: (Optional[str], optional): Query instruction for reranking, which will be used with
28+
with `query_instruction_format`. Default: `None`.
29+
query_instruction_format: (str, optional): The template for `query_instruction_for_rerank`. Default: `"{}{}"`.
30+
passage_instruction_for_rerank (Optional[str], optional): Passage instruction for reranking. Default: `None`.
31+
passage_instruction_format (str, optional): Passage instruction format when using `passage_instruction_for_rerank`.
32+
Default: `"{}{}"`.
33+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Default: `None`.
34+
batch_size (int, optional): Batch size for inference. Default: `128`.
35+
query_max_length (int, optional): Maximum length for query. Default: `None`.
36+
passage_max_length (int, optional): Maximum length for passage. Default: `512`.
37+
normalize (bool, optional): If true, normalize the result. Default: `False`.
38+
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
2139
"""
2240

2341
def __init__(
@@ -61,6 +79,17 @@ def __init__(
6179

6280
@staticmethod
6381
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
82+
"""
83+
84+
Args:
85+
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`.
86+
87+
Raises:
88+
ValueError: devices should be a string or an integer or a list of strings or a list of integers.
89+
90+
Returns:
91+
List[str]: a list of target devices in format
92+
"""
6493
if devices is None:
6594
if torch.cuda.is_available():
6695
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
@@ -85,6 +114,16 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
85114
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
86115

87116
def get_detailed_instruct(self, instruction_format: str, instruction: str, sentence: str):
117+
"""Combine the instruction and sentence along with the instruction format.
118+
119+
Args:
120+
instruction_format (str): Format for instruction.
121+
instruction (str): The text of instruction.
122+
sentence (str): The sentence to concatenate with.
123+
124+
Returns:
125+
str: the complete sentence with instruction
126+
"""
88127
return instruction_format.format(instruction, sentence)
89128

90129
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):

0 commit comments

Comments
 (0)