Skip to content

Commit 5e0a2d7

Browse files
committed
upload embedder and reranker
1 parent f3a2472 commit 5e0a2d7

1 file changed

Lines changed: 169 additions & 5 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 169 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
389389
else:
390390
return sum([len(t) for t in text]) # Sum of length of individual strings
391391

392-
393392
class LayerWiseFlagLLMReranker:
394393
def __init__(
395394
self,
@@ -561,10 +560,175 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
561560
if normalize:
562561
all_scores[i] = [sigmoid(score) for score in all_scores[i]]
563562

564-
# if len(all_scores) == 1:
565-
# if len(all_scores[0]) == 1:
566-
# return all_scores[0][0]
567-
# return all_scores[0]
563+
return all_scores
564+
565+
566+
def _text_length(self, text: Union[List[int], List[List[int]]]):
567+
"""
568+
Help function to get the length for the input text. Text can be either
569+
a list of ints (which means a single text as input), or a tuple of list of ints
570+
(representing several text inputs to the model).
571+
"""
572+
573+
if isinstance(text, dict): # {key: value} case
574+
return len(next(iter(text.values())))
575+
elif not hasattr(text, '__len__'): # Object has no len() method
576+
return 1
577+
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
578+
return len(text)
579+
else:
580+
return sum([len(t) for t in text]) # Sum of length of individual strings
581+
582+
583+
class LightWeightFlagLLMReranker:
584+
def __init__(
585+
self,
586+
model_name_or_path: str = None,
587+
peft_path: str = None,
588+
use_fp16: bool = False,
589+
use_bf16: bool = False,
590+
cache_dir: str = None,
591+
device: Union[str, int] = None
592+
) -> None:
593+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
594+
cache_dir=cache_dir,
595+
trust_remote_code=True)
596+
597+
if use_bf16 is False and use_fp16 is False:
598+
warnings.warn("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default.", UserWarning)
599+
use_fp16 = True
600+
601+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
602+
cache_dir=cache_dir,
603+
trust_remote_code=True,
604+
local_files_only=True,
605+
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32)
606+
if peft_path:
607+
self.model = PeftModel.from_pretrained(self.model,peft_path)
608+
self.model = self.model.merge_and_unload()
609+
self.model_name_or_path = model_name_or_path
610+
self.cache_dir = cache_dir
611+
612+
if device and isinstance(device, str):
613+
if device == 'cpu':
614+
warnings.warn('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.')
615+
device = 'cuda'
616+
self.device = torch.device(device)
617+
else:
618+
device = 0 if device is None else device
619+
if torch.cuda.is_available():
620+
torch.cuda.set_device(device)
621+
self.device = torch.device("cuda")
622+
elif torch.backends.mps.is_available():
623+
self.device = torch.device("mps")
624+
elif is_torch_npu_available():
625+
self.device = torch.device("npu")
626+
else:
627+
self.device = torch.device("cpu")
628+
use_fp16 = False
629+
630+
if use_fp16 and use_bf16 is False:
631+
self.model.half()
632+
633+
self.model = self.model.to(self.device)
634+
635+
self.model.eval()
636+
637+
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
638+
639+
@torch.no_grad()
640+
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
641+
max_length: int = 512,
642+
cutoff_layers: List[int] = None, compress_layer: List[int] = [8], compress_ratio: int = 1,
643+
prompt: str = None, normalize: bool = False) -> Union[float, List[float], List[List[float]]]:
644+
assert isinstance(sentence_pairs, list)
645+
if isinstance(sentence_pairs[0], str):
646+
sentence_pairs = [sentence_pairs]
647+
648+
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
649+
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
650+
651+
if prompt is None:
652+
prompt = "Predict whether passage B contains an answer to query A."
653+
prompt_inputs = self.tokenizer(prompt,
654+
return_tensors=None,
655+
add_special_tokens=False)['input_ids']
656+
sep = "\n"
657+
sep_inputs = self.tokenizer(sep,
658+
return_tensors=None,
659+
add_special_tokens=False)['input_ids']
660+
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
661+
all_scores = []
662+
for batch_start in trange(0, len(sentences_sorted), batch_size):
663+
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
664+
batch_sentences = [(f'A: {q}', f'B: {p}') for q, p in batch_sentences]
665+
queries = [s[0] for s in batch_sentences]
666+
passages = [s[1] for s in batch_sentences]
667+
queries_inputs = self.tokenizer(queries,
668+
return_tensors=None,
669+
add_special_tokens=False,
670+
max_length=max_length * 3 // 4,
671+
truncation=True)
672+
passages_inputs = self.tokenizer(passages,
673+
return_tensors=None,
674+
add_special_tokens=False,
675+
max_length=max_length,
676+
truncation=True)
677+
query_lengths = []
678+
prompt_lengths = []
679+
batch_inputs = []
680+
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
681+
item = self.tokenizer.prepare_for_model(
682+
[self.tokenizer.bos_token_id] + query_inputs,
683+
sep_inputs + passage_inputs,
684+
truncation='only_second',
685+
max_length=encode_max_length,
686+
padding=False,
687+
return_attention_mask=False,
688+
return_token_type_ids=False,
689+
add_special_tokens=False
690+
)
691+
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
692+
item['attention_mask'] = [1] * len(item['input_ids'])
693+
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
694+
if 'position_ids' in item.keys():
695+
item['position_ids'] = list(range(len(item['input_ids'])))
696+
batch_inputs.append(item)
697+
query_lengths.append(len([self.tokenizer.bos_token_id] + query_inputs + sep_inputs))
698+
prompt_lengths.append(len(sep_inputs + prompt_inputs))
699+
700+
collater_instance = collater(self.tokenizer, max_length)
701+
batch_inputs = collater_instance(
702+
[
703+
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
704+
batch_inputs],
705+
query_lengths,
706+
prompt_lengths
707+
])[0]
708+
709+
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
710+
711+
outputs = self.model(**batch_inputs,
712+
output_hidden_states=True,
713+
compress_layer=compress_layer,
714+
compress_ratio=compress_ratio,
715+
query_lengths=query_lengths,
716+
prompt_lengths=prompt_lengths,
717+
cutoff_layers=cutoff_layers)
718+
scores = []
719+
for i in range(len(outputs.logits)):
720+
logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
721+
scores.append(logits.cpu().float().tolist())
722+
if len(all_scores) == 0:
723+
for i in range(len(scores)):
724+
all_scores.append([])
725+
for i in range(len(scores)):
726+
all_scores[i].extend(scores[i])
727+
728+
for i in range(len(all_scores)):
729+
all_scores[i] = [all_scores[i][idx] for idx in np.argsort(length_sorted_idx)]
730+
if normalize:
731+
all_scores[i] = [sigmoid(score) for score in all_scores[i]]
568732

569733
return all_scores
570734

0 commit comments

Comments
 (0)