99
1010
1111class BiEncoderOnlyEmbedderModel (AbsEmbedderModel ):
12+ """Embedder class for encoder only model.
13+
14+ Args:
15+ base_model (AutoModel): The base model to train on.
16+ tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
17+ negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
18+ temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
19+ sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
20+ Defaults to ``-1``.
21+ kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
22+ sentence_pooling_method (str, optional): Pooling method to get sentence embedding. Defaults to '`cls`'.
23+ normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to ``False``.
24+ """
1225 TRANSFORMER_CLS = AutoModel
1326
1427 def __init__ (
@@ -35,6 +48,14 @@ def __init__(
3548 self .cross_entropy = torch .nn .CrossEntropyLoss (reduction = 'mean' )
3649
3750 def encode (self , features ):
51+ """Encode and get the embedding.
52+
53+ Args:
54+ features (Union[list, dict]): Features feed to the model.
55+
56+ Returns:
57+ torch.Tensor: The embedding vectors.
58+ """
3859 if features is None :
3960 return None
4061 if not isinstance (features , list ):
@@ -70,6 +91,18 @@ def encode(self, features):
7091 return all_p_reps .contiguous ()
7192
7293 def _sentence_embedding (self , last_hidden_state , attention_mask ):
94+ """Use the pooling method to get the sentence embedding.
95+
96+ Args:
97+ last_hidden_state (torch.Tensor): The model output's last hidden state.
98+ attention_mask (torch.Tensor): Mask out padding tokens during pooling.
99+
100+ Raises:
101+ NotImplementedError: Specified pooling method not implemented.
102+
103+ Returns:
104+ torch.Tensor: The sentence embeddings.
105+ """
73106 if self .sentence_pooling_method == "cls" :
74107 return last_hidden_state [:, 0 ]
75108 elif self .sentence_pooling_method == "mean" :
@@ -93,25 +126,63 @@ def _sentence_embedding(self, last_hidden_state, attention_mask):
93126 raise NotImplementedError (f"pooling method { self .sentence_pooling_method } not implemented" )
94127
95128 def compute_score (self , q_reps , p_reps ):
129+ """Computes the scores between query and passage representations.
130+
131+ Args:
132+ q_reps (torch.Tensor): Query representations.
133+ p_reps (torch.Tensor): Passage representations.
134+
135+ Returns:
136+ torch.Tensor: The computed scores, adjusted by temperature.
137+ """
96138 scores = self ._compute_similarity (q_reps , p_reps ) / self .temperature
97139 scores = scores .view (q_reps .size (0 ), - 1 )
98140 return scores
99141
100142 def _compute_similarity (self , q_reps , p_reps ):
143+ """Computes the similarity between query and passage representations using inner product.
144+
145+ Args:
146+ q_reps (torch.Tensor): Query representations.
147+ p_reps (torch.Tensor): Passage representations.
148+
149+ Returns:
150+ torch.Tensor: The computed similarity matrix.
151+ """
101152 if len (p_reps .size ()) == 2 :
102153 return torch .matmul (q_reps , p_reps .transpose (0 , 1 ))
103154 return torch .matmul (q_reps , p_reps .transpose (- 2 , - 1 ))
104155
105156 def compute_loss (self , scores , target ):
157+ """Compute the loss using cross entropy.
158+
159+ Args:
160+ scores (torch.Tensor): Computed score.
161+ target (torch.Tensor): The target value.
162+
163+ Returns:
164+ torch.Tensor: The computed cross entropy loss.
165+ """
106166 return self .cross_entropy (scores , target )
107167
108168 def gradient_checkpointing_enable (self , ** kwargs ):
169+ """
170+ Activates gradient checkpointing for the current model.
171+ """
109172 self .model .gradient_checkpointing_enable (** kwargs )
110173
111174 def enable_input_require_grads (self , ** kwargs ):
175+ """
176+ Enables the gradients for the input embeddings.
177+ """
112178 self .model .enable_input_require_grads (** kwargs )
113179
114180 def save (self , output_dir : str ):
181+ """Save the model to the directory.
182+
183+ Args:
184+ output_dir (str): Directory for saving the model.
185+ """
115186 state_dict = self .model .state_dict ()
116187 state_dict = type (state_dict )(
117188 {k : v .clone ().cpu ()
0 commit comments