|
4 | 4 | import queue |
5 | 5 | from multiprocessing import Queue |
6 | 6 |
|
| 7 | +import gc |
7 | 8 | import torch |
8 | 9 | import numpy as np |
9 | 10 | from transformers import AutoModel, AutoTokenizer |
@@ -121,10 +122,8 @@ def __init__( |
121 | 122 | self.query_pool = None |
122 | 123 |
|
123 | 124 | def __del__(self): |
124 | | - if self.pool is not None: |
125 | | - self.stop_multi_process_pool(self.pool) |
126 | | - if self.query_pool is not None: |
127 | | - self.stop_multi_process_pool(self.query_pool) |
| 125 | + self.stop_self_pool() |
| 126 | + self.stop_self_query_pool() |
128 | 127 |
|
129 | 128 | def set_examples(self, examples_for_task: Optional[List[dict]] = None): |
130 | 129 | """Set the prefix to the provided examples. |
@@ -175,6 +174,14 @@ def get_detailed_example(instruction_format: str, instruction: str, query: str, |
175 | 174 | """ |
176 | 175 | return instruction_format.format(instruction, query, response) |
177 | 176 |
|
| 177 | + def stop_self_query_pool(self): |
| 178 | + if self.query_pool is not None: |
| 179 | + self.stop_multi_process_pool(self.query_pool) |
| 180 | + self.query_pool = None |
| 181 | + self.model.to('cpu') |
| 182 | + gc.collect() |
| 183 | + torch.cuda.empty_cache() |
| 184 | + |
178 | 185 | def encode_queries( |
179 | 186 | self, |
180 | 187 | queries: Union[List[str], str], |
@@ -209,9 +216,7 @@ def encode_queries( |
209 | 216 | **kwargs |
210 | 217 | ) |
211 | 218 |
|
212 | | - if self.pool is not None: |
213 | | - self.stop_multi_process_pool(self.pool) |
214 | | - self.pool = None |
| 219 | + self.stop_self_pool() |
215 | 220 | if self.query_pool is None: |
216 | 221 | self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker) |
217 | 222 | embeddings = self.encode_multi_process( |
@@ -244,9 +249,7 @@ def encode_corpus( |
244 | 249 | Returns: |
245 | 250 | Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. |
246 | 251 | """ |
247 | | - if self.query_pool is not None: |
248 | | - self.stop_multi_process_pool(self.query_pool) |
249 | | - self.query_pool = None |
| 252 | + self.stop_self_query_pool() |
250 | 253 | return super().encode_corpus( |
251 | 254 | corpus, |
252 | 255 | batch_size=batch_size, |
|
0 commit comments