@@ -389,7 +389,6 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
389389 else :
390390 return sum ([len (t ) for t in text ]) # Sum of length of individual strings
391391
392-
393392class LayerWiseFlagLLMReranker :
394393 def __init__ (
395394 self ,
@@ -561,10 +560,175 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
561560 if normalize :
562561 all_scores [i ] = [sigmoid (score ) for score in all_scores [i ]]
563562
564- # if len(all_scores) == 1:
565- # if len(all_scores[0]) == 1:
566- # return all_scores[0][0]
567- # return all_scores[0]
563+ return all_scores
564+
565+
566+ def _text_length (self , text : Union [List [int ], List [List [int ]]]):
567+ """
568+ Help function to get the length for the input text. Text can be either
569+ a list of ints (which means a single text as input), or a tuple of list of ints
570+ (representing several text inputs to the model).
571+ """
572+
573+ if isinstance (text , dict ): # {key: value} case
574+ return len (next (iter (text .values ())))
575+ elif not hasattr (text , '__len__' ): # Object has no len() method
576+ return 1
577+ elif len (text ) == 0 or isinstance (text [0 ], int ): # Empty string or list of ints
578+ return len (text )
579+ else :
580+ return sum ([len (t ) for t in text ]) # Sum of length of individual strings
581+
582+
583+ class LightWeightFlagLLMReranker :
584+ def __init__ (
585+ self ,
586+ model_name_or_path : str = None ,
587+ peft_path : str = None ,
588+ use_fp16 : bool = False ,
589+ use_bf16 : bool = False ,
590+ cache_dir : str = None ,
591+ device : Union [str , int ] = None
592+ ) -> None :
593+ self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ,
594+ cache_dir = cache_dir ,
595+ trust_remote_code = True )
596+
597+ if use_bf16 is False and use_fp16 is False :
598+ warnings .warn ("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default." , UserWarning )
599+ use_fp16 = True
600+
601+ self .model = AutoModelForCausalLM .from_pretrained (model_name_or_path ,
602+ cache_dir = cache_dir ,
603+ trust_remote_code = True ,
604+ local_files_only = True ,
605+ torch_dtype = torch .bfloat16 if use_bf16 else torch .float32 )
606+ if peft_path :
607+ self .model = PeftModel .from_pretrained (self .model ,peft_path )
608+ self .model = self .model .merge_and_unload ()
609+ self .model_name_or_path = model_name_or_path
610+ self .cache_dir = cache_dir
611+
612+ if device and isinstance (device , str ):
613+ if device == 'cpu' :
614+ warnings .warn ('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.' )
615+ device = 'cuda'
616+ self .device = torch .device (device )
617+ else :
618+ device = 0 if device is None else device
619+ if torch .cuda .is_available ():
620+ torch .cuda .set_device (device )
621+ self .device = torch .device ("cuda" )
622+ elif torch .backends .mps .is_available ():
623+ self .device = torch .device ("mps" )
624+ elif is_torch_npu_available ():
625+ self .device = torch .device ("npu" )
626+ else :
627+ self .device = torch .device ("cpu" )
628+ use_fp16 = False
629+
630+ if use_fp16 and use_bf16 is False :
631+ self .model .half ()
632+
633+ self .model = self .model .to (self .device )
634+
635+ self .model .eval ()
636+
637+ self .yes_loc = self .tokenizer ('Yes' , add_special_tokens = False )['input_ids' ][0 ]
638+
639+ @torch .no_grad ()
640+ def compute_score (self , sentence_pairs : Union [List [Tuple [str , str ]], Tuple [str , str ]], batch_size : int = 16 ,
641+ max_length : int = 512 ,
642+ cutoff_layers : List [int ] = None , compress_layer : List [int ] = [8 ], compress_ratio : int = 1 ,
643+ prompt : str = None , normalize : bool = False ) -> Union [float , List [float ], List [List [float ]]]:
644+ assert isinstance (sentence_pairs , list )
645+ if isinstance (sentence_pairs [0 ], str ):
646+ sentence_pairs = [sentence_pairs ]
647+
648+ length_sorted_idx = np .argsort ([- self ._text_length (q ) - self ._text_length (p ) for q , p in sentence_pairs ])
649+ sentences_sorted = [sentence_pairs [idx ] for idx in length_sorted_idx ]
650+
651+ if prompt is None :
652+ prompt = "Predict whether passage B contains an answer to query A."
653+ prompt_inputs = self .tokenizer (prompt ,
654+ return_tensors = None ,
655+ add_special_tokens = False )['input_ids' ]
656+ sep = "\n "
657+ sep_inputs = self .tokenizer (sep ,
658+ return_tensors = None ,
659+ add_special_tokens = False )['input_ids' ]
660+ encode_max_length = max_length + len (sep_inputs ) + len (prompt_inputs )
661+ all_scores = []
662+ for batch_start in trange (0 , len (sentences_sorted ), batch_size ):
663+ batch_sentences = sentences_sorted [batch_start :batch_start + batch_size ]
664+ batch_sentences = [(f'A: { q } ' , f'B: { p } ' ) for q , p in batch_sentences ]
665+ queries = [s [0 ] for s in batch_sentences ]
666+ passages = [s [1 ] for s in batch_sentences ]
667+ queries_inputs = self .tokenizer (queries ,
668+ return_tensors = None ,
669+ add_special_tokens = False ,
670+ max_length = max_length * 3 // 4 ,
671+ truncation = True )
672+ passages_inputs = self .tokenizer (passages ,
673+ return_tensors = None ,
674+ add_special_tokens = False ,
675+ max_length = max_length ,
676+ truncation = True )
677+ query_lengths = []
678+ prompt_lengths = []
679+ batch_inputs = []
680+ for query_inputs , passage_inputs in zip (queries_inputs ['input_ids' ], passages_inputs ['input_ids' ]):
681+ item = self .tokenizer .prepare_for_model (
682+ [self .tokenizer .bos_token_id ] + query_inputs ,
683+ sep_inputs + passage_inputs ,
684+ truncation = 'only_second' ,
685+ max_length = encode_max_length ,
686+ padding = False ,
687+ return_attention_mask = False ,
688+ return_token_type_ids = False ,
689+ add_special_tokens = False
690+ )
691+ item ['input_ids' ] = item ['input_ids' ] + sep_inputs + prompt_inputs
692+ item ['attention_mask' ] = [1 ] * len (item ['input_ids' ])
693+ item .pop ('token_type_ids' ) if 'token_type_ids' in item .keys () else None
694+ if 'position_ids' in item .keys ():
695+ item ['position_ids' ] = list (range (len (item ['input_ids' ])))
696+ batch_inputs .append (item )
697+ query_lengths .append (len ([self .tokenizer .bos_token_id ] + query_inputs + sep_inputs ))
698+ prompt_lengths .append (len (sep_inputs + prompt_inputs ))
699+
700+ collater_instance = collater (self .tokenizer , max_length )
701+ batch_inputs = collater_instance (
702+ [
703+ [{'input_ids' : item ['input_ids' ], 'attention_mask' : item ['attention_mask' ]} for item in
704+ batch_inputs ],
705+ query_lengths ,
706+ prompt_lengths
707+ ])[0 ]
708+
709+ batch_inputs = {key : val .to (self .device ) for key , val in batch_inputs .items ()}
710+
711+ outputs = self .model (** batch_inputs ,
712+ output_hidden_states = True ,
713+ compress_layer = compress_layer ,
714+ compress_ratio = compress_ratio ,
715+ query_lengths = query_lengths ,
716+ prompt_lengths = prompt_lengths ,
717+ cutoff_layers = cutoff_layers )
718+ scores = []
719+ for i in range (len (outputs .logits )):
720+ logits = last_logit_pool (outputs .logits [i ], outputs .attention_masks [i ])
721+ scores .append (logits .cpu ().float ().tolist ())
722+ if len (all_scores ) == 0 :
723+ for i in range (len (scores )):
724+ all_scores .append ([])
725+ for i in range (len (scores )):
726+ all_scores [i ].extend (scores [i ])
727+
728+ for i in range (len (all_scores )):
729+ all_scores [i ] = [all_scores [i ][idx ] for idx in np .argsort (length_sorted_idx )]
730+ if normalize :
731+ all_scores [i ] = [sigmoid (score ) for score in all_scores [i ]]
568732
569733 return all_scores
570734
0 commit comments