Skip to content

Commit 1d4a772

Browse files
authored
fixed bugs
1 parent 107cfbe commit 1d4a772

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

FlagEmbedding/visual/modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self,
100100
self.to(self.device)
101101
else:
102102
self.device = torch.device('cpu')
103+
self.dtype = next(bge.parameters()).dtype
103104

104105
def load_model(self, model_weight):
105106
self.load_state_dict(torch.load(model_weight, map_location='cpu'))
@@ -191,7 +192,7 @@ def encode_text(self, texts):
191192
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
192193

193194
head_mask = [None] * self.depth
194-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
195+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape).to(self.dtype)
195196

196197
embedding_output = self.bge_embeddings(
197198
input_ids=input_ids,
@@ -270,7 +271,7 @@ def encode_mm(self, images:torch.Tensor, texts):
270271
prom_img_input_shape = prompt_img_embedding.size()
271272

272273
head_mask = [None] * self.depth
273-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(prom_img_attention_mask, prom_img_input_shape).to(prompt_img_embedding.dtype)
274+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(prom_img_attention_mask, prom_img_input_shape).to(self.dtype)
274275

275276

276277
encoder_outputs = self.bge_encoder(

0 commit comments

Comments
 (0)