99
1010
1111class BaseEmbedder (AbsEmbedder ):
12+ """
13+ Base embedder for encoder only models.
14+
15+ Args:
16+ model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
17+ load a model from HuggingFace Hub with the name.
18+ normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
19+ use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
20+ degradation. Defaults to :data:`True`.
21+ query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
22+ with :attr:`query_instruction_format`. Defaults to :data:`None`.
23+ query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
24+ devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
25+ pooling_method (str, optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`"cls"`.
26+ trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
27+ cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
28+ batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
29+ query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
30+ passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
31+ instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Defaults to :data:`None`.
32+ instruction_format (str, optional): Instruction format when using :attr:`instruction`. Defaults to :data:`"{}{}"`.
33+ convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
34+ Defaults to :data:`True`.
35+
36+ Attributes:
37+ DEFAULT_POOLING_METHOD: The default pooling method when running the model.
38+ """
39+
1240 DEFAULT_POOLING_METHOD = "cls"
1341
1442 def __init__ (
@@ -68,6 +96,18 @@ def encode_queries(
6896 convert_to_numpy : Optional [bool ] = None ,
6997 ** kwargs : Any
7098 ) -> Union [np .ndarray , torch .Tensor ]:
99+ """Encode the queries using the instruction if provided.
100+
101+ Args:
102+ queries (Union[List[str], str]): Input queries to encode.
103+ batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
104+ max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
105+ convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
106+ be a Torch Tensor. Defaults to :data:`None`.
107+
108+ Returns:
109+ Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
110+ """
71111 return super ().encode_queries (
72112 queries ,
73113 batch_size = batch_size ,
@@ -84,6 +124,18 @@ def encode_corpus(
84124 convert_to_numpy : Optional [bool ] = None ,
85125 ** kwargs : Any
86126 ) -> Union [np .ndarray , torch .Tensor ]:
127+ """Encode the corpus using the instruction if provided.
128+
129+ Args:
130+ corpus (Union[List[str], str]): Input corpus to encode.
131+ batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
132+ max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
133+ convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
134+ be a Torch Tensor. Defaults to :data:`None`.
135+
136+ Returns:
137+ Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
138+ """
87139 return super ().encode_corpus (
88140 corpus ,
89141 batch_size = batch_size ,
@@ -100,6 +152,18 @@ def encode(
100152 convert_to_numpy : Optional [bool ] = None ,
101153 ** kwargs : Any
102154 ) -> Union [np .ndarray , torch .Tensor ]:
155+ """Encode the input sentences with the embedding model.
156+
157+ Args:
158+ sentences (Union[List[str], str]): Input sentences to encode.
159+ batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
160+ max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
161+ convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
162+ be a Torch Tensor. Defaults to :data:`None`.
163+
164+ Returns:
165+ Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
166+ """
103167 return super ().encode (
104168 sentences ,
105169 batch_size = batch_size ,
@@ -118,6 +182,19 @@ def encode_single_device(
118182 device : Optional [str ] = None ,
119183 ** kwargs : Any
120184 ):
185+ """Encode input sentences on a single device.
186+
187+ Args:
188+ sentences (Union[List[str], str]): Input sentences to encode.
189+ batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
190+ max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
191+ convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
192+ be a Torch Tensor. Defaults to :data:`True`.
193+ device (Optional[str], optional): Device to use for encoding. Defaults to None.
194+
195+ Returns:
196+ Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
197+ """
121198 if device is None :
122199 device = self .target_devices [0 ]
123200
@@ -214,6 +291,18 @@ def pooling(
214291 last_hidden_state : torch .Tensor ,
215292 attention_mask : Optional [torch .Tensor ] = None
216293 ):
294+ """The pooling function.
295+
296+ Args:
297+ last_hidden_state (torch.Tensor): The last hidden state of the model.
298+ attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to :data:`None`.
299+
300+ Raises:
301+ NotImplementedError: pooling method not implemented.
302+
303+ Returns:
304+ torch.Tensor: The embedding vectors after pooling.
305+ """
217306 if self .pooling_method == 'cls' :
218307 return last_hidden_state [:, 0 ]
219308 elif self .pooling_method == 'mean' :
0 commit comments