diff --git a/FlagEmbedding/inference/reranker/decoder_only/base.py b/FlagEmbedding/inference/reranker/decoder_only/base.py index 4d5b26ec..7fcbb0fa 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/base.py +++ b/FlagEmbedding/inference/reranker/decoder_only/base.py @@ -297,7 +297,8 @@ def compute_score_single_gpu( device = self.target_devices[0] if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16: + self.model.half() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py index 4b75da36..410c55ac 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py +++ b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py @@ -179,7 +179,8 @@ def compute_score_single_gpu( device = self.target_devices[0] if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16: + self.model.half() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 000478af..297d5704 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -258,7 +258,8 @@ def compute_score_single_gpu( device = self.target_devices[0] if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16: + self.model.half() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/reranker/encoder_only/base.py b/FlagEmbedding/inference/reranker/encoder_only/base.py index 1a4d8b6a..87e42b9e 100644 --- a/FlagEmbedding/inference/reranker/encoder_only/base.py +++ b/FlagEmbedding/inference/reranker/encoder_only/base.py @@ -110,8 +110,9 @@ def compute_score_single_gpu( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + use_fp16 = self.use_fp16 and device != "cpu" + if use_fp16 and next(self.model.parameters()).dtype != torch.float16: + self.model.half() self.model.to(device) self.model.eval() diff --git a/research/Matroyshka_reranker/inference/rank_model.py b/research/Matroyshka_reranker/inference/rank_model.py index bbdca4ce..2cd1c063 100644 --- a/research/Matroyshka_reranker/inference/rank_model.py +++ b/research/Matroyshka_reranker/inference/rank_model.py @@ -203,7 +203,8 @@ def compute_score_single_gpu( device = self.target_devices[0] if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if self.use_fp16 and next(self.model.parameters()).dtype != torch.float16: + self.model.half() self.model.to(device) self.model.eval()