@@ -316,11 +316,16 @@ def __init__(
316316 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ,
317317 cache_dir = cache_dir ,
318318 trust_remote_code = True )
319-
320- self .model = AutoModelForCausalLM .from_pretrained (model_name_or_path ,
321- cache_dir = cache_dir ,
322- trust_remote_code = True ,
323- torch_dtype = torch .bfloat16 if use_bf16 else torch .float32 )
319+ if use_bf16 :
320+ self .model = AutoModelForCausalLM .from_pretrained (model_name_or_path ,
321+ cache_dir = cache_dir ,
322+ trust_remote_code = True ,
323+ torch_dtype = torch .bfloat16 )
324+ else :
325+ self .model = AutoModelForCausalLM .from_pretrained (model_name_or_path ,
326+ cache_dir = cache_dir ,
327+ trust_remote_code = True ,
328+ use_flash_attention_2 = False )
324329 self .model_name_or_path = model_name_or_path
325330 self .cache_dir = cache_dir
326331
@@ -344,8 +349,8 @@ def __init__(
344349 @torch .no_grad ()
345350 def compute_score (self , sentence_pairs : Union [List [Tuple [str , str ]], Tuple [str , str ]], batch_size : int = 16 ,
346351 max_length : int = 512 , cutoff_layers : List [int ] = None , prompt : str = None ,
347- normalize : bool = False ) -> float | list [Any ] | list [ float | Any ] | list [
348- list [Any ] | list [ float | Any ]] | Any :
352+ normalize : bool = False ) -> Union [ float , List [Any ], List [ Union [ float , Any ]], List [
353+ List [Any ], List [ Union [ float , Any ]]], Any ] :
349354 assert isinstance (sentence_pairs , list )
350355 if isinstance (sentence_pairs [0 ], str ):
351356 sentence_pairs = [sentence_pairs ]
@@ -407,47 +412,4 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
407412 elif len (text ) == 0 or isinstance (text [0 ], int ): # Empty string or list of ints
408413 return len (text )
409414 else :
410- return sum ([len (t ) for t in text ]) # Sum of length of individual strings
411-
412- def get_inputs (pairs , tokenizer , prompt = None , max_length = 1024 ):
413- if prompt is None :
414- prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
415- sep = "\n "
416- prompt_inputs = tokenizer (prompt ,
417- return_tensors = None ,
418- add_special_tokens = False )['input_ids' ]
419- sep_inputs = tokenizer (sep ,
420- return_tensors = None ,
421- add_special_tokens = False )['input_ids' ]
422- inputs = []
423- for query , passage in pairs :
424- query_inputs = tokenizer (query ,
425- return_tensors = None ,
426- add_special_tokens = False ,
427- max_length = max_length * 3 // 4 ,
428- truncation = True )
429- passage_inputs = tokenizer (passage ,
430- return_tensors = None ,
431- add_special_tokens = False ,
432- max_length = max_length ,
433- truncation = True )
434- item = tokenizer .prepare_for_model (
435- [tokenizer .bos_token_id ] + query_inputs ['input_ids' ],
436- sep_inputs + passage_inputs ['input_ids' ],
437- truncation = 'only_second' ,
438- max_length = max_length ,
439- padding = False ,
440- return_attention_mask = False ,
441- return_token_type_ids = False ,
442- add_special_tokens = False
443- )
444- item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
445- item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
446- inputs .append (item )
447- return tokenizer .pad (
448- inputs ,
449- padding = True ,
450- max_length = max_length ,
451- pad_to_multiple_of = 8 ,
452- return_tensors = 'pt' ,
453- )
415+ return sum ([len (t ) for t in text ]) # Sum of length of individual strings
0 commit comments