33import numpy as np
44import torch
55from tqdm import tqdm
6+ from torch import Tensor
67from transformers import AutoModel , AutoTokenizer , AutoModelForSequenceClassification , is_torch_npu_available
78
9+ import torch .nn .functional as F
10+
11+
12+ def last_token_pool (last_hidden_states : Tensor ,
13+ attention_mask : Tensor ) -> Tensor :
14+ left_padding = (attention_mask [:, - 1 ].sum () == attention_mask .shape [0 ])
15+ if left_padding :
16+ return last_hidden_states [:, - 1 ]
17+ else :
18+ sequence_lengths = attention_mask .sum (dim = 1 ) - 1
19+ batch_size = last_hidden_states .shape [0 ]
20+ return last_hidden_states [torch .arange (batch_size , device = last_hidden_states .device ), sequence_lengths ]
21+
22+
23+ def get_detailed_instruct (task_description : str , query : str ) -> str :
24+ return f'<instruct>{ task_description } \n <query>{ query } '
25+
26+ def get_detailed_example (task_description : str , query : str , response : str ) -> str :
27+ return f'<instruct>{ task_description } \n <query>{ query } \n <response>{ response } '
28+
29+
30+ class FlagICLModel :
31+ def __init__ (
32+ self ,
33+ model_name_or_path : str = None ,
34+ normalize_embeddings : bool = True ,
35+ query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passage that answer the query.' ,
36+ examples_for_task : List [dict ] = None ,
37+ use_fp16 : bool = True
38+ ) -> None :
39+ self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
40+ self .model = AutoModel .from_pretrained (model_name_or_path )
41+ self .query_instruction_for_retrieval = query_instruction_for_retrieval
42+ self .examples_for_task = examples_for_task
43+
44+ self .set_examples ()
45+ self .suffix = '\n <response>'
46+
47+ self .normalize_embeddings = normalize_embeddings
48+
49+ if torch .cuda .is_available ():
50+ self .device = torch .device ("cuda" )
51+ elif torch .backends .mps .is_available ():
52+ self .device = torch .device ("mps" )
53+ else :
54+ self .device = torch .device ("cpu" )
55+ self .model .half ()
56+ self .model = self .model .to (self .device )
57+
58+ if torch .cuda .is_available ():
59+ self .device = torch .device ("cuda" )
60+ elif torch .backends .mps .is_available ():
61+ self .device = torch .device ("mps" )
62+ elif is_torch_npu_available ():
63+ self .device = torch .device ("npu" )
64+ else :
65+ self .device = torch .device ("cpu" )
66+ use_fp16 = False
67+ if use_fp16 : self .model .half ()
68+ self .model = self .model .to (self .device )
69+
70+ self .num_gpus = torch .cuda .device_count ()
71+ if self .num_gpus > 1 :
72+ print (f"----------using { self .num_gpus } *GPUs----------" )
73+ self .model = torch .nn .DataParallel (self .model )
74+
75+ def set_examples (self , examples_for_task : List [dict ] = None ):
76+ if examples_for_task is None and self .examples_for_task is None :
77+ self .prefix = ''
78+ elif examples_for_task is not None :
79+ eg_paris = []
80+ for i in range (len (examples_for_task )):
81+ eg_paris .append (
82+ get_detailed_example (
83+ examples_for_task [i ].get ('instruct' , self .query_instruction_for_retrieval ),
84+ examples_for_task [i ].get ('query' , '' ),
85+ examples_for_task [i ].get ('response' , '' )
86+ )
87+ )
88+ self .prefix = '\n \n ' .join (eg_paris ) + '\n \n '
89+ else :
90+ eg_paris = []
91+ for i in range (len (self .examples_for_task )):
92+ eg_paris .append (
93+ get_detailed_example (
94+ self .examples_for_task [i ].get ('instruct' , self .query_instruction_for_retrieval ),
95+ self .examples_for_task [i ].get ('query' , '' ),
96+ self .examples_for_task [i ].get ('response' , '' )
97+ )
98+ )
99+ self .prefix = '\n \n ' .join (eg_paris ) + '\n \n '
100+
101+
102+ @torch .no_grad ()
103+ def encode_queries (self , queries : Union [List [str ], str ],
104+ batch_size : int = 256 ,
105+ max_length : int = 512 ) -> np .ndarray :
106+ self .model .eval ()
107+ '''
108+ This function will be used for retrieval task
109+ if there is a instruction for queries, we will add it to the query text
110+ '''
111+ if isinstance (queries , str ):
112+ sentences = [get_detailed_instruct (self .query_instruction_for_retrieval , queries )]
113+ else :
114+ sentences = [get_detailed_instruct (self .query_instruction_for_retrieval , q ) for q in queries ]
115+
116+ prefix_ids = self .tokenizer (self .prefix , add_special_tokens = False )['input_ids' ]
117+ suffix_ids = self .tokenizer (self .suffix , add_special_tokens = False )['input_ids' ]
118+
119+ all_embeddings = []
120+ length_sorted_idx = np .argsort ([- self ._text_length (sen ) for sen in sentences ])
121+ sentences_sorted = [sentences [idx ] for idx in length_sorted_idx ]
122+
123+ for start_index in tqdm (range (0 , len (sentences_sorted ), batch_size ), desc = "Inference Embeddings" ,
124+ disable = len (sentences_sorted ) < 256 ):
125+ sentences_batch = sentences_sorted [start_index :start_index + batch_size ]
126+ inputs = self .tokenizer (
127+ sentences_batch ,
128+ max_length = max_length - len (self .tokenizer ('<s>' , add_special_tokens = False )['input_ids' ]) - len (
129+ self .tokenizer ('\n <response></s>' , add_special_tokens = False )['input_ids' ]),
130+ return_token_type_ids = False ,
131+ truncation = True ,
132+ return_tensors = None ,
133+ add_special_tokens = False
134+ )
135+ new_max_length = (len (prefix_ids ) + len (suffix_ids ) + max_length ) // 8 * 8 + 8
136+ sentences_batch = self .tokenizer .batch_decode (inputs ['input_ids' ])
137+ for i in range (len (sentences_batch )):
138+ sentences_batch [i ] = self .prefix + sentences_batch [i ] + self .suffix
139+ inputs = self .tokenizer (
140+ sentences_batch ,
141+ padding = True ,
142+ truncation = True ,
143+ return_tensors = 'pt' ,
144+ max_length = new_max_length ,
145+ add_special_tokens = True
146+ ).to (self .device )
147+
148+ outputs = self .model (** inputs , return_dict = True )
149+ embeddings = last_token_pool (outputs .last_hidden_state , inputs ['attention_mask' ])
150+
151+ if self .normalize_embeddings :
152+ embeddings = F .normalize (embeddings , p = 2 , dim = 1 )
153+ all_embeddings .extend (embeddings .float ().cpu ())
154+
155+ all_embeddings = [all_embeddings [idx ] for idx in np .argsort (length_sorted_idx )]
156+ all_embeddings = np .asarray ([emb .numpy () for emb in all_embeddings ])
157+ return all_embeddings
158+
159+ @torch .no_grad ()
160+ def encode_corpus (self ,
161+ corpus : Union [List [str ], str ],
162+ batch_size : int = 256 ,
163+ max_length : int = 512 ) -> np .ndarray :
164+ '''
165+ This function will be used for retrieval task
166+ encode corpus for retrieval task
167+ '''
168+ self .model .eval ()
169+
170+ if isinstance (corpus , str ):
171+ sentences = [corpus ]
172+ else :
173+ sentences = corpus
174+
175+ all_embeddings = []
176+ length_sorted_idx = np .argsort ([- self ._text_length (sen ) for sen in sentences ])
177+ sentences_sorted = [sentences [idx ] for idx in length_sorted_idx ]
178+
179+ for start_index in tqdm (range (0 , len (sentences_sorted ), batch_size ), desc = "Inference Embeddings" ,
180+ disable = len (sentences_sorted ) < 256 ):
181+ sentences_batch = sentences_sorted [start_index :start_index + batch_size ]
182+ inputs = self .tokenizer (
183+ sentences_batch ,
184+ padding = True ,
185+ truncation = True ,
186+ return_tensors = 'pt' ,
187+ max_length = max_length ,
188+ add_special_tokens = True
189+ ).to (self .device )
190+ outputs = self .model (** inputs , return_dict = True )
191+ embeddings = last_token_pool (outputs .last_hidden_state , inputs ['attention_mask' ])
192+
193+ if self .normalize_embeddings :
194+ embeddings = F .normalize (embeddings , p = 2 , dim = 1 )
195+ all_embeddings .extend (embeddings .float ().cpu ())
196+
197+ all_embeddings = [all_embeddings [idx ] for idx in np .argsort (length_sorted_idx )]
198+ all_embeddings = np .asarray ([emb .numpy () for emb in all_embeddings ])
199+ return all_embeddings
200+
201+ def _text_length (self , text : Union [List [int ], List [List [int ]]]):
202+ """
203+ Help function to get the length for the input text. Text can be either
204+ a list of ints (which means a single text as input), or a tuple of list of ints
205+ (representing several text inputs to the model).
206+ """
207+
208+ if isinstance (text , dict ): # {key: value} case
209+ return len (next (iter (text .values ())))
210+ elif not hasattr (text , '__len__' ): # Object has no len() method
211+ return 1
212+ elif len (text ) == 0 or isinstance (text [0 ], int ): # Empty string or list of ints
213+ return len (text )
214+ else :
215+ return sum ([len (t ) for t in text ]) # Sum of length of individual strings
216+
8217
9218class FlagModel :
10219 def __init__ (
@@ -185,7 +394,7 @@ def encode_queries(self, queries: Union[List[str], str],
185394 max_length : int = 256 ,
186395 task : str = 'qa' ) -> np .ndarray :
187396 '''
188- Encode queries into dense vectors.
397+ Encode queries into dense vectors.
189398 Automatically add instructions according to given task.
190399 '''
191400 instruction = self .instructions [task ]["query" ]
@@ -202,7 +411,7 @@ def encode_keys(self, keys: Union[List[str], str],
202411 max_length : int = 512 ,
203412 task : str = 'qa' ) -> np .ndarray :
204413 '''
205- Encode keys into dense vectors.
414+ Encode keys into dense vectors.
206415 Automatically add instructions according to given task.
207416 '''
208417 instruction = self .instructions [task ]["key" ]
0 commit comments