Skip to content

Commit acd2bda

Browse files
committed
update stop_self_pool function to avoid Exception
- embedder and reranker
1 parent ce3a9f8 commit acd2bda

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ def stop_self_pool(self):
8181
if self.pool is not None:
8282
self.stop_multi_process_pool(self.pool)
8383
self.pool = None
84-
self.model.to('cpu')
84+
try:
85+
self.model.to('cpu')
86+
torch.cuda.empty_cache()
87+
except:
88+
pass
8589
gc.collect()
86-
torch.cuda.empty_cache()
8790

8891
@staticmethod
8992
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,12 @@ def stop_self_pool(self):
8282
if self.pool is not None:
8383
self.stop_multi_process_pool(self.pool)
8484
self.pool = None
85-
self.model.to('cpu')
85+
try:
86+
self.model.to('cpu')
87+
torch.cuda.empty_cache()
88+
except:
89+
pass
8690
gc.collect()
87-
torch.cuda.empty_cache()
8891

8992
@staticmethod
9093
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:

0 commit comments

Comments
 (0)