Skip to content

Commit 6f693a2

Browse files
committed
base & m3
1 parent ac3a9ce commit 6f693a2

3 files changed

Lines changed: 247 additions & 6 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch
394394
"""concatenate and return the results from all the processes
395395
396396
Args:
397-
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): a list of results from all the processes
397+
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes.
398398
399399
Raises:
400400
NotImplementedError: Unsupported type for results_list

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,34 @@
99

1010

1111
class BaseEmbedder(AbsEmbedder):
12+
"""
13+
Base embedder for encoder only models.
14+
15+
Args:
16+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
17+
load a model from HuggingFace Hub with the name.
18+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
19+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
20+
degradation. Defaults to :data:`True`.
21+
query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
22+
with :attr:`query_instruction_format`. Defaults to :data:`None`.
23+
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
24+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
25+
pooling_method (str, optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`"cls"`.
26+
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
27+
cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
28+
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
29+
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
30+
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
31+
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
32+
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
33+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
34+
Defaults to :data:`True`.
35+
36+
Attributes:
37+
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
38+
"""
39+
1240
DEFAULT_POOLING_METHOD = "cls"
1341

1442
def __init__(
@@ -68,6 +96,18 @@ def encode_queries(
6896
convert_to_numpy: Optional[bool] = None,
6997
**kwargs: Any
7098
) -> Union[np.ndarray, torch.Tensor]:
99+
"""Encode the queries using the instruction if provided.
100+
101+
Args:
102+
queries (Union[List[str], str]): Input queries to encode.
103+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
104+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
105+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
106+
be a Torch Tensor. Defaults to :data:`None`.
107+
108+
Returns:
109+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
110+
"""
71111
return super().encode_queries(
72112
queries,
73113
batch_size=batch_size,
@@ -84,6 +124,18 @@ def encode_corpus(
84124
convert_to_numpy: Optional[bool] = None,
85125
**kwargs: Any
86126
) -> Union[np.ndarray, torch.Tensor]:
127+
"""Encode the corpus using the instruction if provided.
128+
129+
Args:
130+
corpus (Union[List[str], str]): Input corpus to encode.
131+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
132+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
133+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
134+
be a Torch Tensor. Defaults to :data:`None`.
135+
136+
Returns:
137+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
138+
"""
87139
return super().encode_corpus(
88140
corpus,
89141
batch_size=batch_size,
@@ -100,6 +152,18 @@ def encode(
100152
convert_to_numpy: Optional[bool] = None,
101153
**kwargs: Any
102154
) -> Union[np.ndarray, torch.Tensor]:
155+
"""Encode the input sentences with the embedding model.
156+
157+
Args:
158+
sentences (Union[List[str], str]): Input sentences to encode.
159+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
160+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
161+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
162+
be a Torch Tensor. Defaults to :data:`None`.
163+
164+
Returns:
165+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
166+
"""
103167
return super().encode(
104168
sentences,
105169
batch_size=batch_size,
@@ -118,6 +182,19 @@ def encode_single_device(
118182
device: Optional[str] = None,
119183
**kwargs: Any
120184
):
185+
"""Encode input sentences on a single device.
186+
187+
Args:
188+
sentences (Union[List[str], str]): Input sentences to encode.
189+
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
190+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
191+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
192+
be a Torch Tensor. Defaults to :data:`True`.
193+
device (Optional[str], optional): Device to use for encoding. Defaults to None.
194+
195+
Returns:
196+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
197+
"""
121198
if device is None:
122199
device = self.target_devices[0]
123200

@@ -214,6 +291,18 @@ def pooling(
214291
last_hidden_state: torch.Tensor,
215292
attention_mask: Optional[torch.Tensor] = None
216293
):
294+
"""The pooling function.
295+
296+
Args:
297+
last_hidden_state (torch.Tensor): The last hidden state of the model.
298+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to :data:`None`.
299+
300+
Raises:
301+
NotImplementedError: pooling method not implemented.
302+
303+
Returns:
304+
torch.Tensor: The embedding vectors after pooling.
305+
"""
217306
if self.pooling_method == 'cls':
218307
return last_hidden_state[:, 0]
219308
elif self.pooling_method == 'mean':

0 commit comments

Comments
 (0)