Skip to content

Commit f5198e3

Browse files
authored
Merge pull request #995 from 545999961/master
upload embedder and reranker
2 parents ad08b9a + 5e0a2d7 commit f5198e3

3 files changed

Lines changed: 382 additions & 8 deletions

File tree

FlagEmbedding/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .flag_models import FlagModel, LLMEmbedder
1+
from .flag_models import FlagModel, LLMEmbedder, FlagICLModel
22
from .bge_m3 import BGEM3FlagModel
33
from .flag_reranker import FlagReranker, FlagLLMReranker, LayerWiseFlagLLMReranker

FlagEmbedding/flag_models.py

Lines changed: 211 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,217 @@
33
import numpy as np
44
import torch
55
from tqdm import tqdm
6+
from torch import Tensor
67
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, is_torch_npu_available
78

9+
import torch.nn.functional as F
10+
11+
12+
def last_token_pool(last_hidden_states: Tensor,
13+
attention_mask: Tensor) -> Tensor:
14+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
15+
if left_padding:
16+
return last_hidden_states[:, -1]
17+
else:
18+
sequence_lengths = attention_mask.sum(dim=1) - 1
19+
batch_size = last_hidden_states.shape[0]
20+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
21+
22+
23+
def get_detailed_instruct(task_description: str, query: str) -> str:
24+
return f'<instruct>{task_description}\n<query>{query}'
25+
26+
def get_detailed_example(task_description: str, query: str, response: str) -> str:
27+
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
28+
29+
30+
class FlagICLModel:
31+
def __init__(
32+
self,
33+
model_name_or_path: str = None,
34+
normalize_embeddings: bool = True,
35+
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passage that answer the query.',
36+
examples_for_task: List[dict] = None,
37+
use_fp16: bool = True
38+
) -> None:
39+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
40+
self.model = AutoModel.from_pretrained(model_name_or_path)
41+
self.query_instruction_for_retrieval = query_instruction_for_retrieval
42+
self.examples_for_task = examples_for_task
43+
44+
self.set_examples()
45+
self.suffix = '\n<response>'
46+
47+
self.normalize_embeddings = normalize_embeddings
48+
49+
if torch.cuda.is_available():
50+
self.device = torch.device("cuda")
51+
elif torch.backends.mps.is_available():
52+
self.device = torch.device("mps")
53+
else:
54+
self.device = torch.device("cpu")
55+
self.model.half()
56+
self.model = self.model.to(self.device)
57+
58+
if torch.cuda.is_available():
59+
self.device = torch.device("cuda")
60+
elif torch.backends.mps.is_available():
61+
self.device = torch.device("mps")
62+
elif is_torch_npu_available():
63+
self.device = torch.device("npu")
64+
else:
65+
self.device = torch.device("cpu")
66+
use_fp16 = False
67+
if use_fp16: self.model.half()
68+
self.model = self.model.to(self.device)
69+
70+
self.num_gpus = torch.cuda.device_count()
71+
if self.num_gpus > 1:
72+
print(f"----------using {self.num_gpus}*GPUs----------")
73+
self.model = torch.nn.DataParallel(self.model)
74+
75+
def set_examples(self, examples_for_task: List[dict] = None):
76+
if examples_for_task is None and self.examples_for_task is None:
77+
self.prefix = ''
78+
elif examples_for_task is not None:
79+
eg_paris = []
80+
for i in range(len(examples_for_task)):
81+
eg_paris.append(
82+
get_detailed_example(
83+
examples_for_task[i].get('instruct', self.query_instruction_for_retrieval),
84+
examples_for_task[i].get('query', ''),
85+
examples_for_task[i].get('response', '')
86+
)
87+
)
88+
self.prefix = '\n\n'.join(eg_paris) + '\n\n'
89+
else:
90+
eg_paris = []
91+
for i in range(len(self.examples_for_task)):
92+
eg_paris.append(
93+
get_detailed_example(
94+
self.examples_for_task[i].get('instruct', self.query_instruction_for_retrieval),
95+
self.examples_for_task[i].get('query', ''),
96+
self.examples_for_task[i].get('response', '')
97+
)
98+
)
99+
self.prefix = '\n\n'.join(eg_paris) + '\n\n'
100+
101+
102+
@torch.no_grad()
103+
def encode_queries(self, queries: Union[List[str], str],
104+
batch_size: int = 256,
105+
max_length: int = 512) -> np.ndarray:
106+
self.model.eval()
107+
'''
108+
This function will be used for retrieval task
109+
if there is a instruction for queries, we will add it to the query text
110+
'''
111+
if isinstance(queries, str):
112+
sentences = [get_detailed_instruct(self.query_instruction_for_retrieval, queries)]
113+
else:
114+
sentences = [get_detailed_instruct(self.query_instruction_for_retrieval, q) for q in queries]
115+
116+
prefix_ids = self.tokenizer(self.prefix, add_special_tokens=False)['input_ids']
117+
suffix_ids = self.tokenizer(self.suffix, add_special_tokens=False)['input_ids']
118+
119+
all_embeddings = []
120+
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
121+
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
122+
123+
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
124+
disable=len(sentences_sorted) < 256):
125+
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
126+
inputs = self.tokenizer(
127+
sentences_batch,
128+
max_length=max_length - len(self.tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
129+
self.tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
130+
return_token_type_ids=False,
131+
truncation=True,
132+
return_tensors=None,
133+
add_special_tokens=False
134+
)
135+
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length) // 8 * 8 + 8
136+
sentences_batch = self.tokenizer.batch_decode(inputs['input_ids'])
137+
for i in range(len(sentences_batch)):
138+
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
139+
inputs = self.tokenizer(
140+
sentences_batch,
141+
padding=True,
142+
truncation=True,
143+
return_tensors='pt',
144+
max_length=new_max_length,
145+
add_special_tokens=True
146+
).to(self.device)
147+
148+
outputs = self.model(**inputs, return_dict=True)
149+
embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
150+
151+
if self.normalize_embeddings:
152+
embeddings = F.normalize(embeddings, p=2, dim=1)
153+
all_embeddings.extend(embeddings.float().cpu())
154+
155+
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
156+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
157+
return all_embeddings
158+
159+
@torch.no_grad()
160+
def encode_corpus(self,
161+
corpus: Union[List[str], str],
162+
batch_size: int = 256,
163+
max_length: int = 512) -> np.ndarray:
164+
'''
165+
This function will be used for retrieval task
166+
encode corpus for retrieval task
167+
'''
168+
self.model.eval()
169+
170+
if isinstance(corpus, str):
171+
sentences = [corpus]
172+
else:
173+
sentences = corpus
174+
175+
all_embeddings = []
176+
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
177+
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
178+
179+
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
180+
disable=len(sentences_sorted) < 256):
181+
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
182+
inputs = self.tokenizer(
183+
sentences_batch,
184+
padding=True,
185+
truncation=True,
186+
return_tensors='pt',
187+
max_length=max_length,
188+
add_special_tokens=True
189+
).to(self.device)
190+
outputs = self.model(**inputs, return_dict=True)
191+
embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
192+
193+
if self.normalize_embeddings:
194+
embeddings = F.normalize(embeddings, p=2, dim=1)
195+
all_embeddings.extend(embeddings.float().cpu())
196+
197+
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
198+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
199+
return all_embeddings
200+
201+
def _text_length(self, text: Union[List[int], List[List[int]]]):
202+
"""
203+
Help function to get the length for the input text. Text can be either
204+
a list of ints (which means a single text as input), or a tuple of list of ints
205+
(representing several text inputs to the model).
206+
"""
207+
208+
if isinstance(text, dict): # {key: value} case
209+
return len(next(iter(text.values())))
210+
elif not hasattr(text, '__len__'): # Object has no len() method
211+
return 1
212+
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
213+
return len(text)
214+
else:
215+
return sum([len(t) for t in text]) # Sum of length of individual strings
216+
8217

9218
class FlagModel:
10219
def __init__(
@@ -185,7 +394,7 @@ def encode_queries(self, queries: Union[List[str], str],
185394
max_length: int = 256,
186395
task: str = 'qa') -> np.ndarray:
187396
'''
188-
Encode queries into dense vectors.
397+
Encode queries into dense vectors.
189398
Automatically add instructions according to given task.
190399
'''
191400
instruction = self.instructions[task]["query"]
@@ -202,7 +411,7 @@ def encode_keys(self, keys: Union[List[str], str],
202411
max_length: int = 512,
203412
task: str = 'qa') -> np.ndarray:
204413
'''
205-
Encode keys into dense vectors.
414+
Encode keys into dense vectors.
206415
Automatically add instructions according to given task.
207416
'''
208417
instruction = self.instructions[task]["key"]

0 commit comments

Comments
 (0)