Skip to content

Commit 46062d9

Browse files
authored
Update flag_reranker.py
1 parent c9512f9 commit 46062d9

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,14 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
282282
if use_dataloader:
283283
if num_workers is None:
284284
num_workers = min(batch_size, 16)
285-
dataset = DatasetForReranker(sentences_sorted,
286-
self.model_name_or_path,
287-
max_length,
288-
cache_dir=self.cache_dir,
289-
prompt=prompt)
290-
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
291-
num_workers=num_workers,
292-
collate_fn=collater(self.tokenizer, max_length))
285+
dataset = DatasetForReranker(sentences_sorted,
286+
self.model_name_or_path,
287+
max_length,
288+
cache_dir=self.cache_dir,
289+
prompt=prompt)
290+
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
291+
num_workers=num_workers,
292+
collate_fn=collater(self.tokenizer, max_length))
293293

294294
all_scores = []
295295
if dataloader is not None:
@@ -456,14 +456,14 @@ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str,
456456
if use_dataloader:
457457
if num_workers is None:
458458
num_workers = min(batch_size, 16)
459-
dataset = DatasetForReranker(sentences_sorted,
460-
self.model_name_or_path,
461-
max_length,
462-
cache_dir=self.cache_dir,
463-
prompt=prompt)
464-
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
465-
num_workers=num_workers,
466-
collate_fn=collater(self.tokenizer, max_length))
459+
dataset = DatasetForReranker(sentences_sorted,
460+
self.model_name_or_path,
461+
max_length,
462+
cache_dir=self.cache_dir,
463+
prompt=prompt)
464+
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=False,
465+
num_workers=num_workers,
466+
collate_fn=collater(self.tokenizer, max_length))
467467

468468
all_scores = []
469469
if dataloader is not None:
@@ -577,4 +577,4 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
577577
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
578578
return len(text)
579579
else:
580-
return sum([len(t) for t in text]) # Sum of length of individual strings
580+
return sum([len(t) for t in text]) # Sum of length of individual strings

0 commit comments

Comments
 (0)