Skip to content

Commit d125f07

Browse files
committed
update reranker inference
1 parent ce88f0f commit d125f07

1 file changed

Lines changed: 66 additions & 30 deletions

File tree

  • FlagEmbedding/inference/reranker/decoder_only

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,28 @@ def __len__(self):
6868
def __getitem__(self, item):
6969
query_inputs = self.all_queries_inputs[item]
7070
passage_inputs = self.all_passages_inputs[item]
71-
item = self.tokenizer.prepare_for_model(
72-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
73-
self.sep_inputs + passage_inputs['input_ids'],
74-
truncation='only_second',
75-
max_length=self.encode_max_length,
76-
padding=False,
77-
return_attention_mask=False,
78-
return_token_type_ids=False,
79-
add_special_tokens=False
80-
)
71+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
72+
item = self.tokenizer.prepare_for_model(
73+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
74+
self.sep_inputs + passage_inputs['input_ids'],
75+
truncation='only_second',
76+
max_length=self.encode_max_length,
77+
padding=False,
78+
return_attention_mask=False,
79+
return_token_type_ids=False,
80+
add_special_tokens=False
81+
)
82+
else:
83+
item = self.tokenizer.prepare_for_model(
84+
query_inputs['input_ids'],
85+
self.sep_inputs + passage_inputs['input_ids'],
86+
truncation='only_second',
87+
max_length=self.encode_max_length,
88+
padding=False,
89+
return_attention_mask=False,
90+
return_token_type_ids=False,
91+
add_special_tokens=False
92+
)
8193
item['input_ids'] = item['input_ids'] + self.sep_inputs + self.prompt_inputs
8294
item['attention_mask'] = [1] * len(item['input_ids'])
8395
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@@ -289,16 +301,28 @@ def compute_score_single_gpu(
289301
all_queries_inputs_sorted[:min(len(all_queries_inputs_sorted), batch_size)],
290302
all_passages_inputs_sorted[:min(len(all_passages_inputs_sorted), batch_size)]
291303
):
292-
item = self.tokenizer.prepare_for_model(
293-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
294-
sep_inputs + passage_inputs['input_ids'],
295-
truncation='only_second',
296-
max_length=encode_max_length,
297-
padding=False,
298-
return_attention_mask=False,
299-
return_token_type_ids=False,
300-
add_special_tokens=False
301-
)
304+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
305+
item = self.tokenizer.prepare_for_model(
306+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
307+
sep_inputs + passage_inputs['input_ids'],
308+
truncation='only_second',
309+
max_length=encode_max_length,
310+
padding=False,
311+
return_attention_mask=False,
312+
return_token_type_ids=False,
313+
add_special_tokens=False
314+
)
315+
else:
316+
item = self.tokenizer.prepare_for_model(
317+
query_inputs['input_ids'],
318+
sep_inputs + passage_inputs['input_ids'],
319+
truncation='only_second',
320+
max_length=encode_max_length,
321+
padding=False,
322+
return_attention_mask=False,
323+
return_token_type_ids=False,
324+
add_special_tokens=False
325+
)
302326
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
303327
item['attention_mask'] = [1] * len(item['input_ids'])
304328
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@@ -358,16 +382,28 @@ def compute_score_single_gpu(
358382

359383
batch_inputs = []
360384
for query_inputs, passage_inputs in zip(queries_inputs, passages_inputs):
361-
item = self.tokenizer.prepare_for_model(
362-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
363-
sep_inputs + passage_inputs['input_ids'],
364-
truncation='only_second',
365-
max_length=encode_max_length,
366-
padding=False,
367-
return_attention_mask=False,
368-
return_token_type_ids=False,
369-
add_special_tokens=False
370-
)
385+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
386+
item = self.tokenizer.prepare_for_model(
387+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
388+
sep_inputs + passage_inputs['input_ids'],
389+
truncation='only_second',
390+
max_length=encode_max_length,
391+
padding=False,
392+
return_attention_mask=False,
393+
return_token_type_ids=False,
394+
add_special_tokens=False
395+
)
396+
else:
397+
item = self.tokenizer.prepare_for_model(
398+
query_inputs['input_ids'],
399+
sep_inputs + passage_inputs['input_ids'],
400+
truncation='only_second',
401+
max_length=encode_max_length,
402+
padding=False,
403+
return_attention_mask=False,
404+
return_token_type_ids=False,
405+
add_special_tokens=False
406+
)
371407
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
372408
item['attention_mask'] = [1] * len(item['input_ids'])
373409
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None

0 commit comments

Comments
 (0)