1- from tqdm import tqdm
1+ from tqdm import tqdm , trange
22from typing import cast , Any , List , Union , Optional
33
44import queue
@@ -69,6 +69,7 @@ def __init__(
6969 use_fp16 : bool = True ,
7070 query_instruction_for_retrieval : Optional [str ] = None ,
7171 query_instruction_format : str = "<instruct>{}\n <query>{}" , # specify the format of query_instruction_for_retrieval
72+ suffix : str = '\n <response>' ,
7273 devices : Optional [Union [str , List [str ]]] = None , # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
7374 # Additional parameters for ICLLLMEmbedder
7475 examples_for_task : Optional [List [dict ]] = None ,
@@ -82,6 +83,8 @@ def __init__(
8283 convert_to_numpy : bool = True ,
8384 ** kwargs : Any ,
8485 ):
86+ query_instruction_format = query_instruction_format .replace ('\\ n' , '\n ' )
87+ examples_instruction_format = examples_instruction_format .replace ('\\ n' , '\n ' )
8588 super ().__init__ (
8689 model_name_or_path ,
8790 normalize_embeddings = normalize_embeddings ,
@@ -113,7 +116,15 @@ def __init__(
113116 raise ValueError ("Pooling method must be 'last_token' for LLM-based models." )
114117
115118 self .set_examples ()
116- self .suffix = '\n <response>'
119+ self .suffix = suffix
120+
121+ self .query_pool = None
122+
123+ def __del__ (self ):
124+ if self .pool is not None :
125+ self .stop_multi_process_pool (self .pool )
126+ if self .query_pool is not None :
127+ self .stop_multi_process_pool (self .query_pool )
117128
118129 def set_examples (self , examples_for_task : Optional [List [dict ]] = None ):
119130 """Set the prefix to the provided examples.
@@ -198,16 +209,19 @@ def encode_queries(
198209 ** kwargs
199210 )
200211
201- pool = self .start_multi_process_pool (ICLLLMEmbedder ._encode_queries_multi_process_worker )
212+ if self .pool is not None :
213+ self .stop_multi_process_pool (self .pool )
214+ self .pool = None
215+ if self .query_pool is None :
216+ self .query_pool = self .start_multi_process_pool (ICLLLMEmbedder ._encode_queries_multi_process_worker )
202217 embeddings = self .encode_multi_process (
203218 queries ,
204- pool ,
219+ self . query_pool ,
205220 batch_size = batch_size ,
206221 max_length = max_length ,
207222 convert_to_numpy = convert_to_numpy ,
208223 ** kwargs
209224 )
210- self .stop_multi_process_pool (pool )
211225 return embeddings
212226
213227 def encode_corpus (
@@ -230,6 +244,9 @@ def encode_corpus(
230244 Returns:
231245 Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
232246 """
247+ if self .query_pool is not None :
248+ self .stop_multi_process_pool (self .query_pool )
249+ self .query_pool = None
233250 return super ().encode_corpus (
234251 corpus ,
235252 batch_size = batch_size ,
@@ -338,16 +355,27 @@ def encode_queries_single_device(
338355 suffix_ids = self .tokenizer (self .suffix , add_special_tokens = False )['input_ids' ]
339356
340357 _len_1 = len (self .tokenizer ('<s>' , add_special_tokens = False )['input_ids' ])
341- _len_2 = len (self .tokenizer ('\n <response></s>' , add_special_tokens = False )['input_ids' ])
358+ _len_2 = len (self .tokenizer (f'{ self .suffix } </s>' , add_special_tokens = False )['input_ids' ])
359+ new_max_length = (len (prefix_ids ) + len (suffix_ids ) + max_length + 8 ) // 8 * 8 + 8
342360
343361 # tokenize without padding to get the correct length
344362 all_inputs = []
345- for start_index in range (0 , len (input_texts ), batch_size ):
363+ for start_index in trange (0 , len (input_texts ), batch_size , desc = 'pre tokenize' ):
346364 sentences_batch = input_texts [start_index :start_index + batch_size ]
347365 inputs_batch = self .tokenizer (
348366 sentences_batch ,
349367 truncation = True ,
350- max_length = max_length ,
368+ max_length = max_length - _len_1 - _len_2 ,
369+ add_special_tokens = False ,
370+ ** kwargs
371+ )
372+ sentences_batch = self .tokenizer .batch_decode (inputs_batch ['input_ids' ])
373+ for i in range (len (sentences_batch )):
374+ sentences_batch [i ] = self .prefix + sentences_batch [i ] + self .suffix
375+ inputs_batch = self .tokenizer (
376+ sentences_batch ,
377+ truncation = True ,
378+ max_length = new_max_length ,
351379 ** kwargs
352380 )
353381 inputs_batch = [{
@@ -385,30 +413,16 @@ def encode_queries_single_device(
385413 all_embeddings = []
386414 for start_index in tqdm (range (0 , len (sentences_sorted ), batch_size ), desc = "Inference Embeddings" ,
387415 disable = len (sentences_sorted ) < 256 ):
388- sentences_batch = sentences_sorted [start_index :start_index + batch_size ]
389- inputs = self .tokenizer (
390- sentences_batch ,
391- max_length = max_length - _len_1 - _len_2 ,
392- return_token_type_ids = False ,
393- truncation = True ,
394- return_tensors = None ,
395- add_special_tokens = False
396- )
397- new_max_length = (len (prefix_ids ) + len (suffix_ids ) + max_length + 8 ) // 8 * 8 + 8
398- sentences_batch = self .tokenizer .batch_decode (inputs ['input_ids' ])
399- for i in range (len (sentences_batch )):
400- sentences_batch [i ] = self .prefix + sentences_batch [i ] + self .suffix
401- inputs = self .tokenizer (
402- sentences_batch ,
416+ inputs_batch = all_inputs_sorted [start_index :start_index + batch_size ]
417+ inputs_batch = self .tokenizer .pad (
418+ inputs_batch ,
403419 padding = True ,
404- truncation = True ,
405420 return_tensors = 'pt' ,
406- max_length = new_max_length ,
407- add_special_tokens = True
421+ ** kwargs
408422 ).to (device )
409423
410- last_hidden_state = self .model (** inputs , return_dict = True ).last_hidden_state
411- embeddings = last_token_pool (last_hidden_state , inputs ['attention_mask' ])
424+ last_hidden_state = self .model (** inputs_batch , return_dict = True ).last_hidden_state
425+ embeddings = last_token_pool (last_hidden_state , inputs_batch ['attention_mask' ])
412426 if self .normalize_embeddings :
413427 embeddings = torch .nn .functional .normalize (embeddings , dim = - 1 )
414428 embeddings = cast (torch .Tensor , embeddings )
@@ -469,7 +483,7 @@ def encode_single_device(
469483
470484 # tokenize without padding to get the correct length
471485 all_inputs = []
472- for start_index in range (0 , len (sentences ), batch_size ):
486+ for start_index in trange (0 , len (sentences ), batch_size , desc = 'pre tokenize' ):
473487 sentences_batch = sentences [start_index :start_index + batch_size ]
474488 inputs_batch = self .tokenizer (
475489 sentences_batch ,
0 commit comments