88from FlagEmbedding .finetune .reranker .decoder_only .layerwise .arguments import RerankerModelArguments
99
1010from .modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM , LayerWiseHead
11+ from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
1112
1213logger = logging .getLogger (__name__ )
1314
@@ -41,7 +42,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
4142 config = AutoConfig .from_pretrained (
4243 model_args .model_name_or_path ,
4344 trust_remote_code = model_args .trust_remote_code ,
44- token = model_args ,
45+ token = model_args . token ,
4546 cache_dir = model_args .cache_dir
4647 )
4748 else :
@@ -61,7 +62,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
6162 trust_remote_code = model_args .trust_remote_code ,
6263 # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
6364 use_flash_attention_2 = True if model_args .use_flash_attn else False ,
64- token = model_args ,
65+ token = model_args . token ,
6566 cache_dir = model_args .cache_dir ,
6667 from_tf = bool (".ckpt" in model_args .model_name_or_path ),
6768 config = config ,
@@ -115,7 +116,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
115116 model_args .model_name_or_path ,
116117 # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
117118 use_flash_attention_2 = True if model_args .use_flash_attn else False ,
118- token = model_args ,
119+ token = model_args . token ,
119120 cache_dir = model_args .cache_dir ,
120121 from_tf = bool (".ckpt" in model_args .model_name_or_path ),
121122 config = config ,
@@ -155,14 +156,14 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
155156 config = AutoConfig .from_pretrained (
156157 model_args .config_name ,
157158 trust_remote_code = model_args .trust_remote_code ,
158- token = model_args ,
159+ token = model_args . token ,
159160 cache_dir = model_args .cache_dir
160161 )
161162 elif model_args .model_name_or_path :
162163 config = AutoConfig .from_pretrained (
163164 model_args .model_name_or_path ,
164165 trust_remote_code = model_args .trust_remote_code ,
165- token = model_args ,
166+ token = model_args . token ,
166167 cache_dir = model_args .cache_dir
167168 )
168169 else :
@@ -172,19 +173,19 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
172173 config .use_cache = False
173174
174175 if model_args .model_type == 'from_raw_model' :
175- config = AutoConfig .from_pretrained ('BAAI/bge-reranker-v2-minicpm-layerwise' ,
176- cache_dir = model_args .cache_dir ,
177- token = model_args ,
178- trust_remote_code = model_args .trust_remote_code )
176+ config = LayerWiseMiniCPMConfig .from_pretrained ('BAAI/bge-reranker-v2-minicpm-layerwise' ,
177+ cache_dir = model_args .cache_dir ,
178+ token = model_args . token ,
179+ trust_remote_code = model_args .trust_remote_code )
179180 config .start_layer = model_args .start_layer
180181 config .head_multi = model_args .head_multi
181182 config .head_type = model_args .head_type
182183
183- model = AutoModelForCausalLM .from_pretrained (model_args .model_name_or_path ,
184- config = config ,
185- cache_dir = model_args .cache_dir ,
186- token = model_args ,
187- trust_remote_code = model_args .trust_remote_code )
184+ model = LayerWiseMiniCPMForCausalLM .from_pretrained (model_args .model_name_or_path ,
185+ config = config ,
186+ cache_dir = model_args .cache_dir ,
187+ token = model_args . token ,
188+ trust_remote_code = model_args .trust_remote_code )
188189
189190 if model_args .raw_peft is not None :
190191 for peft_path in model_args .raw_peft :
0 commit comments