22import torch
33import logging
44from typing import Optional
5- from transformers .deepspeed import is_deepspeed_zero3_enabled
5+ # from transformers.deepspeed import is_deepspeed_zero3_enabled
66
77from FlagEmbedding .abc .finetune .reranker import AbsRerankerTrainer
88from peft import get_peft_model_state_dict
@@ -29,13 +29,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
2929
3030 torch .save (self .args , os .path .join (output_dir , "training_args.bin" ))
3131
32- if is_deepspeed_zero3_enabled ():
33- if state_dict is None :
34- state_dict = self .model .state_dict ()
35- prefix = 'model.'
36- assert all (k .startswith (prefix ) for k in state_dict .keys ()), list (state_dict .keys ())
37- state_dict = {k [len (prefix ):]: v for k , v in state_dict .items ()}
38- lora_state_dict = get_peft_model_state_dict (self .model .model , state_dict )
39- if self .args .process_index <= 0 :
40- torch .save (lora_state_dict , os .path .join (output_dir , "adapter_model.bin" ))
41- print (f"Save adapter model at { output_dir } " )
32+ # if is_deepspeed_zero3_enabled():
33+ # if state_dict is None:
34+ # state_dict = self.model.state_dict()
35+ # prefix = 'model.'
36+ # assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37+ # state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38+ # lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
39+ # if self.args.process_index <= 0:
40+ # torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
41+ # print(f"Save adapter model at {output_dir}")
0 commit comments