Skip to content

Commit 74d0d9d

Browse files
committed
auto embedder and reranker
1 parent f5c3891 commit 74d0d9d

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

FlagEmbedding/inference/auto_embedder.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212

1313
class FlagAutoModel:
14+
"""
15+
Automatically choose the appropriate class to load the embedding model.
16+
"""
1417
def __init__(self):
1518
raise EnvironmentError(
1619
"FlagAutoModel is designed to be instantiated using the `FlagAutoModel.from_finetuned(model_name_or_path)` method."
@@ -30,6 +33,30 @@ def from_finetuned(
3033
query_instruction_format: Optional[str] = None,
3134
**kwargs,
3235
):
36+
"""
37+
Load a finetuned model according to the provided vars.
38+
39+
Args:
40+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
41+
load a model from HuggingFace Hub with the name.
42+
model_class (Optional[Union[str, EmbedderModelClass]], optional): The embedder class to use. Defaults to :data:`None`.
43+
normalize_embeddings (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
44+
Defaults to :data:`True`.
45+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
46+
degradation. Defaults to :data:`True`.
47+
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
48+
:attr:`query_instruction_format`. Defaults to :data:`None`.
49+
devices (Optional[Union[str, List[str]]], optional): Devices to use for model inference. Defaults to :data:`None`.
50+
pooling_method (Optional[str], optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`None`.
51+
trust_remote_code (Optional[bool], optional): trust_remote_code for HF datasets or models. Defaults to :data:`None`.
52+
query_instruction_format (Optional[str], optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`None`.
53+
54+
Raises:
55+
ValueError
56+
57+
Returns:
58+
AbsEmbedder: The model class to load model, which is child class of :clsss:`AbsEmbedder`.
59+
"""
3360
model_name = os.path.basename(model_name_or_path)
3461
if model_name.startswith("checkpoint-"):
3562
model_name = os.path.basename(os.path.dirname(model_name_or_path))

FlagEmbedding/inference/auto_reranker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313

1414
class FlagAutoReranker:
15+
"""
16+
Automatically choose the appropriate class to load the reranker model.
17+
"""
1518
def __init__(self):
1619
raise EnvironmentError(
1720
"FlagAutoReranker is designed to be instantiated using the `FlagAutoReranker.from_finetuned(model_name_or_path)` method."
@@ -26,6 +29,23 @@ def from_finetuned(
2629
trust_remote_code: Optional[bool] = None,
2730
**kwargs,
2831
):
32+
"""
33+
Load a finetuned model according to the provided vars.
34+
35+
Args:
36+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
37+
load a model from HuggingFace Hub with the name.
38+
model_class (Optional[Union[str, RerankerModelClass]], optional): The reranker class to use.. Defaults to :data:`None`.
39+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
40+
degradation. Defaults to :data:`False`.
41+
trust_remote_code (Optional[bool], optional): trust_remote_code for HF datasets or models. Defaults to :data:`None`.
42+
43+
Raises:
44+
ValueError
45+
46+
Returns:
47+
AbsReranker: The reranker class to load model, which is child class of :clsss:`AbsReranker`.
48+
"""
2949
model_name = os.path.basename(model_name_or_path)
3050
if model_name.startswith("checkpoint-"):
3151
model_name = os.path.basename(os.path.dirname(model_name_or_path))

0 commit comments

Comments
 (0)