@@ -39,22 +39,23 @@ class AbsEmbedder(ABC):
3939 Default: `True`.
4040 kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
4141 """
42+
4243 def __init__ (
43- self ,
44- model_name_or_path : str ,
45- normalize_embeddings : bool = True ,
46- use_fp16 : bool = True ,
47- query_instruction_for_retrieval : Optional [str ] = None ,
48- query_instruction_format : str = "{}{}" , # specify the format of query_instruction_for_retrieval
49- devices : Optional [Union [str , int , List [str ], List [int ]]] = None ,
50- # inference
51- batch_size : int = 256 ,
52- query_max_length : int = 512 ,
53- passage_max_length : int = 512 ,
54- instruction : Optional [str ] = None ,
55- instruction_format : str = "{}{}" ,
56- convert_to_numpy : bool = True ,
57- ** kwargs : Any ,
44+ self ,
45+ model_name_or_path : str ,
46+ normalize_embeddings : bool = True ,
47+ use_fp16 : bool = True ,
48+ query_instruction_for_retrieval : Optional [str ] = None ,
49+ query_instruction_format : str = "{}{}" , # specify the format of query_instruction_for_retrieval
50+ devices : Optional [Union [str , int , List [str ], List [int ]]] = None ,
51+ # inference
52+ batch_size : int = 256 ,
53+ query_max_length : int = 512 ,
54+ passage_max_length : int = 512 ,
55+ instruction : Optional [str ] = None ,
56+ instruction_format : str = "{}{}" ,
57+ convert_to_numpy : bool = True ,
58+ ** kwargs : Any ,
5859 ):
5960 self .model_name_or_path = model_name_or_path
6061 self .normalize_embeddings = normalize_embeddings
@@ -78,6 +79,7 @@ def __init__(
7879 # tokenizer and model are initialized in the child class
7980 self .tokenizer = None
8081 self .model = None
82+ self .pool = None
8183
8284 @staticmethod
8385 def get_target_devices (devices : Union [str , int , List [str ], List [int ]]) -> List [str ]:
@@ -109,12 +111,12 @@ def get_detailed_instruct(instruction_format: str, instruction: str, sentence: s
109111 return instruction_format .format (instruction , sentence )
110112
111113 def encode_queries (
112- self ,
113- queries : Union [List [str ], str ],
114- batch_size : Optional [int ] = None ,
115- max_length : Optional [int ] = None ,
116- convert_to_numpy : Optional [bool ] = None ,
117- ** kwargs : Any
114+ self ,
115+ queries : Union [List [str ], str ],
116+ batch_size : Optional [int ] = None ,
117+ max_length : Optional [int ] = None ,
118+ convert_to_numpy : Optional [bool ] = None ,
119+ ** kwargs : Any
118120 ):
119121 if batch_size is None : batch_size = self .batch_size
120122 if max_length is None : max_length = self .query_max_length
@@ -131,12 +133,12 @@ def encode_queries(
131133 )
132134
133135 def encode_corpus (
134- self ,
135- corpus : Union [List [str ], str ],
136- batch_size : Optional [int ] = None ,
137- max_length : Optional [int ] = None ,
138- convert_to_numpy : Optional [bool ] = None ,
139- ** kwargs : Any
136+ self ,
137+ corpus : Union [List [str ], str ],
138+ batch_size : Optional [int ] = None ,
139+ max_length : Optional [int ] = None ,
140+ convert_to_numpy : Optional [bool ] = None ,
141+ ** kwargs : Any
140142 ):
141143 passage_instruction_for_retrieval = self .kwargs .get ("passage_instruction_for_retrieval" , None )
142144 passage_instruction_format = self .kwargs .get ("passage_instruction_format" , "{}{}" )
@@ -156,23 +158,27 @@ def encode_corpus(
156158 )
157159
158160 def encode (
159- self ,
160- sentences : Union [List [str ], str ],
161- batch_size : Optional [int ] = None ,
162- max_length : Optional [int ] = None ,
163- convert_to_numpy : Optional [bool ] = None ,
164- instruction : Optional [str ] = None ,
165- instruction_format : Optional [str ] = None ,
166- ** kwargs : Any
161+ self ,
162+ sentences : Union [List [str ], str ],
163+ batch_size : Optional [int ] = None ,
164+ max_length : Optional [int ] = None ,
165+ convert_to_numpy : Optional [bool ] = None ,
166+ instruction : Optional [str ] = None ,
167+ instruction_format : Optional [str ] = None ,
168+ ** kwargs : Any
167169 ):
168170 if instruction is None : instruction = self .instruction
169171 if instruction_format is None : instruction_format = self .instruction_format
172+ if batch_size is None : batch_size = self .batch_size
173+ if max_length is None : max_length = self .passage_max_length
174+ if convert_to_numpy is None : convert_to_numpy = self .convert_to_numpy
170175
171176 if instruction is not None :
172177 if isinstance (sentences , str ):
173178 sentences = self .get_detailed_instruct (instruction_format , instruction , sentences )
174179 else :
175- sentences = [self .get_detailed_instruct (instruction_format , instruction , sentence ) for sentence in sentences ]
180+ sentences = [self .get_detailed_instruct (instruction_format , instruction , sentence ) for sentence in
181+ sentences ]
176182
177183 if isinstance (sentences , str ) or len (self .target_devices ) == 1 :
178184 return self .encode_single_device (
@@ -184,27 +190,31 @@ def encode(
184190 ** kwargs
185191 )
186192
187- pool = self .start_multi_process_pool (AbsEmbedder ._encode_multi_process_worker )
193+ if self .pool is None :
194+ self .pool = self .start_multi_process_pool (AbsEmbedder ._encode_multi_process_worker )
188195 embeddings = self .encode_multi_process (
189196 sentences ,
190- pool ,
197+ self . pool ,
191198 batch_size = batch_size ,
192199 max_length = max_length ,
193200 convert_to_numpy = convert_to_numpy ,
194201 ** kwargs
195202 )
196- self .stop_multi_process_pool (pool )
197203 return embeddings
198204
205+ def __del__ (self ):
206+ if self .pool is not None :
207+ self .stop_multi_process_pool (self .pool )
208+
199209 @abstractmethod
200210 def encode_single_device (
201- self ,
202- sentences : Union [List [str ], str ],
203- batch_size : int = 256 ,
204- max_length : int = 512 ,
205- convert_to_numpy : bool = True ,
206- device : Optional [str ] = None ,
207- ** kwargs : Any ,
211+ self ,
212+ sentences : Union [List [str ], str ],
213+ batch_size : int = 256 ,
214+ max_length : int = 512 ,
215+ convert_to_numpy : bool = True ,
216+ device : Optional [str ] = None ,
217+ ** kwargs : Any ,
208218 ):
209219 """
210220 This method should encode sentences and return embeddings on a single device.
@@ -213,8 +223,8 @@ def encode_single_device(
213223
214224 # adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
215225 def start_multi_process_pool (
216- self ,
217- process_target_func : Any ,
226+ self ,
227+ process_target_func : Any ,
218228 ) -> Dict [Literal ["input" , "output" , "processes" ], Any ]:
219229 """
220230 Starts a multi-process pool to process the encoding with several independent processes
@@ -253,7 +263,7 @@ def start_multi_process_pool(
253263 # adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L976
254264 @staticmethod
255265 def _encode_multi_process_worker (
256- target_device : str , model : 'AbsEmbedder' , input_queue : Queue , results_queue : Queue
266+ target_device : str , model : 'AbsEmbedder' , input_queue : Queue , results_queue : Queue
257267 ) -> None :
258268 """
259269 Internal working process to encode sentences in multi-process setup
@@ -297,10 +307,10 @@ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"],
297307
298308 # adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
299309 def encode_multi_process (
300- self ,
301- sentences : List [str ],
302- pool : Dict [Literal ["input" , "output" , "processes" ], Any ],
303- ** kwargs
310+ self ,
311+ sentences : List [str ],
312+ pool : Dict [Literal ["input" , "output" , "processes" ], Any ],
313+ ** kwargs
304314 ):
305315 chunk_size = math .ceil (len (sentences ) / len (pool ["processes" ]))
306316
0 commit comments