Skip to content

Commit c9512f9

Browse files
authored
Merge pull request #642 from 545999961/master
update reranker v2
2 parents 3bfd718 + b0802a5 commit c9512f9

4 files changed

Lines changed: 185 additions & 46 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 180 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import Tensor
66
from torch.utils.data import DataLoader
7-
from tqdm import tqdm
7+
from tqdm import tqdm, trange
88
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, is_torch_npu_available
99

1010
import warnings
@@ -269,32 +269,96 @@ def __init__(
269269

270270
@torch.no_grad()
271271
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
272-
max_length: int = 512, prompt: str = None, normalize: bool = False) -> List[float]:
272+
max_length: int = 512, prompt: str = None, normalize: bool = False,
273+
use_dataloader: bool = True, num_workers: int = None) -> List[float]:
273274
assert isinstance(sentence_pairs, list)
274275
if isinstance(sentence_pairs[0], str):
275276
sentence_pairs = [sentence_pairs]
276277

277278
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
278279
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
279280

280-
dataset = DatasetForReranker(sentences_sorted,
281-
self.model_name_or_path,
282-
max_length,
283-
cache_dir=self.cache_dir,
284-
prompt=prompt)
285-
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
286-
num_workers=min(batch_size, 16),
287-
collate_fn=collater(self.tokenizer, max_length))
281+
dataset, dataloader = None, None
282+
if use_dataloader:
283+
if num_workers is None:
284+
num_workers = min(batch_size, 16)
285+
dataset = DatasetForReranker(sentences_sorted,
286+
self.model_name_or_path,
287+
max_length,
288+
cache_dir=self.cache_dir,
289+
prompt=prompt)
290+
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
291+
num_workers=num_workers,
292+
collate_fn=collater(self.tokenizer, max_length))
288293

289294
all_scores = []
290-
for inputs in tqdm(dataloader):
291-
inputs = inputs.to(self.device)
292-
293-
outputs = self.model(**inputs, output_hidden_states=True)
294-
logits = outputs.logits
295-
scores = last_logit_pool(logits, inputs['attention_mask'])
296-
scores = scores[:, self.yes_loc]
297-
all_scores.extend(scores.cpu().float().tolist())
295+
if dataloader is not None:
296+
for inputs in tqdm(dataloader):
297+
inputs = inputs.to(self.device)
298+
299+
outputs = self.model(**inputs, output_hidden_states=True)
300+
logits = outputs.logits
301+
scores = last_logit_pool(logits, inputs['attention_mask'])
302+
scores = scores[:, self.yes_loc]
303+
all_scores.extend(scores.cpu().float().tolist())
304+
else:
305+
if prompt is None:
306+
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
307+
prompt_inputs = self.tokenizer(prompt,
308+
return_tensors=None,
309+
add_special_tokens=False)['input_ids']
310+
sep = "\n"
311+
sep_inputs = self.tokenizer(sep,
312+
return_tensors=None,
313+
add_special_tokens=False)['input_ids']
314+
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
315+
for batch_start in trange(0, len(sentences_sorted), batch_size):
316+
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
317+
batch_sentences = [(f'A: {q}', f'B: {p}') for q,p in batch_sentences]
318+
queries = [s[0] for s in batch_sentences]
319+
passages = [s[1] for s in batch_sentences]
320+
queries_inputs = self.tokenizer(queries,
321+
return_tensors=None,
322+
add_special_tokens=False,
323+
max_length=max_length * 3 // 4,
324+
truncation=True)
325+
passages_inputs = self.tokenizer(passages,
326+
return_tensors=None,
327+
add_special_tokens=False,
328+
max_length=max_length,
329+
truncation=True)
330+
331+
batch_inputs = []
332+
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
333+
item = self.tokenizer.prepare_for_model(
334+
[self.tokenizer.bos_token_id] + query_inputs,
335+
sep_inputs + passage_inputs,
336+
truncation='only_second',
337+
max_length=encode_max_length,
338+
padding=False,
339+
return_attention_mask=False,
340+
return_token_type_ids=False,
341+
add_special_tokens=False
342+
)
343+
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
344+
item['attention_mask'] = [1] * len(item['input_ids'])
345+
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
346+
if 'position_ids' in item.keys():
347+
item['position_ids'] = list(range(len(item['input_ids'])))
348+
batch_inputs.append(item)
349+
350+
collater_instance = collater(self.tokenizer, max_length)
351+
batch_inputs = collater_instance(
352+
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
353+
batch_inputs])
354+
355+
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
356+
357+
outputs = self.model(**batch_inputs, output_hidden_states=True)
358+
logits = outputs.logits
359+
scores = last_logit_pool(logits, batch_inputs['attention_mask'])
360+
scores = scores[:, self.yes_loc]
361+
all_scores.extend(scores.cpu().float().tolist())
298362

