Skip to content

Commit f370b72

Browse files
committed
Inference docstring
1 parent 6f693a2 commit f370b72

3 files changed

Lines changed: 204 additions & 4 deletions

File tree

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
# Pooling function for LLM-based embedding models
1212
def last_token_pool(last_hidden_states: torch.Tensor,
1313
attention_mask: torch.Tensor) -> torch.Tensor:
14+
"""Last token pooling method.
15+
16+
Args:
17+
last_hidden_state (torch.Tensor): The last hidden state of the model.
18+
attention_mask (torch.Tensor): Attention mask. Defaults to :data:`None`.
19+
20+
Returns:
21+
torch.Tensor: The embedding vectors after pooling.
22+
"""
1423
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
1524
if left_padding:
1625
return last_hidden_states[:, -1]
@@ -21,6 +30,31 @@ def last_token_pool(last_hidden_states: torch.Tensor,
2130

2231

2332
class BaseLLMEmbedder(AbsEmbedder):
33+
"""Base embedder for LLM like decoder only models.
34+
35+
Args:
36+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
37+
load a model from HuggingFace Hub with the name.
38+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
39+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
40+
degradation. Defaults to :data:`True`.
41+
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
42+
with :attr:`query_instruction_format`. Defaults to :data:`None`.
43+
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
44+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
45+
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
46+
cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
47+
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
48+
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
49+
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
50+
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
51+
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
52+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
53+
Defaults to :data:`True`.
54+
55+
Attributes:
56+
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
57+
"""
2458
DEFAULT_POOLING_METHOD = "last_token"
2559

2660
def __init__(
@@ -81,6 +115,18 @@ def encode_queries(
81115
convert_to_numpy: Optional[bool] = None,
82116
**kwargs: Any
83117
) -> Union[np.ndarray, torch.Tensor]:
118+
"""Encode the queries.
119+
120+
Args:
121+
queries (Union[List[str], str]): Input queries to encode.
122+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
123+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
124+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
125+
be a Torch Tensor. Defaults to :data:`None`.
126+
127+
Returns:
128+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
129+
"""
84130
return super().encode_queries(
85131
queries,
86132
batch_size=batch_size,
@@ -97,6 +143,18 @@ def encode_corpus(
97143
convert_to_numpy: Optional[bool] = None,
98144
**kwargs: Any
99145
) -> Union[np.ndarray, torch.Tensor]:
146+
"""Encode the corpus.
147+
148+
Args:
149+
corpus (Union[List[str], str]): Input corpus to encode.
150+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
151+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
152+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
153+
be a Torch Tensor. Defaults to :data:`None`.
154+
155+
Returns:
156+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
157+
"""
100158
return super().encode_corpus(
101159
corpus,
102160
batch_size=batch_size,
@@ -113,6 +171,18 @@ def encode(
113171
convert_to_numpy: Optional[bool] = None,
114172
**kwargs: Any
115173
) -> Union[np.ndarray, torch.Tensor]:
174+
"""Encode the input sentences with the embedding model.
175+
176+
Args:
177+
sentences (Union[List[str], str]): Input sentences to encode.
178+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
179+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
180+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
181+
be a Torch Tensor. Defaults to :data:`None`.
182+
183+
Returns:
184+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
185+
"""
116186
return super().encode(
117187
sentences,
118188
batch_size=batch_size,
@@ -131,6 +201,19 @@ def encode_single_device(
131201
device: Optional[str] = None,
132202
**kwargs: Any # add `pad_to_multiple_of=8` for bge-multilingual-gemmma2
133203
):
204+
"""Encode input sentences by a single device.
205+
206+
Args:
207+
sentences (Union[List[str], str]): Input sentences to encode.
208+
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
209+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
210+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
211+
be a Torch Tensor. Defaults to :data:`True`.
212+
device (Optional[str], optional): Device to use for encoding. Defaults to None.
213+
214+
Returns:
215+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
216+
"""
134217
if device is None:
135218
device = self.target_devices[0]
136219

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
# Pooling function for LLM-based embedding models
1515
def last_token_pool(last_hidden_states: torch.Tensor,
1616
attention_mask: torch.Tensor) -> torch.Tensor:
17+
"""Last token pooling method.
18+
19+
Args:
20+
last_hidden_state (torch.Tensor): The last hidden state of the model.
21+
attention_mask (torch.Tensor): Attention mask. Defaults to :data:`None`.
22+
23+
Returns:
24+
torch.Tensor: The embedding vectors after pooling.
25+
"""
1726
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
1827
if left_padding:
1928
return last_hidden_states[:, -1]
@@ -24,6 +33,35 @@ def last_token_pool(last_hidden_states: torch.Tensor,
2433

2534

2635
class ICLLLMEmbedder(AbsEmbedder):
36+
"""_summary_
37+
38+
Args:
39+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
40+
load a model from HuggingFace Hub with the name.
41+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
42+
use_fp16 (bool, optional) If true, use half-precision floating-point to speed up computation with a slight performance
43+
degradation. Defaults to :data:`True`.
44+
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
45+
with :attr:`query_instruction_format`. Defaults to :data:`None`.
46+
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
47+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
48+
examples_for_task (Optional[List[dict]], optional): Few-shot examples for the model to enhance model's ability. Defaults to
49+
:data:`None`.
50+
examples_instruction_format (str, optional): Example format when using :attr:`examples_for_task`. Defaults to
51+
:data:`"<instruct>{}\n<query>{}\n<response>{}"`.
52+
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
53+
cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
54+
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
55+
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
56+
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
57+
instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
58+
instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
59+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
60+
Defaults to :data:`True`.
61+
62+
Attributes:
63+
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
64+
"""
2765
DEFAULT_POOLING_METHOD = "last_token"
2866

2967
def __init__(
@@ -84,6 +122,12 @@ def __init__(
84122
self.suffix = '\n<response>'
85123

86124
def set_examples(self, examples_for_task: Optional[List[dict]] = None):
125+
"""Set the prefix to the provided examples.
126+
127+
Args:
128+
examples_for_task (Optional[List[dict]], optional): Few-shot examples for the model to enhance model's ability.
129+
Defaults to :data:`None`.
130+
"""
87131
if examples_for_task is None and self.examples_for_task is None:
88132
self.prefix = ''
89133
elif examples_for_task is not None:
@@ -113,6 +157,17 @@ def set_examples(self, examples_for_task: Optional[List[dict]] = None):
113157

114158
@staticmethod
115159
def get_detailed_example(instruction_format: str, instruction: str, query: str, response: str):
160+
"""Combine the instruction and sentence along with the instruction format.
161+
162+
Args:
163+
instruction_format (str): Format for instruction.
164+
instruction (str): The text of instruction.
165+
query (str): The text of example query.
166+
response (str): The text of example response.
167+
168+
Returns:
169+
str: The complete example following the given format.
170+
"""
116171
return instruction_format.format(instruction, query, response)
117172

118173
def encode_queries(
@@ -123,6 +178,18 @@ def encode_queries(
123178
convert_to_numpy: Optional[bool] = None,
124179
**kwargs: Any
125180
) -> Union[np.ndarray, torch.Tensor]:
181+
"""Encode the queries.
182+
183+
Args:
184+
queries (Union[List[str], str]): Input queries to encode.
185+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
186+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
187+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
188+
be a Torch Tensor. Defaults to :data:`None`.
189+
190+
Returns:
191+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
192+
"""
126193
if batch_size is None: batch_size = self.batch_size
127194
if max_length is None: max_length = self.query_max_length
128195
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
@@ -157,6 +224,18 @@ def encode_corpus(
157224
convert_to_numpy: Optional[bool] = None,
158225
**kwargs: Any
159226
) -> Union[np.ndarray, torch.Tensor]:
227+
"""Encode the corpus.
228+
229+
Args:
230+
corpus (Union[List[str], str]): Input corpus to encode.
231+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
232+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
233+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
234+
be a Torch Tensor. Defaults to :data:`None`.
235+
236+
Returns:
237+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
238+
"""
160239
return super().encode_corpus(
161240
corpus,
162241
batch_size=batch_size,
@@ -173,6 +252,18 @@ def encode(
173252
convert_to_numpy: Optional[bool] = None,
174253
**kwargs: Any
175254
) -> Union[np.ndarray, torch.Tensor]:
255+
"""Encode the input sentences with the embedding model.
256+
257+
Args:
258+
sentences (Union[List[str], str]): Input sentences to encode.
259+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
260+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
261+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
262+
be a Torch Tensor. Defaults to :data:`None`.
263+
264+
Returns:
265+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
266+
"""
176267
return super().encode(
177268
sentences,
178269
batch_size=batch_size,
@@ -214,6 +305,19 @@ def encode_queries_single_device(
214305
device: Optional[str] = None,
215306
**kwargs: Any
216307
):
308+
"""Encode queries by a single device.
309+
310+
Args:
311+
queries (Union[List[str], str]): Input queries to encode.
312+
batch_size (int, optional): Number of queries for each iter. Defaults to :data:`256`.
313+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
314+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
315+
be a Torch Tensor. Defaults to :data:`True`.
316+
device (Optional[str], optional): Device to use for encoding. Defaults to None.
317+
318+
Returns:
319+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
320+
"""
217321
if device is None:
218322
device = self.target_devices[0]
219323

@@ -342,6 +446,19 @@ def encode_single_device(
342446
device: Optional[str] = None,
343447
**kwargs: Any
344448
):
449+
"""Encode input sentences by a single device.
450+
451+
Args:
452+
sentences (Union[List[str], str]): Input sentences to encode.
453+
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
454+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
455+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
456+
be a Torch Tensor. Defaults to :data:`True`.
457+
device (Optional[str], optional): Device to use for encoding. Defaults to None.
458+
459+
Returns:
460+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
461+
"""
345462
if device is None:
346463
device = self.target_devices[0]
347464

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class BaseEmbedder(AbsEmbedder):
1818
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
1919
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
2020
degradation. Defaults to :data:`True`.
21-
query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
21+
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
2222
with :attr:`query_instruction_format`. Defaults to :data:`None`.
23-
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
23+
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
2424
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
2525
pooling_method (str, optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`"cls"`.
2626
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
@@ -96,7 +96,7 @@ def encode_queries(
9696
convert_to_numpy: Optional[bool] = None,
9797
**kwargs: Any
9898
) -> Union[np.ndarray, torch.Tensor]:
99-
"""Encode the queries using the instruction if provided.
99+
"""Encode the queries.
100100
101101
Args:
102102
queries (Union[List[str], str]): Input queries to encode.
@@ -182,7 +182,7 @@ def encode_single_device(
182182
device: Optional[str] = None,
183183
**kwargs: Any
184184
):
185-
"""Encode input sentences on a single device.
185+
"""Encode input sentences by a single device.
186186
187187
Args:
188188
sentences (Union[List[str], str]): Input sentences to encode.

0 commit comments

Comments
 (0)