Skip to content

Commit 58fecb4

Browse files
committed
update reranker trainer
1 parent 5cef26d commit 58fecb4

2 files changed

Lines changed: 22 additions & 22 deletions

File tree

FlagEmbedding/finetune/reranker/decoder_only/base/trainer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import logging
44
from typing import Optional
5-
from transformers.deepspeed import is_deepspeed_zero3_enabled
5+
# from transformers.deepspeed import is_deepspeed_zero3_enabled
66

77
from FlagEmbedding.abc.finetune.reranker import AbsRerankerTrainer
88
from peft import get_peft_model_state_dict
@@ -29,13 +29,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
2929

3030
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
3131

32-
if is_deepspeed_zero3_enabled():
33-
if state_dict is None:
34-
state_dict = self.model.state_dict()
35-
prefix = 'model.'
36-
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37-
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38-
lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
39-
if self.args.process_index <= 0:
40-
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
41-
print(f"Save adapter model at {output_dir}")
32+
# if is_deepspeed_zero3_enabled():
33+
# if state_dict is None:
34+
# state_dict = self.model.state_dict()
35+
# prefix = 'model.'
36+
# assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37+
# state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38+
# lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
39+
# if self.args.process_index <= 0:
40+
# torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
41+
# print(f"Save adapter model at {output_dir}")

FlagEmbedding/finetune/reranker/decoder_only/layerwise/trainer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import logging
44
from typing import Optional
5-
from transformers.deepspeed import is_deepspeed_zero3_enabled
5+
# from transformers.deepspeed import is_deepspeed_zero3_enabled
66
from peft import get_peft_model_state_dict
77

88
from FlagEmbedding.abc.finetune.reranker import AbsRerankerTrainer
@@ -29,13 +29,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
2929

3030
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
3131

32-
if is_deepspeed_zero3_enabled():
33-
if state_dict is None:
34-
state_dict = self.model.state_dict()
35-
prefix = 'model.'
36-
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37-
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38-
lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
39-
if self.args.process_index <= 0:
40-
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
41-
print(f"Save adapter model at {output_dir}")
32+
# if is_deepspeed_zero3_enabled():
33+
# if state_dict is None:
34+
# state_dict = self.model.state_dict()
35+
# prefix = 'model.'
36+
# assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37+
# state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38+
# lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
39+
# if self.args.process_index <= 0:
40+
# torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
41+
# print(f"Save adapter model at {output_dir}")

0 commit comments

Comments
 (0)