Skip to content

Commit ccf5199

Browse files
committed
update model_class parameters in evaluation
- delete 'auto' option - support import `EmbedderModelClass` and `RerankerModelClass`
1 parent 2fd0457 commit ccf5199

6 files changed

Lines changed: 14 additions & 6 deletions

File tree

FlagEmbedding/abc/evaluation/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class AbsEvalModelArgs:
8181
metadata={"help": "The embedder name or path.", "required": True}
8282
)
8383
embedder_model_class: Optional[str] = field(
84-
default="auto", metadata={"help": "The embedder model class. Available classes: ['auto', 'encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: auto.", "choices": ["auto", "encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
84+
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
8585
)
8686
normalize_embeddings: bool = field(
8787
default=True, metadata={"help": "whether to normalize the embeddings"}
@@ -114,7 +114,7 @@ class AbsEvalModelArgs:
114114
default=None, metadata={"help": "The reranker name or path."}
115115
)
116116
reranker_model_class: Optional[str] = field(
117-
default="auto", metadata={"help": "The reranker model class. Available classes: ['auto', 'encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: auto.", "choices": ["auto", "encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
117+
default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
118118
)
119119
reranker_peft_path: Optional[str] = field(
120120
default=None, metadata={"help": "The reranker peft path."}

FlagEmbedding/inference/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22
from .auto_reranker import FlagAutoReranker
33
from .embedder import (
44
FlagModel, BGEM3FlagModel,
5-
FlagICLModel, FlagLLMModel
5+
FlagICLModel, FlagLLMModel,
6+
EmbedderModelClass
67
)
78
from .reranker import (
89
FlagReranker,
9-
FlagLLMReranker, LayerWiseFlagLLMReranker, LightWeightFlagLLMReranker
10+
FlagLLMReranker, LayerWiseFlagLLMReranker, LightWeightFlagLLMReranker,
11+
RerankerModelClass
1012
)
1113

1214

1315
__all__ = [
1416
"FlagAutoModel",
1517
"FlagAutoReranker",
18+
"EmbedderModelClass",
19+
"RerankerModelClass",
1620
"FlagModel",
1721
"BGEM3FlagModel",
1822
"FlagICLModel",

FlagEmbedding/inference/auto_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def from_finetuned(
3434
if model_name.startswith("checkpoint-"):
3535
model_name = os.path.basename(os.path.dirname(model_name_or_path))
3636

37-
if model_class is not None and model_class != 'auto':
37+
if model_class is not None:
3838
_model_class = EMBEDDER_CLASS_MAPPING[EmbedderModelClass(model_class)]
3939
if pooling_method is None:
4040
pooling_method = _model_class.DEFAULT_POOLING_METHOD

FlagEmbedding/inference/auto_reranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def from_finetuned(
3030
if model_name.startswith("checkpoint-"):
3131
model_name = os.path.basename(os.path.dirname(model_name_or_path))
3232

33-
if model_class is not None and model_class != 'auto':
33+
if model_class is not None:
3434
_model_class = RERANKER_CLASS_MAPPING[RerankerModelClass(model_class)]
3535
if trust_remote_code is None:
3636
trust_remote_code = False
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .encoder_only import FlagModel, BGEM3FlagModel
22
from .decoder_only import FlagICLModel, FlagLLMModel
3+
from .model_mapping import EmbedderModelClass
34

45
__all__ = [
56
"FlagModel",
67
"BGEM3FlagModel",
78
"FlagICLModel",
89
"FlagLLMModel",
10+
"EmbedderModelClass",
911
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .decoder_only import FlagLLMReranker, LayerWiseFlagLLMReranker, LightWeightFlagLLMReranker
22
from .encoder_only import FlagReranker
3+
from .model_mapping import RerankerModelClass
34

45
__all__ = [
56
"FlagReranker",
67
"FlagLLMReranker",
78
"LayerWiseFlagLLMReranker",
89
"LightWeightFlagLLMReranker",
10+
"RerankerModelClass",
911
]

0 commit comments

Comments
 (0)