1111
1212
1313class FlagAutoModel :
14+ """
15+ Automatically choose the appropriate class to load the embedding model.
16+ """
1417 def __init__ (self ):
1518 raise EnvironmentError (
1619 "FlagAutoModel is designed to be instantiated using the `FlagAutoModel.from_finetuned(model_name_or_path)` method."
@@ -30,6 +33,30 @@ def from_finetuned(
3033 query_instruction_format : Optional [str ] = None ,
3134 ** kwargs ,
3235 ):
36+ """
37+ Load a finetuned model according to the provided vars.
38+
39+ Args:
40+ model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
41+ load a model from HuggingFace Hub with the name.
42+ model_class (Optional[Union[str, EmbedderModelClass]], optional): The embedder class to use. Defaults to :data:`None`.
43+ normalize_embeddings (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
44+ Defaults to :data:`True`.
45+ use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
46+ degradation. Defaults to :data:`True`.
47+ query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
48+ :attr:`query_instruction_format`. Defaults to :data:`None`.
49+ devices (Optional[Union[str, List[str]]], optional): Devices to use for model inference. Defaults to :data:`None`.
50+ pooling_method (Optional[str], optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`None`.
51+ trust_remote_code (Optional[bool], optional): trust_remote_code for HF datasets or models. Defaults to :data:`None`.
52+ query_instruction_format (Optional[str], optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`None`.
53+
54+ Raises:
55+ ValueError
56+
57+ Returns:
58+ AbsEmbedder: The model class to load model, which is child class of :clsss:`AbsEmbedder`.
59+ """
3360 model_name = os .path .basename (model_name_or_path )
3461 if model_name .startswith ("checkpoint-" ):
3562 model_name = os .path .basename (os .path .dirname (model_name_or_path ))
0 commit comments