|
4 | 4 | import torch |
5 | 5 | from torch import Tensor |
6 | 6 | from torch.utils.data import DataLoader |
7 | | -from tqdm import tqdm |
| 7 | +from tqdm import tqdm, trange |
8 | 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, is_torch_npu_available |
9 | 9 |
|
10 | 10 | import warnings |
@@ -269,32 +269,96 @@ def __init__( |
269 | 269 |
|
270 | 270 | @torch.no_grad() |
271 | 271 | def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16, |
272 | | - max_length: int = 512, prompt: str = None, normalize: bool = False) -> List[float]: |
| 272 | + max_length: int = 512, prompt: str = None, normalize: bool = False, |
| 273 | + use_dataloader: bool = True, num_workers: int = None) -> List[float]: |
273 | 274 | assert isinstance(sentence_pairs, list) |
274 | 275 | if isinstance(sentence_pairs[0], str): |
275 | 276 | sentence_pairs = [sentence_pairs] |
276 | 277 |
|
277 | 278 | length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs]) |
278 | 279 | sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx] |
279 | 280 |
|
280 | | - dataset = DatasetForReranker(sentences_sorted, |
281 | | - self.model_name_or_path, |
282 | | - max_length, |
283 | | - cache_dir=self.cache_dir, |
284 | | - prompt=prompt) |
285 | | - dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False, |
286 | | - num_workers=min(batch_size, 16), |
287 | | - collate_fn=collater(self.tokenizer, max_length)) |
| 281 | + dataset, dataloader = None, None |
| 282 | + if use_dataloader: |
| 283 | + if num_workers is None: |
| 284 | + num_workers = min(batch_size, 16) |
| 285 | + dataset = DatasetForReranker(sentences_sorted, |
| 286 | + self.model_name_or_path, |
| 287 | + max_length, |
| 288 | + cache_dir=self.cache_dir, |
| 289 | + prompt=prompt) |
| 290 | + dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False, |
| 291 | + num_workers=num_workers, |
| 292 | + collate_fn=collater(self.tokenizer, max_length)) |
288 | 293 |
|
289 | 294 | all_scores = [] |
290 | | - for inputs in tqdm(dataloader): |
291 | | - inputs = inputs.to(self.device) |
292 | | - |
293 | | - outputs = self.model(**inputs, output_hidden_states=True) |
294 | | - logits = outputs.logits |
295 | | - scores = last_logit_pool(logits, inputs['attention_mask']) |
296 | | - scores = scores[:, self.yes_loc] |
297 | | - all_scores.extend(scores.cpu().float().tolist()) |
| 295 | + if dataloader is not None: |
| 296 | + for inputs in tqdm(dataloader): |
| 297 | + inputs = inputs.to(self.device) |
| 298 | + |
| 299 | + outputs = self.model(**inputs, output_hidden_states=True) |
| 300 | + logits = outputs.logits |
| 301 | + scores = last_logit_pool(logits, inputs['attention_mask']) |
| 302 | + scores = scores[:, self.yes_loc] |
| 303 | + all_scores.extend(scores.cpu().float().tolist()) |
| 304 | + else: |
| 305 | + if prompt is None: |
| 306 | + prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." |
| 307 | + prompt_inputs = self.tokenizer(prompt, |
| 308 | + return_tensors=None, |
| 309 | + add_special_tokens=False)['input_ids'] |
| 310 | + sep = "\n" |
| 311 | + sep_inputs = self.tokenizer(sep, |
| 312 | + return_tensors=None, |
| 313 | + add_special_tokens=False)['input_ids'] |
| 314 | + encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs) |
| 315 | + for batch_start in trange(0, len(sentences_sorted), batch_size): |
| 316 | + batch_sentences = sentences_sorted[batch_start:batch_start + batch_size] |
| 317 | + batch_sentences = [(f'A: {q}', f'B: {p}') for q,p in batch_sentences] |
| 318 | + queries = [s[0] for s in batch_sentences] |
| 319 | + passages = [s[1] for s in batch_sentences] |
| 320 | + queries_inputs = self.tokenizer(queries, |
| 321 | + return_tensors=None, |
| 322 | + add_special_tokens=False, |
| 323 | + max_length=max_length * 3 // 4, |
| 324 | + truncation=True) |
| 325 | + passages_inputs = self.tokenizer(passages, |
| 326 | + return_tensors=None, |
| 327 | + add_special_tokens=False, |
| 328 | + max_length=max_length, |
| 329 | + truncation=True) |
| 330 | + |
| 331 | + batch_inputs = [] |
| 332 | + for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']): |
| 333 | + item = self.tokenizer.prepare_for_model( |
| 334 | + [self.tokenizer.bos_token_id] + query_inputs, |
| 335 | + sep_inputs + passage_inputs, |
| 336 | + truncation='only_second', |
| 337 | + max_length=encode_max_length, |
| 338 | + padding=False, |
| 339 | + return_attention_mask=False, |
| 340 | + return_token_type_ids=False, |
| 341 | + add_special_tokens=False |
| 342 | + ) |
| 343 | + item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs |
| 344 | + item['attention_mask'] = [1] * len(item['input_ids']) |
| 345 | + item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None |
| 346 | + if 'position_ids' in item.keys(): |
| 347 | + item['position_ids'] = list(range(len(item['input_ids']))) |
| 348 | + batch_inputs.append(item) |
| 349 | + |
| 350 | + collater_instance = collater(self.tokenizer, max_length) |
| 351 | + batch_inputs = collater_instance( |
| 352 | + [{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in |
| 353 | + batch_inputs]) |
| 354 | + |
| 355 | + batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()} |
| 356 | + |
| 357 | + outputs = self.model(**batch_inputs, output_hidden_states=True) |
| 358 | + logits = outputs.logits |
| 359 | + scores = last_logit_pool(logits, batch_inputs['attention_mask']) |
| 360 | + scores = scores[:, self.yes_loc] |
| 361 | + all_scores.extend(scores.cpu().float().tolist()) |
298 | 362 |
|
299 | 363 | all_scores = [all_scores[idx] for idx in np.argsort(length_sorted_idx)] |
300 | 364 |
|
@@ -323,6 +387,7 @@ def _text_length(self, text: Union[List[int], List[List[int]]]): |
323 | 387 | else: |
324 | 388 | return sum([len(t) for t in text]) # Sum of length of individual strings |
325 | 389 |
|
| 390 | + |
326 | 391 | class LayerWiseFlagLLMReranker: |
327 | 392 | def __init__( |
328 | 393 | self, |
@@ -378,40 +443,112 @@ def __init__( |
378 | 443 | @torch.no_grad() |
379 | 444 | def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 16, |
380 | 445 | max_length: int = 512, cutoff_layers: List[int] = None, prompt: str = None, |
381 | | - normalize: bool = False) -> Union[float, List[float], List[List[float]]]: |
| 446 | + normalize: bool = False, use_dataloader: bool = True, |
| 447 | + num_workers: int = None) -> Union[float, List[float], List[List[float]]]: |
382 | 448 | assert isinstance(sentence_pairs, list) |
383 | 449 | if isinstance(sentence_pairs[0], str): |
384 | 450 | sentence_pairs = [sentence_pairs] |
385 | 451 |
|
386 | 452 | length_sorted_idx = np.argsort([-self._text_length(q) - self._text_length(p) for q, p in sentence_pairs]) |
387 | 453 | sentences_sorted = [sentence_pairs[idx] for idx in length_sorted_idx] |
388 | 454 |
|
389 | | - dataset = DatasetForReranker(sentences_sorted, |
390 | | - self.model_name_or_path, |
391 | | - max_length, |
392 | | - cache_dir=self.cache_dir, |
393 | | - prompt=prompt) |
394 | | - dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False, |
395 | | - num_workers=min(batch_size, 16), |
396 | | - collate_fn=collater(self.tokenizer, max_length)) |
| 455 | + dataset, dataloader = None, None |
| 456 | + if use_dataloader: |
| 457 | + if num_workers is None: |
| 458 | + num_workers = min(batch_size, 16) |
| 459 | + dataset = DatasetForReranker(sentences_sorted, |
| 460 | + self.model_name_or_path, |
| 461 | + max_length, |
| 462 | + cache_dir=self.cache_dir, |
| 463 | + prompt=prompt) |
| 464 | + dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False, |
| 465 | + num_workers=num_workers, |
| 466 | + collate_fn=collater(self.tokenizer, max_length)) |
397 | 467 |
|
398 | 468 | all_scores = [] |
399 | | - for inputs in tqdm(dataloader): |
400 | | - inputs = inputs.to(self.device) |
401 | | - |
402 | | - outputs = self.model(**inputs, output_hidden_states=True, cutoff_layers=cutoff_layers) |
403 | | - all_logits = outputs.logits |
404 | | - tmp_all_scores = [] |
405 | | - for logits in all_logits: |
406 | | - scores = last_logit_pool_layerwise(logits, inputs['attention_mask']) |
407 | | - tmp_all_scores.append(scores.contiguous()) |
408 | | - |
409 | | - if len(all_scores) == 0: |
410 | | - for _ in range(len(tmp_all_scores)): |
411 | | - all_scores.append([]) |
412 | | - |
413 | | - for i in range(len(tmp_all_scores)): |
414 | | - all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist()) |
| 469 | + if dataloader is not None: |
| 470 | + for inputs in tqdm(dataloader): |
| 471 | + inputs = inputs.to(self.device) |
| 472 | + |
| 473 | + outputs = self.model(**inputs, output_hidden_states=True, cutoff_layers=cutoff_layers) |
| 474 | + all_logits = outputs.logits |
| 475 | + tmp_all_scores = [] |
| 476 | + for logits in all_logits: |
| 477 | + scores = last_logit_pool_layerwise(logits, inputs['attention_mask']) |
| 478 | + tmp_all_scores.append(scores.contiguous()) |
| 479 | + |
| 480 | + if len(all_scores) == 0: |
| 481 | + for _ in range(len(tmp_all_scores)): |
| 482 | + all_scores.append([]) |
| 483 | + |
| 484 | + for i in range(len(tmp_all_scores)): |
| 485 | + all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist()) |
| 486 | + else: |
| 487 | + if prompt is None: |
| 488 | + prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." |
| 489 | + prompt_inputs = self.tokenizer(prompt, |
| 490 | + return_tensors=None, |
| 491 | + add_special_tokens=False)['input_ids'] |
| 492 | + sep = "\n" |
| 493 | + sep_inputs = self.tokenizer(sep, |
| 494 | + return_tensors=None, |
| 495 | + add_special_tokens=False)['input_ids'] |
| 496 | + encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs) |
| 497 | + for batch_start in trange(0, len(sentences_sorted), batch_size): |
| 498 | + batch_sentences = sentences_sorted[batch_start:batch_start + batch_size] |
| 499 | + batch_sentences = [(f'A: {q}', f'B: {p}') for q, p in batch_sentences] |
| 500 | + queries = [s[0] for s in batch_sentences] |
| 501 | + passages = [s[1] for s in batch_sentences] |
| 502 | + queries_inputs = self.tokenizer(queries, |
| 503 | + return_tensors=None, |
| 504 | + add_special_tokens=False, |
| 505 | + max_length=max_length * 3 // 4, |
| 506 | + truncation=True) |
| 507 | + passages_inputs = self.tokenizer(passages, |
| 508 | + return_tensors=None, |
| 509 | + add_special_tokens=False, |
| 510 | + max_length=max_length, |
| 511 | + truncation=True) |
| 512 | + |
| 513 | + batch_inputs = [] |
| 514 | + for query_inputs, passage_inputs in zip(queries_inputs['input_ids'], passages_inputs['input_ids']): |
| 515 | + item = self.tokenizer.prepare_for_model( |
| 516 | + [self.tokenizer.bos_token_id] + query_inputs, |
| 517 | + sep_inputs + passage_inputs, |
| 518 | + truncation='only_second', |
| 519 | + max_length=encode_max_length, |
| 520 | + padding=False, |
| 521 | + return_attention_mask=False, |
| 522 | + return_token_type_ids=False, |
| 523 | + add_special_tokens=False |
| 524 | + ) |
| 525 | + item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs |
| 526 | + item['attention_mask'] = [1] * len(item['input_ids']) |
| 527 | + item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None |
| 528 | + if 'position_ids' in item.keys(): |
| 529 | + item['position_ids'] = list(range(len(item['input_ids']))) |
| 530 | + batch_inputs.append(item) |
| 531 | + |
| 532 | + collater_instance = collater(self.tokenizer, max_length) |
| 533 | + batch_inputs = collater_instance( |
| 534 | + [{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in |
| 535 | + batch_inputs]) |
| 536 | + |
| 537 | + batch_inputs = {key: val.to(self.device) for key, val in batch_inputs.items()} |
| 538 | + |
| 539 | + outputs = self.model(**batch_inputs, output_hidden_states=True, cutoff_layers=cutoff_layers) |
| 540 | + all_logits = outputs.logits |
| 541 | + tmp_all_scores = [] |
| 542 | + for logits in all_logits: |
| 543 | + scores = last_logit_pool_layerwise(logits, batch_inputs['attention_mask']) |
| 544 | + tmp_all_scores.append(scores.contiguous()) |
| 545 | + |
| 546 | + if len(all_scores) == 0: |
| 547 | + for _ in range(len(tmp_all_scores)): |
| 548 | + all_scores.append([]) |
| 549 | + |
| 550 | + for i in range(len(tmp_all_scores)): |
| 551 | + all_scores[i].extend(tmp_all_scores[i].cpu().float().tolist()) |
415 | 552 |
|
416 | 553 | for i in range(len(all_scores)): |
417 | 554 | all_scores[i] = [all_scores[i][idx] for idx in np.argsort(length_sorted_idx)] |
|
0 commit comments