@@ -88,16 +88,28 @@ def __len__(self):
8888 def __getitem__ (self , item ):
8989 query_inputs = self .all_queries_inputs [item ]
9090 passage_inputs = self .all_passages_inputs [item ]
91- item = self .tokenizer .prepare_for_model (
92- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
93- self .sep_inputs + passage_inputs ['input_ids' ],
94- truncation = 'only_second' ,
95- max_length = self .encode_max_length ,
96- padding = False ,
97- return_attention_mask = False ,
98- return_token_type_ids = False ,
99- add_special_tokens = False
100- )
91+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
92+ item = self .tokenizer .prepare_for_model (
93+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
94+ self .sep_inputs + passage_inputs ['input_ids' ],
95+ truncation = 'only_second' ,
96+ max_length = self .encode_max_length ,
97+ padding = False ,
98+ return_attention_mask = False ,
99+ return_token_type_ids = False ,
100+ add_special_tokens = False
101+ )
102+ else :
103+ item = self .tokenizer .prepare_for_model (
104+ query_inputs ['input_ids' ],
105+ self .sep_inputs + passage_inputs ['input_ids' ],
106+ truncation = 'only_second' ,
107+ max_length = self .encode_max_length ,
108+ padding = False ,
109+ return_attention_mask = False ,
110+ return_token_type_ids = False ,
111+ add_special_tokens = False
112+ )
101113 item ['input_ids' ] = item ['input_ids' ] + self .sep_inputs + self .prompt_inputs
102114 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
103115 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
@@ -357,16 +369,28 @@ def compute_score_single_gpu(
357369 all_queries_inputs_sorted [:min (len (all_queries_inputs_sorted ), batch_size )],
358370 all_passages_inputs_sorted [:min (len (all_passages_inputs_sorted ), batch_size )]
359371 ):
360- item = self .tokenizer .prepare_for_model (
361- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
362- sep_inputs + passage_inputs ['input_ids' ],
363- truncation = 'only_second' ,
364- max_length = encode_max_length ,
365- padding = False ,
366- return_attention_mask = False ,
367- return_token_type_ids = False ,
368- add_special_tokens = False
369- )
372+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
373+ item = self .tokenizer .prepare_for_model (
374+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
375+ sep_inputs + passage_inputs ['input_ids' ],
376+ truncation = 'only_second' ,
377+ max_length = encode_max_length ,
378+ padding = False ,
379+ return_attention_mask = False ,
380+ return_token_type_ids = False ,
381+ add_special_tokens = False
382+ )
383+ else :
384+ item = self .tokenizer .prepare_for_model (
385+ query_inputs ['input_ids' ],
386+ sep_inputs + passage_inputs ['input_ids' ],
387+ truncation = 'only_second' ,
388+ max_length = encode_max_length ,
389+ padding = False ,
390+ return_attention_mask = False ,
391+ return_token_type_ids = False ,
392+ add_special_tokens = False
393+ )
370394 item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
371395 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
372396 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
@@ -426,16 +450,28 @@ def compute_score_single_gpu(
426450
427451 batch_inputs = []
428452 for query_inputs , passage_inputs in zip (queries_inputs , passages_inputs ):
429- item = self .tokenizer .prepare_for_model (
430- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
431- sep_inputs + passage_inputs ['input_ids' ],
432- truncation = 'only_second' ,
433- max_length = encode_max_length ,
434- padding = False ,
435- return_attention_mask = False ,
436- return_token_type_ids = False ,
437- add_special_tokens = False
438- )
453+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
454+ item = self .tokenizer .prepare_for_model (
455+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
456+ sep_inputs + passage_inputs ['input_ids' ],
457+ truncation = 'only_second' ,
458+ max_length = encode_max_length ,
459+ padding = False ,
460+ return_attention_mask = False ,
461+ return_token_type_ids = False ,
462+ add_special_tokens = False
463+ )
464+ else :
465+ item = self .tokenizer .prepare_for_model (
466+ query_inputs ['input_ids' ],
467+ sep_inputs + passage_inputs ['input_ids' ],
468+ truncation = 'only_second' ,
469+ max_length = encode_max_length ,
470+ padding = False ,
471+ return_attention_mask = False ,
472+ return_token_type_ids = False ,
473+ add_special_tokens = False
474+ )
439475 item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
440476 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
441477 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
0 commit comments