Skip to content

Commit 456899a

Browse files
committed
update reranker FT
1 parent 98d2621 commit 456899a

2 files changed

Lines changed: 16 additions & 15 deletions

File tree

  • FlagEmbedding

FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments
99

1010
from .modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM, LayerWiseHead
11+
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -41,7 +42,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
4142
config = AutoConfig.from_pretrained(
4243
model_args.model_name_or_path,
4344
trust_remote_code=model_args.trust_remote_code,
44-
token=model_args,
45+
token=model_args.token,
4546
cache_dir=model_args.cache_dir
4647
)
4748
else:
@@ -61,7 +62,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
6162
trust_remote_code=model_args.trust_remote_code,
6263
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
6364
use_flash_attention_2=True if model_args.use_flash_attn else False,
64-
token=model_args,
65+
token=model_args.token,
6566
cache_dir=model_args.cache_dir,
6667
from_tf=bool(".ckpt" in model_args.model_name_or_path),
6768
config=config,
@@ -115,7 +116,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
115116
model_args.model_name_or_path,
116117
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
117118
use_flash_attention_2=True if model_args.use_flash_attn else False,
118-
token=model_args,
119+
token=model_args.token,
119120
cache_dir=model_args.cache_dir,
120121
from_tf=bool(".ckpt" in model_args.model_name_or_path),
121122
config=config,
@@ -155,14 +156,14 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
155156
config = AutoConfig.from_pretrained(
156157
model_args.config_name,
157158
trust_remote_code=model_args.trust_remote_code,
158-
token=model_args,
159+
token=model_args.token,
159160
cache_dir=model_args.cache_dir
160161
)
161162
elif model_args.model_name_or_path:
162163
config = AutoConfig.from_pretrained(
163164
model_args.model_name_or_path,
164165
trust_remote_code=model_args.trust_remote_code,
165-
token=model_args,
166+
token=model_args.token,
166167
cache_dir=model_args.cache_dir
167168
)
168169
else:
@@ -172,19 +173,19 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
172173
config.use_cache = False
173174

174175
if model_args.model_type == 'from_raw_model':
175-
config = AutoConfig.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise',
176-
cache_dir=model_args.cache_dir,
177-
token=model_args,
178-
trust_remote_code=model_args.trust_remote_code)
176+
config = LayerWiseMiniCPMConfig.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise',
177+
cache_dir=model_args.cache_dir,
178+
token=model_args.token,
179+
trust_remote_code=model_args.trust_remote_code)
179180
config.start_layer = model_args.start_layer
180181
config.head_multi = model_args.head_multi
181182
config.head_type = model_args.head_type
182183

183-
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
184-
config=config,
185-
cache_dir=model_args.cache_dir,
186-
token=model_args,
187-
trust_remote_code=model_args.trust_remote_code)
184+
model = LayerWiseMiniCPMForCausalLM.from_pretrained(model_args.model_name_or_path,
185+
config=config,
186+
cache_dir=model_args.cache_dir,
187+
token=model_args.token,
188+
trust_remote_code=model_args.trust_remote_code)
188189

189190
if model_args.raw_peft is not None:
190191
for peft_path in model_args.raw_peft:

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def sigmoid(x):
11-
return 1 / (1 + np.exp(-x))
11+
return float(1 / (1 + np.exp(-x)))
1212

1313

1414
class BaseReranker(AbsReranker):

0 commit comments

Comments
 (0)