Skip to content

Commit 842c130

Browse files
committed
update reranker model
1 parent a90ee45 commit 842c130

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,16 @@ def __init__(
316316
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
317317
cache_dir=cache_dir,
318318
trust_remote_code=True)
319-
if use_bf16:
320-
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
321-
cache_dir=cache_dir,
322-
trust_remote_code=True,
323-
torch_dtype=torch.bfloat16)
324-
else:
325-
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
326-
cache_dir=cache_dir,
327-
trust_remote_code=True,
328-
use_flash_attention_2=False)
319+
320+
if use_bf16 is False and use_fp16 is False:
321+
warnings.warn("Due to model constraints, `use_bf16` and `use_fp16` cannot both be `False`. Here, `use_fp16` is set to `True` by default.", UserWarning)
322+
use_fp16 = True
323+
324+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
325+
cache_dir=cache_dir,
326+
trust_remote_code=True,
327+
torch_dtype=torch.bfloat16 if use_bf16 else False)
328+
329329
self.model_name_or_path = model_name_or_path
330330
self.cache_dir = cache_dir
331331

0 commit comments

Comments
 (0)