Skip to content

Commit 1b971d0

Browse files
authored
Merge pull request #1219 from hanhainebula/master
Fix bugs
2 parents f3f9800 + c9026b4 commit 1b971d0

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

FlagEmbedding/abc/evaluation/data_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
266266
Returns:
267267
datasets.DatasetDict: A dict of relevance of query and document.
268268
"""
269-
checked_split = self.check_splits(split)
269+
checked_split = self.check_splits(split, dataset_name=dataset_name)
270270
if len(checked_split) == 0:
271271
raise ValueError(f"Split {split} not found in the dataset.")
272272
split = checked_split[0]
@@ -301,7 +301,7 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
301301
Returns:
302302
datasets.DatasetDict: A dict of queries with id as key, query text as value.
303303
"""
304-
checked_split = self.check_splits(split)
304+
checked_split = self.check_splits(split, dataset_name=dataset_name)
305305
if len(checked_split) == 0:
306306
raise ValueError(f"Split {split} not found in the dataset.")
307307
split = checked_split[0]

FlagEmbedding/abc/finetune/embedder/AbsDataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ def _get_train_group_size(self, batch_raw_data):
426426
return 2, data_type
427427
elif data_type in ['symmetric_class']:
428428
return min(len(batch_raw_data['neg'][0]) + 1, self.args.train_group_size), data_type
429+
else:
430+
return self.args.train_group_size, data_type
429431
return self.args.train_group_size, None
430432

431433
def _create_batch_data(self, batch_raw_data):

FlagEmbedding/inference/auto_reranker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def from_finetuned(
5454
_model_class = RERANKER_CLASS_MAPPING[RerankerModelClass(model_class)]
5555
if trust_remote_code is None:
5656
trust_remote_code = False
57-
logging.warning(
58-
f"`trust_remote_code` is not specified, set to default value '{trust_remote_code}'."
59-
)
57+
logging.warning(
58+
f"`trust_remote_code` is not specified, set to default value '{trust_remote_code}'."
59+
)
6060
else:
6161
if model_name not in AUTO_RERANKER_MAPPING:
6262
raise ValueError(

0 commit comments

Comments
 (0)