Skip to content

Commit 7c251f6

Browse files
committed
update trainer
1 parent 17443f6 commit 7c251f6

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

FlagEmbedding/abc/finetune/embedder/AbsTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class AbsEmbedderTrainer(ABC, Trainer):
1111
def _save(self, output_dir: Optional[str] = None, state_dict=None):
1212
pass
1313

14-
def compute_loss(self, model, inputs, return_outputs=False):
14+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
1515
"""
1616
How the loss is computed by Trainer. By default, all models return the loss in the first element.
1717

FlagEmbedding/abc/finetune/reranker/AbsTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class AbsRerankerTrainer(ABC, Trainer):
1111
def _save(self, output_dir: Optional[str] = None, state_dict=None):
1212
pass
1313

14-
def compute_loss(self, model, inputs, return_outputs=False):
14+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
1515
"""
1616
How the loss is computed by Trainer. By default, all models return the loss in the first element.
1717

0 commit comments

Comments
 (0)