Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion FlagEmbedding/inference/reranker/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def compute_score_single_gpu(
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16:
self.model.half()

self.model.to(device)
self.model.eval()
Expand Down
3 changes: 2 additions & 1 deletion FlagEmbedding/inference/reranker/decoder_only/layerwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def compute_score_single_gpu(
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16:
self.model.half()

self.model.to(device)
self.model.eval()
Expand Down
3 changes: 2 additions & 1 deletion FlagEmbedding/inference/reranker/decoder_only/lightweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def compute_score_single_gpu(
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16:
self.model.half()

self.model.to(device)
self.model.eval()
Expand Down
3 changes: 2 additions & 1 deletion FlagEmbedding/inference/reranker/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def compute_score_single_gpu(
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16:
self.model.half()

Comment thread
nedeadinside marked this conversation as resolved.
Outdated
self.model.to(device)
self.model.eval()
Expand Down
3 changes: 2 additions & 1 deletion research/Matroyshka_reranker/inference/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def compute_score_single_gpu(
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16:
self.model.half()

self.model.to(device)
self.model.eval()
Expand Down
Loading