Skip to content

Commit bae7503

Browse files
authored
Merge pull request #996 from hanhainebula/master
update flag_models.py to support new models
2 parents f5198e3 + e1dc2b6 commit bae7503

1 file changed

Lines changed: 118 additions & 2 deletions

File tree

FlagEmbedding/flag_models.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
self,
3333
model_name_or_path: str = None,
3434
normalize_embeddings: bool = True,
35-
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passage that answer the query.',
35+
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
3636
examples_for_task: List[dict] = None,
3737
use_fp16: bool = True
3838
) -> None:
@@ -215,6 +215,122 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
215215
return sum([len(t) for t in text]) # Sum of length of individual strings
216216

217217

218+
class FlagLLMModel:
219+
def __init__(
220+
self,
221+
model_name_or_path: str = None,
222+
normalize_embeddings: bool = True,
223+
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
224+
use_fp16: bool = True,
225+
) -> None:
226+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
227+
self.model = AutoModel.from_pretrained(model_name_or_path)
228+
self.query_instruction_for_retrieval = query_instruction_for_retrieval
229+
self.normalize_embeddings = normalize_embeddings
230+
231+
if torch.cuda.is_available():
232+
self.device = torch.device("cuda")
233+
elif torch.backends.mps.is_available():
234+
self.device = torch.device("mps")
235+
elif is_torch_npu_available():
236+
self.device = torch.device("npu")
237+
else:
238+
self.device = torch.device("cpu")
239+
use_fp16 = False
240+
if use_fp16: self.model.half()
241+
self.model = self.model.to(self.device)
242+
243+
self.num_gpus = torch.cuda.device_count()
244+
if self.num_gpus > 1:
245+
print(f"----------using {self.num_gpus}*GPUs----------")
246+
self.model = torch.nn.DataParallel(self.model)
247+
248+
def encode_queries(self, queries: Union[List[str], str],
249+
batch_size: int = 256,
250+
max_length: int = 512,
251+
convert_to_numpy: bool = True) -> np.ndarray:
252+
'''
253+
This function will be used for retrieval task
254+
if there is a instruction for queries, we will add it to the query text
255+
'''
256+
if isinstance(queries, str):
257+
input_texts = get_detailed_instruct(self.query_instruction_for_retrieval, queries)
258+
else:
259+
input_texts = [get_detailed_instruct(self.query_instruction_for_retrieval, q) for q in queries]
260+
return self.encode(input_texts, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
261+
262+
def encode_corpus(self,
263+
corpus: Union[List[str], str],
264+
batch_size: int = 256,
265+
max_length: int = 512,
266+
convert_to_numpy: bool = True) -> np.ndarray:
267+
'''
268+
This function will be used for retrieval task
269+
encode corpus for retrieval task
270+
'''
271+
return self.encode(corpus, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
272+
273+
@torch.no_grad()
274+
def encode(self,
275+
sentences: Union[List[str], str],
276+
batch_size: int = 256,
277+
max_length: int = 512,
278+
convert_to_numpy: bool = True) -> np.ndarray:
279+
if self.num_gpus > 0:
280+
batch_size = batch_size * self.num_gpus
281+
self.model.eval()
282+
283+
input_was_string = False
284+
if isinstance(sentences, str):
285+
sentences = [sentences]
286+
input_was_string = True
287+
288+
all_embeddings = []
289+
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
290+
disable=len(sentences) < 256):
291+
sentences_batch = sentences[start_index:start_index + batch_size]
292+
inputs = self.tokenizer(
293+
sentences_batch,
294+
padding=True,
295+
truncation=True,
296+
return_tensors='pt',
297+
max_length=max_length,
298+
pad_to_multiple_of=8,
299+
).to(self.device)
300+
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
301+
embeddings = self.last_token_pool(last_hidden_state, inputs['attention_mask'])
302+
if self.normalize_embeddings:
303+
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
304+
embeddings = cast(torch.Tensor, embeddings)
305+
306+
if convert_to_numpy:
307+
embeddings = embeddings.cpu().numpy()
308+
all_embeddings.append(embeddings)
309+
310+
if convert_to_numpy:
311+
all_embeddings = np.concatenate(all_embeddings, axis=0)
312+
else:
313+
all_embeddings = torch.cat(all_embeddings, dim=0)
314+
315+
if input_was_string:
316+
return all_embeddings[0]
317+
return all_embeddings
318+
319+
def last_token_pool(self,
320+
last_hidden_state: torch.Tensor,
321+
attention_mask: torch.Tensor = None):
322+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
323+
if left_padding:
324+
return last_hidden_state[:, -1]
325+
else:
326+
sequence_lengths = attention_mask.sum(dim=1) - 1
327+
batch_size = last_hidden_state.shape[0]
328+
return last_hidden_state[
329+
torch.arange(batch_size, device=last_hidden_state.device),
330+
sequence_lengths,
331+
]
332+
333+
218334
class FlagModel:
219335
def __init__(
220336
self,
@@ -315,7 +431,7 @@ def encode(self,
315431
if convert_to_numpy:
316432
all_embeddings = np.concatenate(all_embeddings, axis=0)
317433
else:
318-
all_embeddings = torch.stack(all_embeddings)
434+
all_embeddings = torch.cat(all_embeddings, dim=0)
319435

320436
if input_was_string:
321437
return all_embeddings[0]

0 commit comments

Comments
 (0)