File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -325,8 +325,11 @@ def _trans_state_dict(state_dict):
325325
326326 self .model .save_pretrained (output_dir , state_dict = _trans_state_dict (self .model .state_dict ()))
327327
328- torch .save (_trans_state_dict (self .colbert_linear .state_dict ()), os .path .join (output_dir , 'colbert_linear.pt' ))
329- torch .save (_trans_state_dict (self .sparse_linear .state_dict ()), os .path .join (output_dir , 'sparse_linear.pt' ))
328+ if self .unified_finetuning :
329+ torch .save (_trans_state_dict (self .colbert_linear .state_dict ()),
330+ os .path .join (output_dir , 'colbert_linear.pt' ))
331+ torch .save (_trans_state_dict (self .sparse_linear .state_dict ()),
332+ os .path .join (output_dir , 'sparse_linear.pt' ))
330333
331334 def load_pooler (self , model_dir ):
332335 colbert_state_dict = torch .load (os .path .join (model_dir , 'colbert_linear.pt' ), map_location = 'cpu' )
You can’t perform that action at this time.
0 commit comments