Skip to content

Commit 4750923

Browse files
committed
Merge branch 'new-flagembedding-v1' of https://github.com/hanhainebula/FlagEmbedding into new-flagembedding-v1
2 parents 42042c3 + 4b17829 commit 4750923

12 files changed

Lines changed: 570 additions & 173 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@ class AbsEmbedder(ABC):
1919
"""
2020
Base class for embedder.
2121
Extend this class and implement `encode_queries`, `encode_passages`, `encode` for custom embedders.
22+
23+
Args:
24+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
25+
load a model from HuggingFace Hub with the name.
26+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Default: `True`.
27+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
28+
degradation. Default: `True`.
29+
query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
30+
with `query_instruction_format`. Default: `None`.
31+
query_instruction_format: (str, optional): The template for `query_instruction_for_retrieval`. Default: `"{}{}"`.
32+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Default: `None`.
33+
batch_size (int, optional): Batch size for inference. Default: `256`.
34+
query_max_length (int, optional): Maximum length for query. Default: `512`.
35+
passage_max_length (int, optional): Maximum length for passage. Default: `512`.
36+
instruction (Optional[str], optional): Instruction for embedding. Default: `None`.
37+
instruction_format (str, optional): Instruction format when using `instruction`. Default: `"{}{}"`.
38+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
39+
Default: `True`.
40+
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
2241
"""
2342
def __init__(
2443
self,

Tutorials/1_Embedding/1.1_Intro&Inference.ipynb

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@
8383
"%pip install -U FlagEmbedding sentence_transformers openai cohere"
8484
]
8585
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": 1,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"import os \n",
93+
"os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'\n",
94+
"# single GPU is better for small tasks\n",
95+
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
96+
]
97+
},
8698
{
8799
"cell_type": "markdown",
88100
"metadata": {},
@@ -92,7 +104,7 @@
92104
},
93105
{
94106
"cell_type": "code",
95-
"execution_count": 4,
107+
"execution_count": 2,
96108
"metadata": {},
97109
"outputs": [],
98110
"source": [
@@ -138,19 +150,27 @@
138150
},
139151
{
140152
"cell_type": "code",
141-
"execution_count": 7,
153+
"execution_count": 4,
142154
"metadata": {},
143155
"outputs": [
156+
{
157+
"name": "stderr",
158+
"output_type": "stream",
159+
"text": [
160+
"initial target device: 100%|██████████| 8/8 [00:31<00:00, 3.89s/it]\n",
161+
"Chunks: 100%|██████████| 3/3 [00:04<00:00, 1.61s/it]\n"
162+
]
163+
},
144164
{
145165
"name": "stdout",
146166
"output_type": "stream",
147167
"text": [
148168
"Embeddings:\n",
149169
"(3, 768)\n",
150170
"Similarity scores:\n",
151-
"[[1. 0.7900386 0.57525384]\n",
152-
" [0.7900386 0.9999998 0.59190154]\n",
153-
" [0.57525384 0.59190154 0.99999994]]\n"
171+
"[[1. 0.79 0.575 ]\n",
172+
" [0.79 0.9995 0.592 ]\n",
173+
" [0.575 0.592 0.999 ]]\n"
154174
]
155175
}
156176
],
@@ -373,7 +393,7 @@
373393
],
374394
"metadata": {
375395
"kernelspec": {
376-
"display_name": "base",
396+
"display_name": "dev",
377397
"language": "python",
378398
"name": "python3"
379399
},
@@ -387,7 +407,7 @@
387407
"name": "python",
388408
"nbconvert_exporter": "python",
389409
"pygments_lexer": "ipython3",
390-
"version": "3.10.13"
410+
"version": "3.12.7"
391411
}
392412
},
393413
"nbformat": 4,

0 commit comments

Comments
 (0)