Skip to content

Commit 324a9a5

Browse files
authored
Merge branch 'FlagOpen:master' into master
2 parents ee3437f + d125f07 commit 324a9a5

1 file changed

Lines changed: 66 additions & 30 deletions

File tree

  • FlagEmbedding/inference/reranker/decoder_only

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,28 @@ def __len__(self):
8888
def __getitem__(self, item):
8989
query_inputs = self.all_queries_inputs[item]
9090
passage_inputs = self.all_passages_inputs[item]
91-
item = self.tokenizer.prepare_for_model(
92-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
93-
self.sep_inputs + passage_inputs['input_ids'],
94-
truncation='only_second',
95-
max_length=self.encode_max_length,
96-
padding=False,
97-
return_attention_mask=False,
98-
return_token_type_ids=False,
99-
add_special_tokens=False
100-
)
91+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
92+
item = self.tokenizer.prepare_for_model(
93+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
94+
self.sep_inputs + passage_inputs['input_ids'],
95+
truncation='only_second',
96+
max_length=self.encode_max_length,
97+
padding=False,
98+
return_attention_mask=False,
99+
return_token_type_ids=False,
100+
add_special_tokens=False
101+
)
102+
else:
103+
item = self.tokenizer.prepare_for_model(
104+
query_inputs['input_ids'],
105+
self.sep_inputs + passage_inputs['input_ids'],
106+
truncation='only_second',
107+
max_length=self.encode_max_length,
108+
padding=False,
109+
return_attention_mask=False,
110+
return_token_type_ids=False,
111+
add_special_tokens=False
112+
)
101113
item['input_ids'] = item['input_ids'] + self.sep_inputs + self.prompt_inputs
102114
item['attention_mask'] = [1] * len(item['input_ids'])
103115
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@@ -357,16 +369,28 @@ def compute_score_single_gpu(
357369
all_queries_inputs_sorted[:min(len(all_queries_inputs_sorted), batch_size)],
358370
all_passages_inputs_sorted[:min(len(all_passages_inputs_sorted), batch_size)]
359371
):
360-
item = self.tokenizer.prepare_for_model(
361-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
362-
sep_inputs + passage_inputs['input_ids'],
363-
truncation='only_second',
364-
max_length=encode_max_length,
365-
padding=False,
366-
return_attention_mask=False,
367-
return_token_type_ids=False,
368-
add_special_tokens=False
369-
)
372+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
373+
item = self.tokenizer.prepare_for_model(
374+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
375+
sep_inputs + passage_inputs['input_ids'],
376+
truncation='only_second',
377+
max_length=encode_max_length,
378+
padding=False,
379+
return_attention_mask=False,
380+
return_token_type_ids=False,
381+
add_special_tokens=False
382+
)
383+
else:
384+
item = self.tokenizer.prepare_for_model(
385+
query_inputs['input_ids'],
386+
sep_inputs + passage_inputs['input_ids'],
387+
truncation='only_second',
388+
max_length=encode_max_length,
389+
padding=False,
390+
return_attention_mask=False,
391+
return_token_type_ids=False,
392+
add_special_tokens=False
393+
)
370394
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
371395
item['attention_mask'] = [1] * len(item['input_ids'])
372396
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@@ -426,16 +450,28 @@ def compute_score_single_gpu(
426450

427451
batch_inputs = []
428452
for query_inputs, passage_inputs in zip(queries_inputs, passages_inputs):
429-
item = self.tokenizer.prepare_for_model(
430-
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
431-
sep_inputs + passage_inputs['input_ids'],
432-
truncation='only_second',
433-
max_length=encode_max_length,
434-
padding=False,
435-
return_attention_mask=False,
436-
return_token_type_ids=False,
437-
add_special_tokens=False
438-
)
453+
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
454+
item = self.tokenizer.prepare_for_model(
455+
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
456+
sep_inputs + passage_inputs['input_ids'],
457+
truncation='only_second',
458+
max_length=encode_max_length,
459+
padding=False,
460+
return_attention_mask=False,
461+
return_token_type_ids=False,
462+
add_special_tokens=False
463+
)
464+
else:
465+
item = self.tokenizer.prepare_for_model(
466+
query_inputs['input_ids'],
467+
sep_inputs + passage_inputs['input_ids'],
468+
truncation='only_second',
469+
max_length=encode_max_length,
470+
padding=False,
471+
return_attention_mask=False,
472+
return_token_type_ids=False,
473+
add_special_tokens=False
474+
)
439475
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
440476
item['attention_mask'] = [1] * len(item['input_ids'])
441477
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None

0 commit comments

Comments
 (0)