299363
all_scores = [all_scores[idx] for idx in np.argsort(length_sorted_idx)]
300364

@@ -323,6 +387,7 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
323387
else:
324388
return sum([len(t) for t in text]) # Sum of length of individual strings
325389

390+
326391
class LayerWiseFlagLLMReranker:
327392
def __init__(
328393
self,
@@ -378,40 +443,112 @@ def __init__(
378443
@torch.no_grad()
379444
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16,
380445
max_length: int = 512, cutoff_layers: List[int] = None, prompt: str = None,
381-
normalize: bool = False) -> Union[float, List[float], List[List[float]]]:
446+
normalize: bool = False, use_dataloader: bool = True,
447+
num_workers: int = None) -> Union[float, List[float], List[List[float]]]:
382448
assert isinstance(sentence_pairs, list)
383449
if isinstance(sentence_pairs[0], str):
384450
sentence_pairs = [sentence_pairs]
385451

386452
length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs])
387453
sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx]
388454

389-
dataset = DatasetForReranker(sentences_sorted,
390-
self.model_name_or_path,
391-
max_length,
392-
cache_dir=self.cache_dir,
393-
prompt=prompt)
394-
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
395-
num_workers=min(batch_size, 16),
396-
collate_fn=collater(self.tokenizer, max_length))
455+
dataset, dataloader = None, None
456+
if use_dataloader:
457+
if num_workers is None:
458+
num_workers = min(batch_size, 16)
459+
dataset = DatasetForReranker(sentences_sorted,
460+
self.model_name_or_path,
461+
max_length,
462+
cache_dir=self.cache_dir,
463+
prompt=prompt)
464+
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
465+
num_workers=num_workers,
466+
collate_fn=collater(self.tokenizer, max_length))
397467

