Skip to content

Commit 14ba730

Browse files
authored
Merge pull request #997 from 545999961/master
update light weight reranker
2 parents bae7503 + 37c250d commit 14ba730

1 file changed

Lines changed: 62 additions & 2 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,55 @@ def __call__(self, data):
123123
return_tensors='pt',
124124
)
125125

126+
127+
class collater_for_lightweight():
128+
def __init__(self, tokenizer, max_len):
129+
self.tokenizer = tokenizer
130+
self.max_len = max_len
131+
self.pad_to_multiple_of = 8
132+
self.label_pad_token_id = -100
133+
warnings.filterwarnings("ignore",
134+
message="`max_length` is ignored when `padding`=`True` and there is no truncation strategy.")
135+
136+
def __call__(self, data):
137+
features = data[0]
138+
query_lengths = data[1]
139+
prompt_lengths = data[2]
140+
141+
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
142+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
143+
# same length to return tensors.
144+
if labels is not None:
145+
max_label_length = max(len(l) for l in labels)
146+
if self.pad_to_multiple_of is not None:
147+
max_label_length = (
148+
(max_label_length + self.pad_to_multiple_of - 1)
149+
// self.pad_to_multiple_of
150+
* self.pad_to_multiple_of
151+
)
152+
153+
padding_side = self.tokenizer.padding_side
154+
for feature in features:
155+
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
156+
if isinstance(feature["labels"], list):
157+
feature["labels"] = (
158+
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
159+
)
160+
elif padding_side == "right":
161+
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
162+
else:
163+
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
164+
165+
collected = self.tokenizer.pad(
166+
features,
167+
padding=True,
168+
max_length=self.max_len,
169+
pad_to_multiple_of=8,
170+
return_tensors='pt',
171+
)
172+
173+
return collected, query_lengths, prompt_lengths
174+
126175
def last_logit_pool(logits: Tensor,
127176
attention_mask: Tensor) -> Tensor:
128177
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
@@ -143,6 +192,16 @@ def last_logit_pool_layerwise(logits: Tensor,
143192
batch_size = logits.shape[0]
144193
return logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
145194

195+
def last_logit_pool_lightweight(logits: Tensor,
196+
attention_mask: Tensor) -> Tensor:
197+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
198+
if left_padding:
199+
return logits[:, -1]
200+
else:
201+
sequence_lengths = attention_mask.sum(dim=1) - 1
202+
batch_size = logits.shape[0]
203+
return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)
204+
146205
def sigmoid(x):
147206
return 1 / (1 + np.exp(-x))
148207

@@ -593,6 +652,7 @@ def __init__(
593652
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
594653
cache_dir=cache_dir,
595654
trust_remote_code=True)
655+
self.tokenizer.padding_side = 'right'
596656

597657
if use_bf16 is False and use_fp16 is False:
598658
warnings.warn("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default.", UserWarning)
@@ -697,7 +757,7 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
697757
query_lengths.append(len([self.tokenizer.bos_token_id] + query_inputs + sep_inputs))
698758
prompt_lengths.append(len(sep_inputs + prompt_inputs))
699759

700-
collater_instance = collater(self.tokenizer, max_length)
760+
collater_instance = collater_for_lightweight(self.tokenizer, max_length)
701761
batch_inputs = collater_instance(
702762
[
703763
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
@@ -717,7 +777,7 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
717777
cutoff_layers=cutoff_layers)
718778
scores = []
719779
for i in range(len(outputs.logits)):
720-
logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
780+
logits = last_logit_pool_lightweight(outputs.logits[i], outputs.attention_masks[i])
721781
scores.append(logits.cpu().float().tolist())
722782
if len(all_scores) == 0:
723783
for i in range(len(scores)):

0 commit comments

Comments
 (0)