Skip to content

Commit 60ddaf0

Browse files
committed
Merge branch 'new-flagembedding-v1' of https://github.com/hanhainebula/FlagEmbedding into new-flagembedding-v1
2 parents 071caeb + a7bbf09 commit 60ddaf0

6 files changed

Lines changed: 131 additions & 9 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ def __init__(
8383

8484
@staticmethod
8585
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
86+
"""
87+
88+
Args:
89+
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`.
90+
91+
Raises:
92+
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
93+
94+
Returns:
95+
List[str]: A list of target devices in format
96+
"""
8697
if devices is None:
8798
if torch.cuda.is_available():
8899
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
@@ -108,6 +119,16 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
108119

109120
@staticmethod
110121
def get_detailed_instruct(instruction_format: str, instruction: str, sentence: str):
122+
"""Combine the instruction and sentence along with the instruction format.
123+
124+
Args:
125+
instruction_format (str): Format for instruction.
126+
instruction (str): The text of instruction.
127+
sentence (str): The sentence to concatenate with.
128+
129+
Returns:
130+
str: The complete sentence with instruction
131+
"""
111132
return instruction_format.format(instruction, sentence)
112133

113134
def encode_queries(
@@ -118,6 +139,18 @@ def encode_queries(
118139
convert_to_numpy: Optional[bool] = None,
119140
**kwargs: Any
120141
):
142+
"""encode the queries using the instruction if provided.
143+
144+
Args:
145+
queries (Union[List[str], str]): Input queries to encode.
146+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
147+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
148+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
149+
be a Torch Tensor. Defaults to None.
150+
151+
Returns:
152+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
153+
"""
121154
if batch_size is None: batch_size = self.batch_size
122155
if max_length is None: max_length = self.query_max_length
123156
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
@@ -140,6 +173,18 @@ def encode_corpus(
140173
convert_to_numpy: Optional[bool] = None,
141174
**kwargs: Any
142175
):
176+
"""encode the corpus using the instruction if provided.
177+
178+
Args:
179+
corpus (Union[List[str], str]): Input corpus to encode.
180+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
181+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
182+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
183+
be a Torch Tensor. Defaults to None.
184+
185+
Returns:
186+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
187+
"""
143188
passage_instruction_for_retrieval = self.kwargs.get("passage_instruction_for_retrieval", None)
144189
passage_instruction_format = self.kwargs.get("passage_instruction_format", "{}{}")
145190

@@ -167,6 +212,20 @@ def encode(
167212
instruction_format: Optional[str] = None,
168213
**kwargs: Any
169214
):
215+
"""encode the input sentences with the embedding model.
216+
217+
Args:
218+
sentences (Union[List[str], str]): Input sentences to encode.
219+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to None.
220+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to None.
221+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
222+
be a Torch Tensor. Defaults to None.
223+
instruction (Optional[str], optional): The text of instruction. Defaults to None.
224+
instruction_format (Optional[str], optional): Format for instruction. Defaults to None.
225+
226+
Returns:
227+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
228+
"""
170229
if batch_size is None: batch_size = self.batch_size
171230
if max_length is None: max_length = self.passage_max_length
172231
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
@@ -338,6 +397,17 @@ def encode_multi_process(
338397
return embeddings
339398

340399
def _concatenate_results_from_multi_process(self, results_list: List[Union[torch.Tensor, np.ndarray, Any]]):
400+
"""concatenate and return the results from all the processes
401+
402+
Args:
403+
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): a list of results from all the processes
404+
405+
Raises:
406+
NotImplementedError: Unsupported type for results_list
407+
408+
Returns:
409+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
410+
"""
341411
if isinstance(results_list[0], torch.Tensor):
342412
return torch.cat(results_list, dim=0)
343413
elif isinstance(results_list[0], np.ndarray):

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,26 @@
1616

1717
class AbsReranker(ABC):
1818
"""
19-
Base class for embedder.
19+
Base class for Reranker.
2020
Extend this class and implement `compute_score_single_gpu` for custom rerankers.
21+
22+
Args:
23+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
24+
load a model from HuggingFace Hub with the name.
25+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
26+
degradation. Default: `False`.
27+
query_instruction_for_rerank: (Optional[str], optional): Query instruction for reranking, which will be used with
28+
with `query_instruction_format`. Default: `None`.
29+
query_instruction_format: (str, optional): The template for `query_instruction_for_rerank`. Default: `"{}{}"`.
30+
passage_instruction_for_rerank (Optional[str], optional): Passage instruction for reranking. Default: `None`.
31+
passage_instruction_format (str, optional): Passage instruction format when using `passage_instruction_for_rerank`.
32+
Default: `"{}{}"`.
33+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Default: `None`.
34+
batch_size (int, optional): Batch size for inference. Default: `128`.
35+
query_max_length (int, optional): Maximum length for query. Default: `None`.
36+
passage_max_length (int, optional): Maximum length for passage. Default: `512`.
37+
normalize (bool, optional): If true, normalize the result. Default: `False`.
38+
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
2139
"""
2240

2341
def __init__(
@@ -61,6 +79,17 @@ def __init__(
6179

6280
@staticmethod
6381
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
82+
"""
83+
84+
Args:
85+
devices (Union[str, int, List[str], List[int]]): Specified devices, can be `str`, `int`, list of `str`, or list of `int`.
86+
87+
Raises:
88+
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
89+
90+
Returns:
91+
List[str]: A list of target devices in format
92+
"""
6493
if devices is None:
6594
if torch.cuda.is_available():
6695
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
@@ -85,9 +114,27 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
85114
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
86115

87116
def get_detailed_instruct(self, instruction_format: str, instruction: str, sentence: str):
117+
"""Combine the instruction and sentence along with the instruction format.
118+
119+
Args:
120+
instruction_format (str): Format for instruction.
121+
instruction (str): The text of instruction.
122+
sentence (str): The sentence to concatenate with.
123+
124+
Returns:
125+
str: The complete sentence with instruction
126+
"""
88127
return instruction_format.format(instruction, sentence)
89128

90129
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
130+
"""get detailed instruct for all the inputs
131+
132+
Args:
133+
sentence_pairs (Union[str, List[str]]): Input sentence pairs
134+
135+
Returns:
136+
list[list[str]]: The complete sentence pairs with instruction
137+
"""
91138
if isinstance(sentence_pairs, str):
92139
sentence_pairs = [sentence_pairs]
93140

@@ -127,6 +174,14 @@ def compute_score(
127174
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
128175
**kwargs
129176
):
177+
"""Compute score for each sentence pair
178+
179+
Args:
180+
sentence_pairs (Union[List[Tuple[str, str]], Tuple[str, str]]): Input sentence pairs to compute.
181+
182+
Returns:
183+
numpy.ndarray: scores of all the sentence pairs.
184+
"""
130185
if isinstance(sentence_pairs[0], str):
131186
sentence_pairs = [sentence_pairs]
132187
sentence_pairs = self.get_detailed_inputs(sentence_pairs)

docs/source/API/evaluation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Visualized BGE
2-
==============
1+
Evaluation
2+
==========

docs/source/API/inference.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Evaluation
2-
==========
1+
Inference
2+
=========

docs/source/bge/llm_embedder.rst

Lines changed: 0 additions & 2 deletions
This file was deleted.

docs/source/index.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ We are aiming to enhance text and multi-model retrieval by leveraging advanced e
5151

5252
bge/introduction
5353
bge/bge_v1
54-
bge/llm_embedder
5554
bge/bge_m3
5655
bge/bge_icl
5756
bge/bge_reranker
@@ -62,8 +61,8 @@ We are aiming to enhance text and multi-model retrieval by leveraging advanced e
6261
:caption: API
6362

6463
API/abc
65-
API/evaluation
6664
API/inference
65+
API/evaluation
6766

6867
.. toctree::
6968
:hidden:

0 commit comments

Comments
 (0)