1- from typing import cast , List , Union , Tuple
2-
1+ from typing import cast , List , Union
32import numpy as np
4- import torch
53from tqdm import tqdm
4+ from transformers import AutoModel , AutoTokenizer , is_torch_npu_available
5+ import torch
66from torch import Tensor
7- from transformers import AutoModel , AutoTokenizer , AutoModelForSequenceClassification , is_torch_npu_available
8-
97import torch .nn .functional as F
108
119
@@ -23,6 +21,7 @@ def last_token_pool(last_hidden_states: Tensor,
2321def get_detailed_instruct (task_description : str , query : str ) -> str :
2422 return f'<instruct>{ task_description } \n <query>{ query } '
2523
24+
2625def get_detailed_example (task_description : str , query : str , response : str ) -> str :
2726 return f'<instruct>{ task_description } \n <query>{ query } \n <response>{ response } '
2827
@@ -98,7 +97,6 @@ def set_examples(self, examples_for_task: List[dict] = None):
9897 )
9998 self .prefix = '\n \n ' .join (eg_paris ) + '\n \n '
10099
101-
102100 @torch .no_grad ()
103101 def encode_queries (self , queries : Union [List [str ], str ],
104102 batch_size : int = 256 ,
@@ -217,11 +215,11 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
217215
218216class FlagLLMModel :
219217 def __init__ (
220- self ,
221- model_name_or_path : str = None ,
222- normalize_embeddings : bool = True ,
223- query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passages that answer the query.' ,
224- use_fp16 : bool = True ,
218+ self ,
219+ model_name_or_path : str = None ,
220+ normalize_embeddings : bool = True ,
221+ query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passages that answer the query.' ,
222+ use_fp16 : bool = True
225223 ) -> None :
226224 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
227225 self .model = AutoModel .from_pretrained (model_name_or_path )
@@ -298,7 +296,7 @@ def encode(self,
298296 pad_to_multiple_of = 8 ,
299297 ).to (self .device )
300298 last_hidden_state = self .model (** inputs , return_dict = True ).last_hidden_state
301- embeddings = self . last_token_pool (last_hidden_state , inputs ['attention_mask' ])
299+ embeddings = last_token_pool (last_hidden_state , inputs ['attention_mask' ])
302300 if self .normalize_embeddings :
303301 embeddings = torch .nn .functional .normalize (embeddings , dim = - 1 )
304302 embeddings = cast (torch .Tensor , embeddings )
@@ -316,31 +314,16 @@ def encode(self,
316314 return all_embeddings [0 ]
317315 return all_embeddings
318316
319- def last_token_pool (self ,
320- last_hidden_state : torch .Tensor ,
321- attention_mask : torch .Tensor = None ):
322- left_padding = attention_mask [:, - 1 ].sum () == attention_mask .shape [0 ]
323- if left_padding :
324- return last_hidden_state [:, - 1 ]
325- else :
326- sequence_lengths = attention_mask .sum (dim = 1 ) - 1
327- batch_size = last_hidden_state .shape [0 ]
328- return last_hidden_state [
329- torch .arange (batch_size , device = last_hidden_state .device ),
330- sequence_lengths ,
331- ]
332-
333317
334318class FlagModel :
335319 def __init__ (
336- self ,
337- model_name_or_path : str = None ,
338- pooling_method : str = 'cls' ,
339- normalize_embeddings : bool = True ,
340- query_instruction_for_retrieval : str = None ,
341- use_fp16 : bool = True
320+ self ,
321+ model_name_or_path : str = None ,
322+ pooling_method : str = 'cls' ,
323+ normalize_embeddings : bool = True ,
324+ query_instruction_for_retrieval : str = None ,
325+ use_fp16 : bool = True
342326 ) -> None :
343-
344327 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
345328 self .model = AutoModel .from_pretrained (model_name_or_path )
346329 self .query_instruction_for_retrieval = query_instruction_for_retrieval
@@ -476,11 +459,11 @@ class LLMEmbedder:
476459 }
477460
478461 def __init__ (
479- self ,
480- model_name_or_path : str = None ,
481- pooling_method : str = 'cls' ,
482- normalize_embeddings : bool = True ,
483- use_fp16 : bool = True
462+ self ,
463+ model_name_or_path : str = None ,
464+ pooling_method : str = 'cls' ,
465+ normalize_embeddings : bool = True ,
466+ use_fp16 : bool = True
484467 ) -> None :
485468 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
486469 self .model = AutoModel .from_pretrained (model_name_or_path )
@@ -583,4 +566,3 @@ def pooling(self,
583566 return s / d
584567 else :
585568 raise NotImplementedError (f"Pooling method { self .pooling_method } not implemented!" )
586-
0 commit comments