Skip to content

Commit 82b0d01

Browse files
authored
Merge pull request #1001 from hanhainebula/master
clean code for flag_models.py
2 parents 24f9aa9 + 1f87d60 commit 82b0d01

1 file changed

Lines changed: 21 additions & 39 deletions

File tree

FlagEmbedding/flag_models.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import cast, List, Union, Tuple
2-
1+
from typing import cast, List, Union
32
import numpy as np
4-
import torch
53
from tqdm import tqdm
4+
from transformers import AutoModel, AutoTokenizer, is_torch_npu_available
5+
import torch
66
from torch import Tensor
7-
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, is_torch_npu_available
8-
97
import torch.nn.functional as F
108

119

@@ -23,6 +21,7 @@ def last_token_pool(last_hidden_states: Tensor,
2321
def get_detailed_instruct(task_description: str, query: str) -> str:
2422
return f'<instruct>{task_description}\n<query>{query}'
2523

24+
2625
def get_detailed_example(task_description: str, query: str, response: str) -> str:
2726
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
2827

@@ -98,7 +97,6 @@ def set_examples(self, examples_for_task: List[dict] = None):
9897
)
9998
self.prefix = '\n\n'.join(eg_paris) + '\n\n'
10099

101-
102100
@torch.no_grad()
103101
def encode_queries(self, queries: Union[List[str], str],
104102
batch_size: int = 256,
@@ -217,11 +215,11 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
217215

218216
class FlagLLMModel:
219217
def __init__(
220-
self,
221-
model_name_or_path: str = None,
222-
normalize_embeddings: bool = True,
223-
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
224-
use_fp16: bool = True,
218+
self,
219+
model_name_or_path: str = None,
220+
normalize_embeddings: bool = True,
221+
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
222+
use_fp16: bool = True
225223
) -> None:
226224
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
227225
self.model = AutoModel.from_pretrained(model_name_or_path)
@@ -298,7 +296,7 @@ def encode(self,
298296
pad_to_multiple_of=8,
299297
).to(self.device)
300298
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
301-
embeddings = self.last_token_pool(last_hidden_state, inputs['attention_mask'])
299+
embeddings = last_token_pool(last_hidden_state, inputs['attention_mask'])
302300
if self.normalize_embeddings:
303301
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
304302
embeddings = cast(torch.Tensor, embeddings)
@@ -316,31 +314,16 @@ def encode(self,
316314
return all_embeddings[0]
317315
return all_embeddings
318316

319-
def last_token_pool(self,
320-
last_hidden_state: torch.Tensor,
321-
attention_mask: torch.Tensor = None):
322-
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
323-
if left_padding:
324-
return last_hidden_state[:, -1]
325-
else:
326-
sequence_lengths = attention_mask.sum(dim=1) - 1
327-
batch_size = last_hidden_state.shape[0]
328-
return last_hidden_state[
329-
torch.arange(batch_size, device=last_hidden_state.device),
330-
sequence_lengths,
331-
]
332-
333317

334318
class FlagModel:
335319
def __init__(
336-
self,
337-
model_name_or_path: str = None,
338-
pooling_method: str = 'cls',
339-
normalize_embeddings: bool = True,
340-
query_instruction_for_retrieval: str = None,
341-
use_fp16: bool = True
320+
self,
321+
model_name_or_path: str = None,
322+
pooling_method: str = 'cls',
323+
normalize_embeddings: bool = True,
324+
query_instruction_for_retrieval: str = None,
325+
use_fp16: bool = True
342326
) -> None:
343-
344327
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
345328
self.model = AutoModel.from_pretrained(model_name_or_path)
346329
self.query_instruction_for_retrieval = query_instruction_for_retrieval
@@ -476,11 +459,11 @@ class LLMEmbedder:
476459
}
477460

478461
def __init__(
479-
self,
480-
model_name_or_path: str = None,
481-
pooling_method: str = 'cls',
482-
normalize_embeddings: bool = True,
483-
use_fp16: bool = True
462+
self,
463+
model_name_or_path: str = None,
464+
pooling_method: str = 'cls',
465+
normalize_embeddings: bool = True,
466+
use_fp16: bool = True
484467
) -> None:
485468
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
486469
self.model = AutoModel.from_pretrained(model_name_or_path)
@@ -583,4 +566,3 @@ def pooling(self,
583566
return s / d
584567
else:
585568
raise NotImplementedError(f"Pooling method {self.pooling_method} not implemented!")
586-

0 commit comments

Comments
 (0)