Skip to content

Commit 3f84da0

Browse files
authored
Update modeling.py
1 parent b768035 commit 3f84da0

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

FlagEmbedding/visual/modeling.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,21 @@ def __init__(self,
3535
):
3636
super().__init__()
3737

38-
assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"]
38+
assert 'bge' in model_name_bge
3939
assert model_weight is not None
4040

4141
self.model_name_bge = model_name_bge
4242

43-
if model_name_bge == 'BAAI/bge-base-en-v1.5':
43+
if 'bge-base-en-v1.5' in model_name_bge:
4444
model_name_eva = "EVA02-CLIP-B-16"
4545
self.hidden_dim = 768
4646
self.depth = 12
47-
elif model_name_bge == 'BAAI/bge-m3':
47+
elif 'bge-m3' in model_name_bge:
4848
model_name_eva = "EVA02-CLIP-L-14"
4949
self.hidden_dim = 1024
5050
self.depth = 24
51+
else:
52+
raise Exception(f'Unavailable model_name {model_name_bge}')
5153

5254
if not from_pretrained:
5355
bge_config = AutoConfig.from_pretrained(model_name_bge)

0 commit comments

Comments
 (0)