Skip to content

Commit fcd6937

Browse files
committed
abc inference docstring
1 parent 43bdddf commit fcd6937

2 files changed

Lines changed: 25 additions & 9 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
8989
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`.
9090
9191
Raises:
92-
ValueError: devices should be a string or an integer or a list of strings or a list of integers.
92+
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
9393
9494
Returns:
95-
List[str]: a list of target devices in format
95+
List[str]: A list of target devices in format
9696
"""
9797
if devices is None:
9898
if torch.cuda.is_available():
@@ -127,7 +127,7 @@ def get_detailed_instruct(instruction_format: str, instruction: str, sentence: s
127127
sentence (str): The sentence to concatenate with.
128128
129129
Returns:
130-
str: the complete sentence with instruction
130+
str: The complete sentence with instruction
131131
"""
132132
return instruction_format.format(instruction, sentence)
133133

@@ -149,7 +149,7 @@ def encode_queries(
149149
be a Torch Tensor. Defaults to None.
150150
151151
Returns:
152-
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
152+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
153153
"""
154154
if batch_size is None: batch_size = self.batch_size
155155
if max_length is None: max_length = self.query_max_length
@@ -183,7 +183,7 @@ def encode_corpus(
183183
be a Torch Tensor. Defaults to None.
184184
185185
Returns:
186-
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
186+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
187187
"""
188188
passage_instruction_for_retrieval = self.kwargs.get("passage_instruction_for_retrieval", None)
189189
passage_instruction_format = self.kwargs.get("passage_instruction_format", "{}{}")

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
8282
"""
8383
8484
Args:
85-
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`.
85+
devices (Union[str, int, List[str], List[int]]): Specified devices, can be `str`, `int`, list of `str`, or list of `int`.
8686
8787
Raises:
88-
ValueError: devices should be a string or an integer or a list of strings or a list of integers.
88+
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
8989
9090
Returns:
91-
List[str]: a list of target devices in format
91+
List[str]: A list of target devices in format
9292
"""
9393
if devices is None:
9494
if torch.cuda.is_available():
@@ -122,11 +122,19 @@ def get_detailed_instruct(self, instruction_format: str, instruction: str, sente
122122
sentence (str): The sentence to concatenate with.
123123
124124
Returns:
125-
str: the complete sentence with instruction
125+
str: The complete sentence with instruction
126126
"""
127127
return instruction_format.format(instruction, sentence)
128128

129129
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
130+
"""get detailed instruct for all the inputs
131+
132+
Args:
133+
sentence_pairs (Union[str, List[str]]): Input sentence pairs
134+
135+
Returns:
136+
list[list[str]]: The complete sentence pairs with instruction
137+
"""
130138
if isinstance(sentence_pairs, str):
131139
sentence_pairs = [sentence_pairs]
132140

@@ -166,6 +174,14 @@ def compute_score(
166174
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
167175
**kwargs
168176
):
177+
"""Compute score for each sentence pair
178+
179+
Args:
180+
sentence_pairs (Union[List[Tuple[str, str]], Tuple[str, str]]): Input sentence pairs to compute.
181+
182+
Returns:
183+
numpy.ndarray: scores of all the sentence pairs.
184+
"""
169185
if isinstance(sentence_pairs[0], str):
170186
sentence_pairs = [sentence_pairs]
171187
sentence_pairs = self.get_detailed_inputs(sentence_pairs)

0 commit comments

Comments
 (0)