Skip to content

Commit 4f573c5

Browse files
committed
update-reranker-v2
1 parent 8011c1a commit 4f573c5

1 file changed

Lines changed: 13 additions & 51 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,16 @@ def __init__(
316316
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
317317
cache_dir=cache_dir,
318318
trust_remote_code=True)
319-
320-
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
321-
cache_dir=cache_dir,
322-
trust_remote_code=True,
323-
torch_dtype=torch.bfloat16 if use_bf16 else torch.float32)
319+
if use_bf16:
320+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
321+
cache_dir=cache_dir,
322+
trust_remote_code=True,
323+
torch_dtype=torch.bfloat16)
324+
else:
325+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
326+
cache_dir=cache_dir,
327+
trust_remote_code=True,
328+
use_flash_attention_2=False)
324329
self.model_name_or_path = model_name_or_path
325330
self.cache_dir = cache_dir
326331

@@ -344,8 +349,8 @@ def __init__(
344349
@torch.no_grad()
345350
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
346351
max_length: int = 512, cutoff_layers: List[int] = None, prompt: str = None,
347-
normalize: bool = False) -> float | list[Any] | list[float | Any] | list[
348-
list[Any] | list[float | Any]] | Any:
352+
normalize: bool = False) -> Union[float, List[Any], List[Union[float, Any]], List[
353+
List[Any], List[Union[float, Any]]], Any]:
349354
assert isinstance(sentence_pairs, list)
350355
if isinstance(sentence_pairs[0], str):
351356
sentence_pairs = [sentence_pairs]
@@ -407,47 +412,4 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
407412
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
408413
return len(text)
409414
else:
410-
return sum([len(t) for t in text]) # Sum of length of individual strings
411-
412-
def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
413-
if prompt is None:
414-
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
415-
sep = "\n"
416-
prompt_inputs = tokenizer(prompt,
417-
return_tensors=None,
418-
add_special_tokens=False)['input_ids']
419-
sep_inputs = tokenizer(sep,
420-
return_tensors=None,
421-
add_special_tokens=False)['input_ids']
422-
inputs = []
423-
for query, passage in pairs:
424-
query_inputs = tokenizer(query,
425-
return_tensors=None,
426-
add_special_tokens=False,
427-
max_length=max_length * 3 // 4,
428-
truncation=True)
429-
passage_inputs = tokenizer(passage,
430-
return_tensors=None,
431-
add_special_tokens=False,
432-
max_length=max_length,
433-
truncation=True)
434-
item = tokenizer.prepare_for_model(
435-
[tokenizer.bos_token_id] + query_inputs['input_ids'],
436-
sep_inputs + passage_inputs['input_ids'],
437-
truncation='only_second',
438-
max_length=max_length,
439-
padding=False,
440-
return_attention_mask=False,
441-
return_token_type_ids=False,
442-
add_special_tokens=False
443-
)
444-
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
445-
item['attention_mask'] = [1] * len(item['input_ids'])
446-
inputs.append(item)
447-
return tokenizer.pad(
448-
inputs,
449-
padding=True,
450-
max_length=max_length,
451-
pad_to_multiple_of=8,
452-
return_tensors='pt',
453-
)
415+
return sum([len(t) for t in text]) # Sum of length of individual strings

0 commit comments

Comments
 (0)