@@ -123,6 +123,55 @@ def __call__(self, data):
123123 return_tensors = 'pt' ,
124124 )
125125
126+
127+ class collater_for_lightweight ():
128+ def __init__ (self , tokenizer , max_len ):
129+ self .tokenizer = tokenizer
130+ self .max_len = max_len
131+ self .pad_to_multiple_of = 8
132+ self .label_pad_token_id = - 100
133+ warnings .filterwarnings ("ignore" ,
134+ message = "`max_length` is ignored when `padding`=`True` and there is no truncation strategy." )
135+
136+ def __call__ (self , data ):
137+ features = data [0 ]
138+ query_lengths = data [1 ]
139+ prompt_lengths = data [2 ]
140+
141+ labels = [feature ["labels" ] for feature in features ] if "labels" in features [0 ].keys () else None
142+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
143+ # same length to return tensors.
144+ if labels is not None :
145+ max_label_length = max (len (l ) for l in labels )
146+ if self .pad_to_multiple_of is not None :
147+ max_label_length = (
148+ (max_label_length + self .pad_to_multiple_of - 1 )
149+ // self .pad_to_multiple_of
150+ * self .pad_to_multiple_of
151+ )
152+
153+ padding_side = self .tokenizer .padding_side
154+ for feature in features :
155+ remainder = [self .label_pad_token_id ] * (max_label_length - len (feature ["labels" ]))
156+ if isinstance (feature ["labels" ], list ):
157+ feature ["labels" ] = (
158+ feature ["labels" ] + remainder if padding_side == "right" else remainder + feature ["labels" ]
159+ )
160+ elif padding_side == "right" :
161+ feature ["labels" ] = np .concatenate ([feature ["labels" ], remainder ]).astype (np .int64 )
162+ else :
163+ feature ["labels" ] = np .concatenate ([remainder , feature ["labels" ]]).astype (np .int64 )
164+
165+ collected = self .tokenizer .pad (
166+ features ,
167+ padding = True ,
168+ max_length = self .max_len ,
169+ pad_to_multiple_of = 8 ,
170+ return_tensors = 'pt' ,
171+ )
172+
173+ return collected , query_lengths , prompt_lengths
174+
126175def last_logit_pool (logits : Tensor ,
127176 attention_mask : Tensor ) -> Tensor :
128177 left_padding = (attention_mask [:, - 1 ].sum () == attention_mask .shape [0 ])
@@ -143,6 +192,16 @@ def last_logit_pool_layerwise(logits: Tensor,
143192 batch_size = logits .shape [0 ]
144193 return logits [torch .arange (batch_size , device = logits .device ), sequence_lengths ]
145194
195+ def last_logit_pool_lightweight (logits : Tensor ,
196+ attention_mask : Tensor ) -> Tensor :
197+ left_padding = (attention_mask [:, - 1 ].sum () == attention_mask .shape [0 ])
198+ if left_padding :
199+ return logits [:, - 1 ]
200+ else :
201+ sequence_lengths = attention_mask .sum (dim = 1 ) - 1
202+ batch_size = logits .shape [0 ]
203+ return torch .stack ([logits [i , sequence_lengths [i ]] for i in range (batch_size )], dim = 0 )
204+
146205def sigmoid (x ):
147206 return 1 / (1 + np .exp (- x ))
148207
@@ -593,6 +652,7 @@ def __init__(
593652 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ,
594653 cache_dir = cache_dir ,
595654 trust_remote_code = True )
655+ self .tokenizer .padding_side = 'right'
596656
597657 if use_bf16 is False and use_fp16 is False :
598658 warnings .warn ("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default." , UserWarning )
@@ -697,7 +757,7 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
697757 query_lengths .append (len ([self .tokenizer .bos_token_id ] + query_inputs + sep_inputs ))
698758 prompt_lengths .append (len (sep_inputs + prompt_inputs ))
699759
700- collater_instance = collater (self .tokenizer , max_length )
760+ collater_instance = collater_for_lightweight (self .tokenizer , max_length )
701761 batch_inputs = collater_instance (
702762 [
703763 [{'input_ids' : item ['input_ids' ], 'attention_mask' : item ['attention_mask' ]} for item in
@@ -717,7 +777,7 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
717777 cutoff_layers = cutoff_layers )
718778 scores = []
719779 for i in range (len (outputs .logits )):
720- logits = last_logit_pool (outputs .logits [i ], outputs .attention_masks [i ])
780+ logits = last_logit_pool_lightweight (outputs .logits [i ], outputs .attention_masks [i ])
721781 scores .append (logits .cpu ().float ().tolist ())
722782 if len (all_scores ) == 0 :
723783 for i in range (len (scores )):
0 commit comments