398468
all_scores = []
399-
for inputs in tqdm(dataloader):
400-
inputs = inputs.to(self.device)
401-
402-
outputs = self.model(**inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
403-
all_logits = outputs.logits
404-
tmp_all_scores = []
405-
for logits in all_logits:
406-
scores = last_logit_pool_layerwise(logits, inputs['attention_mask'])
407-
tmp_all_scores.append(scores.contiguous())
408-
409-
if len(all_scores) == 0:
410-
for _ in range(len(tmp_all_scores)):
411-
all_scores.append([])
412-
413-
for i in range(len(tmp_all_scores)):
414-
all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist())
469+
if dataloader is not None:
470+
for inputs in tqdm(dataloader):
471+
inputs = inputs.to(self.device)
472+
473+
outputs = self.model(**inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
474+
all_logits = outputs.logits
475+
tmp_all_scores = []
476+
for logits in all_logits:
477+
scores = last_logit_pool_layerwise(logits, inputs['attention_mask'])
478+
tmp_all_scores.append(scores.contiguous())
479+
480+
if len(all_scores) == 0:
481+
for _ in range(len(tmp_all_scores)):
482+
all_scores.append([])
483+
484+
for i in range(len(tmp_all_scores)):
485+
all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist())
486+
else:
487+
if prompt is None:
488+
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
489+
prompt_inputs = self.tokenizer(prompt,
490+
return_tensors=None,
491+
add_special_tokens=False)['input_ids']
492+
sep = "\n"
493+
sep_inputs = self.tokenizer(sep,
494+
return_tensors=None,
495+
add_special_tokens=False)['input_ids']
496+
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
497+
for batch_start in trange(0, len(sentences_sorted), batch_size):
498+
batch_sentences = sentences_sorted[batch_start:batch_start + batch_size]
499+
batch_sentences = [(f'A: {q}', f'B: {p}') for q, p in batch_sentences]
500+
queries = [s[0] for s in batch_sentences]
501+
passages = [s[1] for s in batch_sentences]
502+
queries_inputs = self.tokenizer(queries,
503+
return_tensors=None,
504+
add_special_tokens=False,
505+
max_length=max_length * 3 // 4,
506+
truncation=True)
507+
passages_inputs = self.tokenizer(passages,
508+
return_tensors=None,
509+
add_special_tokens=False,
510+
max_length=max_length,
511+
truncation=True)
512+
513+
batch_inputs = []
514+
for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']):
515+
item = self.tokenizer.prepare_for_model(
516+
[self.tokenizer.bos_token_id] + query_inputs,
517+
sep_inputs + passage_inputs,
518+
truncation='only_second',
519+
max_length=encode_max_length,
520+
padding=False,
521+
return_attention_mask=False,
522+
return_token_type_ids=False,
523+
add_special_tokens=False
524+
)
525+
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
526+
item['attention_mask'] = [1] * len(item['input_ids'])
527+
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
528+
if 'position_ids' in item.keys():
529+
item['position_ids'] = list(range(len(item['input_ids'])))
530+
batch_inputs.append(item)
531+
532+
collater_instance = collater(self.tokenizer, max_length)
533+
batch_inputs = collater_instance(
534+
[{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in
535+
batch_inputs])
536+
537+
batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()}
538+
539+
outputs = self.model(**batch_inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
540+
all_logits = outputs.logits
541+
tmp_all_scores = []
542+
for logits in all_logits:
543+
scores = last_logit_pool_layerwise(logits, batch_inputs['attention_mask'])
544+
tmp_all_scores.append(scores.contiguous())
545+
546+
if len(all_scores) == 0:
547+
for _ in range(len(tmp_all_scores)):
548+
all_scores.append([])
549+
550+
for i in range(len(tmp_all_scores)):
551+
all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist())
415552

416553
for i in range(len(all_scores)):
417554
all_scores[i] = [all_scores[i][idx] for idx in np.argsort(length_sorted_idx)]

FlagEmbedding/llm_reranker/finetune_for_instruction/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import sys
3+
from typing import List
34

45
import math
56
import os.path
@@ -60,7 +61,7 @@ def is_chinese(self, text):
6061
chinese_pattern = re.compile('[\u4e00-\u9fa5]')
6162
return bool(chinese_pattern.search(text))
6263

63-
def __getitem__(self, item) -> list[BatchEncoding]:
64+
def __getitem__(self, item) -> List[BatchEncoding]:
6465
query = self.dataset[item]['query']
6566

6667
passages = []

FlagEmbedding/llm_reranker/finetune_for_layerwise/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class ModelArguments:
7171
from_peft: str = field(
7272
default=None
7373
)
74-
lora_extra_parameters: str = field(
74+
lora_extra_parameters: Optional[List[str]] = field(
7575
default=None
7676
)
7777
start_layer: int = field(

FlagEmbedding/llm_reranker/finetune_for_layerwise/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import sys
3+
from typing import List
34

45
import math
56
import os.path
@@ -56,7 +57,7 @@ def __init__(
5657
def __len__(self):
5758
return self.total_len
5859

59-
def __getitem__(self, item) -> list[BatchEncoding]:
60+
def __getitem__(self, item) -> List[BatchEncoding]:
6061
query = self.dataset[item]['query']
6162

6263
passages = []

0 commit comments

Comments
 (0)