@@ -18,6 +18,13 @@ class RerankerOutput(ModelOutput):
1818
1919
2020class AbsRerankerModel (ABC , nn .Module ):
21+ """Abstract class of embedding model for training.
22+
23+ Args:
24+ base_model: The base model to train on.
25+ tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
26+ train_batch_size (int, optional): Batch size used for training. Defaults to ``4``.
27+ """
2128 def __init__ (
2229 self ,
2330 base_model : None ,
@@ -38,16 +45,36 @@ def __init__(
3845 self .yes_loc = self .tokenizer ('Yes' , add_special_tokens = False )['input_ids' ][- 1 ]
3946
4047 def gradient_checkpointing_enable (self , ** kwargs ):
48+ """
49+ Activates gradient checkpointing for the current model.
50+ """
4151 self .model .gradient_checkpointing_enable (** kwargs )
4252
4353 def enable_input_require_grads (self , ** kwargs ):
54+ """
55+ Enables the gradients for the input embeddings.
56+ """
4457 self .model .enable_input_require_grads (** kwargs )
4558
4659 @abstractmethod
4760 def encode (self , features ):
61+ """Abstract method of encode.
62+
63+ Args:
64+ features (dict): Teatures to pass to the model.
65+ """
4866 pass
4967
5068 def forward (self , pair : Union [Dict [str , Tensor ], List [Dict [str , Tensor ]]] = None , teacher_scores : Optional [Tensor ] = None ):
69+ """The computation performed at every call.
70+
71+ Args:
72+ pair (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): The query-document pair. Defaults to ``None``.
73+ teacher_scores (Optional[Tensor], optional): Teacher scores of knowledge distillation. Defaults to None.
74+
75+ Returns:
76+ RerankerOutput: Output of reranker model.
77+ """
5178 ranker_logits = self .encode (pair ) # (batch_size * num, dim)
5279 if teacher_scores is not None :
5380 teacher_scores = torch .Tensor (teacher_scores )
@@ -72,9 +99,23 @@ def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None
7299 )
73100
74101 def compute_loss (self , scores , target ):
102+ """Compute the loss.
103+
104+ Args:
105+ scores (torch.Tensor): Computed scores.
106+ target (torch.Tensor): The target value.
107+
108+ Returns:
109+ torch.Tensor: The computed loss.
110+ """
75111 return self .cross_entropy (scores , target )
76112
77113 def save (self , output_dir : str ):
114+ """Save the model.
115+
116+ Args:
117+ output_dir (str): Directory for saving the model.
118+ """
78119 # self.model.save_pretrained(output_dir)
79120 state_dict = self .model .state_dict ()
80121 state_dict = type (state_dict )(
@@ -84,5 +125,8 @@ def save(self, output_dir: str):
84125 self .model .save_pretrained (output_dir , state_dict = state_dict )
85126
86127 def save_pretrained (self , * args , ** kwargs ):
128+ """
129+ Save the tokenizer and model.
130+ """
87131 self .tokenizer .save_pretrained (* args , ** kwargs )
88132 return self .model .save_pretrained (* args , ** kwargs )
0 commit comments