1515
1616@dataclass
1717class EmbedderOutput (ModelOutput ):
18+ """
19+ Output information returned by the model.
20+ """
1821 q_reps : Optional [Tensor ] = None
1922 p_reps : Optional [Tensor ] = None
2023 loss : Optional [Tensor ] = None
2124 scores : Optional [Tensor ] = None
2225
2326
2427class AbsEmbedderModel (ABC , nn .Module ):
28+ """Abstract class of embedding model for training.
29+
30+ Args:
31+ base_model: The base model to train on.
32+ tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
33+ negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
34+ temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
35+ sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
36+ Defaults to ``-1``.
37+ kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
38+ """
2539 def __init__ (
2640 self ,
2741 base_model ,
@@ -48,21 +62,53 @@ def __init__(
4862
4963 @abstractmethod
5064 def encode (self , features ):
65+ """Abstract method encode and get the embedding.
66+
67+ Args:
68+ features (Union[list, dict]): Features feed to the model.
69+ """
5170 pass
5271
5372 @abstractmethod
5473 def compute_loss (self , scores , target ):
74+ """Abstract method compute the loss.
75+
76+ Args:
77+ scores (torch.Tensor): Computed score.
78+ target (torch.Tensor): The target value.
79+ """
5580 pass
5681
5782 @abstractmethod
5883 def compute_score (self , q_reps , p_reps ):
84+ """Abstract method to compute the score.
85+
86+ Args:
87+ q_reps (torch.Tensor): Queries representations.
88+ p_reps (torch.Tensor): Passages rerpresentations.
89+ """
5990 pass
6091
6192 @abstractmethod
6293 def save (self , output_dir : str ):
94+ """Abstract method to save the model.
95+
96+ Args:
97+ output_dir (str): Directory for saving the model.
98+ """
6399 pass
64100
65101 def get_local_score (self , q_reps , p_reps , all_scores ):
102+ """Get the local score of queries and passages.
103+
104+ Args:
105+ q_reps (torch.Tensor): Queries representations.
106+ p_reps (torch.Tensor): Passages rerpresentations.
107+ all_scores (torch.Tensor): All the query-passage scores computed.
108+
109+ Returns:
110+ torch.Tensor: Local scores to compute loss.
111+ """
66112 group_size = p_reps .size (0 ) // q_reps .size (0 )
67113 indices = torch .arange (0 , q_reps .size (0 ), device = q_reps .device ) * group_size
68114 specific_scores = []
@@ -73,6 +119,17 @@ def get_local_score(self, q_reps, p_reps, all_scores):
73119 return torch .stack (specific_scores , dim = 1 ).view (q_reps .size (0 ), - 1 )
74120
75121 def compute_local_score (self , q_reps , p_reps , compute_score_func = None , ** kwargs ):
122+ """Compute the local score of queries and passages.
123+
124+ Args:
125+ q_reps (torch.Tensor): Queries representations.
126+ p_reps (torch.Tensor): Passages rerpresentations.
127+ compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
128+ :meth:`self.compute_score`.
129+
130+ Returns:
131+ torch.Tensor: Local scores to compute loss.
132+ """
76133 if compute_score_func is None :
77134 all_scores = self .compute_score (q_reps , p_reps )
78135 else :
@@ -181,6 +238,17 @@ def forward(
181238 teacher_scores : Union [None , List [float ]] = None ,
182239 no_in_batch_neg_flag : bool = False ,
183240 ):
241+ """The computation performed at every call.
242+
243+ Args:
244+ queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
245+ passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
246+ teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
247+ no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
248+
249+ Returns:
250+ EmbedderOutput: Output of the forward call of model.
251+ """
184252 q_reps = self .encode (queries ) # (batch_size, dim)
185253 p_reps = self .encode (passages ) # (batch_size * group_size, dim)
186254
@@ -210,6 +278,20 @@ def forward(
210278
211279 @staticmethod
212280 def distill_loss (kd_loss_type , teacher_targets , student_scores , group_size = None ):
281+ """Compute the distillation loss.
282+
283+ Args:
284+ kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
285+ teacher_targets (torch.Tensor): Targets from the teacher model.
286+ student_scores (torch.Tensor): Score of student model.
287+ group_size (int, optional): Number of groups for . Defaults to ``None``.
288+
289+ Raises:
290+ ValueError: Invalid kd_loss_type
291+
292+ Returns:
293+ torch.Tensor: A scalar of computed distillation loss.
294+ """
213295 if kd_loss_type == 'kl_div' :
214296 # teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
215297 # student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
@@ -236,6 +318,15 @@ def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None)
236318 raise ValueError (f"Invalid kd_loss_type: { kd_loss_type } " )
237319
238320 def _dist_gather_tensor (self , t : Optional [torch .Tensor ]):
321+ """Gather a tensor from all processes in a distributed setting.
322+
323+ Args:
324+ t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
325+
326+ Returns:
327+ Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
328+ otherwise returns ``None``.
329+ """
239330 if t is None :
240331 return None
241332 t = t .contiguous ()
0 commit comments