@@ -32,7 +32,7 @@ def __init__(
3232 self ,
3333 model_name_or_path : str = None ,
3434 normalize_embeddings : bool = True ,
35- query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passage that answer the query.' ,
35+ query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passages that answer the query.' ,
3636 examples_for_task : List [dict ] = None ,
3737 use_fp16 : bool = True
3838 ) -> None :
@@ -215,6 +215,122 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
215215 return sum ([len (t ) for t in text ]) # Sum of length of individual strings
216216
217217
218+ class FlagLLMModel :
219+ def __init__ (
220+ self ,
221+ model_name_or_path : str = None ,
222+ normalize_embeddings : bool = True ,
223+ query_instruction_for_retrieval : str = 'Given a query, retrieval relevant passages that answer the query.' ,
224+ use_fp16 : bool = True ,
225+ ) -> None :
226+ self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
227+ self .model = AutoModel .from_pretrained (model_name_or_path )
228+ self .query_instruction_for_retrieval = query_instruction_for_retrieval
229+ self .normalize_embeddings = normalize_embeddings
230+
231+ if torch .cuda .is_available ():
232+ self .device = torch .device ("cuda" )
233+ elif torch .backends .mps .is_available ():
234+ self .device = torch .device ("mps" )
235+ elif is_torch_npu_available ():
236+ self .device = torch .device ("npu" )
237+ else :
238+ self .device = torch .device ("cpu" )
239+ use_fp16 = False
240+ if use_fp16 : self .model .half ()
241+ self .model = self .model .to (self .device )
242+
243+ self .num_gpus = torch .cuda .device_count ()
244+ if self .num_gpus > 1 :
245+ print (f"----------using { self .num_gpus } *GPUs----------" )
246+ self .model = torch .nn .DataParallel (self .model )
247+
248+ def encode_queries (self , queries : Union [List [str ], str ],
249+ batch_size : int = 256 ,
250+ max_length : int = 512 ,
251+ convert_to_numpy : bool = True ) -> np .ndarray :
252+ '''
253+ This function will be used for retrieval task
254+ if there is a instruction for queries, we will add it to the query text
255+ '''
256+ if isinstance (queries , str ):
257+ input_texts = get_detailed_instruct (self .query_instruction_for_retrieval , queries )
258+ else :
259+ input_texts = [get_detailed_instruct (self .query_instruction_for_retrieval , q ) for q in queries ]
260+ return self .encode (input_texts , batch_size = batch_size , max_length = max_length , convert_to_numpy = convert_to_numpy )
261+
262+ def encode_corpus (self ,
263+ corpus : Union [List [str ], str ],
264+ batch_size : int = 256 ,
265+ max_length : int = 512 ,
266+ convert_to_numpy : bool = True ) -> np .ndarray :
267+ '''
268+ This function will be used for retrieval task
269+ encode corpus for retrieval task
270+ '''
271+ return self .encode (corpus , batch_size = batch_size , max_length = max_length , convert_to_numpy = convert_to_numpy )
272+
273+ @torch .no_grad ()
274+ def encode (self ,
275+ sentences : Union [List [str ], str ],
276+ batch_size : int = 256 ,
277+ max_length : int = 512 ,
278+ convert_to_numpy : bool = True ) -> np .ndarray :
279+ if self .num_gpus > 0 :
280+ batch_size = batch_size * self .num_gpus
281+ self .model .eval ()
282+
283+ input_was_string = False
284+ if isinstance (sentences , str ):
285+ sentences = [sentences ]
286+ input_was_string = True
287+
288+ all_embeddings = []
289+ for start_index in tqdm (range (0 , len (sentences ), batch_size ), desc = "Inference Embeddings" ,
290+ disable = len (sentences ) < 256 ):
291+ sentences_batch = sentences [start_index :start_index + batch_size ]
292+ inputs = self .tokenizer (
293+ sentences_batch ,
294+ padding = True ,
295+ truncation = True ,
296+ return_tensors = 'pt' ,
297+ max_length = max_length ,
298+ pad_to_multiple_of = 8 ,
299+ ).to (self .device )
300+ last_hidden_state = self .model (** inputs , return_dict = True ).last_hidden_state
301+ embeddings = self .last_token_pool (last_hidden_state , inputs ['attention_mask' ])
302+ if self .normalize_embeddings :
303+ embeddings = torch .nn .functional .normalize (embeddings , dim = - 1 )
304+ embeddings = cast (torch .Tensor , embeddings )
305+
306+ if convert_to_numpy :
307+ embeddings = embeddings .cpu ().numpy ()
308+ all_embeddings .append (embeddings )
309+
310+ if convert_to_numpy :
311+ all_embeddings = np .concatenate (all_embeddings , axis = 0 )
312+ else :
313+ all_embeddings = torch .cat (all_embeddings , dim = 0 )
314+
315+ if input_was_string :
316+ return all_embeddings [0 ]
317+ return all_embeddings
318+
319+ def last_token_pool (self ,
320+ last_hidden_state : torch .Tensor ,
321+ attention_mask : torch .Tensor = None ):
322+ left_padding = attention_mask [:, - 1 ].sum () == attention_mask .shape [0 ]
323+ if left_padding :
324+ return last_hidden_state [:, - 1 ]
325+ else :
326+ sequence_lengths = attention_mask .sum (dim = 1 ) - 1
327+ batch_size = last_hidden_state .shape [0 ]
328+ return last_hidden_state [
329+ torch .arange (batch_size , device = last_hidden_state .device ),
330+ sequence_lengths ,
331+ ]
332+
333+
218334class FlagModel :
219335 def __init__ (
220336 self ,
@@ -315,7 +431,7 @@ def encode(self,
315431 if convert_to_numpy :
316432 all_embeddings = np .concatenate (all_embeddings , axis = 0 )
317433 else :
318- all_embeddings = torch .stack (all_embeddings )
434+ all_embeddings = torch .cat (all_embeddings , dim = 0 )
319435
320436 if input_was_string :
321437 return all_embeddings [0 ]
0 commit comments