@@ -68,16 +68,28 @@ def __len__(self):
6868 def __getitem__ (self , item ):
6969 query_inputs = self .all_queries_inputs [item ]
7070 passage_inputs = self .all_passages_inputs [item ]
71- item = self .tokenizer .prepare_for_model (
72- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
73- self .sep_inputs + passage_inputs ['input_ids' ],
74- truncation = 'only_second' ,
75- max_length = self .encode_max_length ,
76- padding = False ,
77- return_attention_mask = False ,
78- return_token_type_ids = False ,
79- add_special_tokens = False
80- )
71+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
72+ item = self .tokenizer .prepare_for_model (
73+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
74+ self .sep_inputs + passage_inputs ['input_ids' ],
75+ truncation = 'only_second' ,
76+ max_length = self .encode_max_length ,
77+ padding = False ,
78+ return_attention_mask = False ,
79+ return_token_type_ids = False ,
80+ add_special_tokens = False
81+ )
82+ else :
83+ item = self .tokenizer .prepare_for_model (
84+ query_inputs ['input_ids' ],
85+ self .sep_inputs + passage_inputs ['input_ids' ],
86+ truncation = 'only_second' ,
87+ max_length = self .encode_max_length ,
88+ padding = False ,
89+ return_attention_mask = False ,
90+ return_token_type_ids = False ,
91+ add_special_tokens = False
92+ )
8193 item ['input_ids' ] = item ['input_ids' ] + self .sep_inputs + self .prompt_inputs
8294 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
8395 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
@@ -289,16 +301,28 @@ def compute_score_single_gpu(
289301 all_queries_inputs_sorted [:min (len (all_queries_inputs_sorted ), batch_size )],
290302 all_passages_inputs_sorted [:min (len (all_passages_inputs_sorted ), batch_size )]
291303 ):
292- item = self .tokenizer .prepare_for_model (
293- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
294- sep_inputs + passage_inputs ['input_ids' ],
295- truncation = 'only_second' ,
296- max_length = encode_max_length ,
297- padding = False ,
298- return_attention_mask = False ,
299- return_token_type_ids = False ,
300- add_special_tokens = False
301- )
304+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
305+ item = self .tokenizer .prepare_for_model (
306+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
307+ sep_inputs + passage_inputs ['input_ids' ],
308+ truncation = 'only_second' ,
309+ max_length = encode_max_length ,
310+ padding = False ,
311+ return_attention_mask = False ,
312+ return_token_type_ids = False ,
313+ add_special_tokens = False
314+ )
315+ else :
316+ item = self .tokenizer .prepare_for_model (
317+ query_inputs ['input_ids' ],
318+ sep_inputs + passage_inputs ['input_ids' ],
319+ truncation = 'only_second' ,
320+ max_length = encode_max_length ,
321+ padding = False ,
322+ return_attention_mask = False ,
323+ return_token_type_ids = False ,
324+ add_special_tokens = False
325+ )
302326 item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
303327 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
304328 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
@@ -358,16 +382,28 @@ def compute_score_single_gpu(
358382
359383 batch_inputs = []
360384 for query_inputs , passage_inputs in zip (queries_inputs , passages_inputs ):
361- item = self .tokenizer .prepare_for_model (
362- [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
363- sep_inputs + passage_inputs ['input_ids' ],
364- truncation = 'only_second' ,
365- max_length = encode_max_length ,
366- padding = False ,
367- return_attention_mask = False ,
368- return_token_type_ids = False ,
369- add_special_tokens = False
370- )
385+ if self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id != self .tokenizer .pad_token_id :
386+ item = self .tokenizer .prepare_for_model (
387+ [self .tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
388+ sep_inputs + passage_inputs ['input_ids' ],
389+ truncation = 'only_second' ,
390+ max_length = encode_max_length ,
391+ padding = False ,
392+ return_attention_mask = False ,
393+ return_token_type_ids = False ,
394+ add_special_tokens = False
395+ )
396+ else :
397+ item = self .tokenizer .prepare_for_model (
398+ query_inputs ['input_ids' ],
399+ sep_inputs + passage_inputs ['input_ids' ],
400+ truncation = 'only_second' ,
401+ max_length = encode_max_length ,
402+ padding = False ,
403+ return_attention_mask = False ,
404+ return_token_type_ids = False ,
405+ add_special_tokens = False
406+ )
371407 item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
372408 item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
373409 item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
0 commit comments