Skip to content

Commit c65b55b

Browse files
committed
fix OOM bug: add torch.OutOfMemoryError exception
1 parent aa72ea6 commit c65b55b

8 files changed

Lines changed: 19 additions & 1 deletion

File tree

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def encode_single_device(
182182
flag = True
183183
except RuntimeError as e:
184184
batch_size = batch_size * 3 // 4
185+
except torch.OutOfMemoryError as e:
186+
batch_size = batch_size * 3 // 4
185187

186188
# encode
187189
all_embeddings = []

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def encode_queries_single_device(
280280
flag = True
281281
except RuntimeError as e:
282282
batch_size = batch_size * 3 // 4
283+
except torch.OutOfMemoryError as e:
284+
batch_size = batch_size * 3 // 4
283285

284286
# encode
285287
all_embeddings = []
@@ -391,6 +393,8 @@ def encode_single_device(
391393
flag = True
392394
except RuntimeError as e:
393395
batch_size = batch_size * 3 // 4
396+
except torch.OutOfMemoryError as e:
397+
batch_size = batch_size * 3 // 4
394398

395399
# encode
396400
all_embeddings = []

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def encode_single_device(
172172
flag = True
173173
except RuntimeError as e:
174174
batch_size = batch_size * 3 // 4
175+
except torch.OutOfMemoryError as e:
176+
batch_size = batch_size * 3 // 4
175177

176178
# encode
177179
all_embeddings = []

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
306306
flag = True
307307
except RuntimeError as e:
308308
batch_size = batch_size * 3 // 4
309+
except torch.OutOfMemoryError as e:
310+
batch_size = batch_size * 3 // 4
309311

310312
# encode
311313
all_dense_embeddings, all_lexical_weights, all_colbert_vecs = [], [], []

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def compute_score_single_gpu(
319319
flag = True
320320
except RuntimeError as e:
321321
batch_size = batch_size * 3 // 4
322+
except torch.OutOfMemoryError as e:
323+
batch_size = batch_size * 3 // 4
322324

323325
dataset, dataloader = None, None
324326
if use_dataloader:

FlagEmbedding/inference/reranker/decoder_only/layerwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def compute_score_single_gpu(
227227
flag = True
228228
except RuntimeError as e:
229229
batch_size = batch_size * 3 // 4
230+
except torch.OutOfMemoryError as e:
231+
batch_size = batch_size * 3 // 4
230232

231233
dataset, dataloader = None, None
232234
if use_dataloader:

FlagEmbedding/inference/reranker/decoder_only/lightweight.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,9 @@ def compute_score_single_gpu(
302302
flag = True
303303
except RuntimeError as e:
304304
batch_size = batch_size * 3 // 4
305-
305+
except torch.OutOfMemoryError as e:
306+
batch_size = batch_size * 3 // 4
307+
306308
all_scores = []
307309
for batch_start in trange(0, len(all_queries_inputs_sorted), batch_size):
308310
queries_inputs = all_queries_inputs_sorted[batch_start:batch_start+batch_size]

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def compute_score_single_gpu(
136136
flag = True
137137
except RuntimeError as e:
138138
batch_size = batch_size * 3 // 4
139+
except torch.OutOfMemoryError as e:
140+
batch_size = batch_size * 3 // 4
139141

140142
all_scores = []
141143
for start_index in tqdm(range(0, len(all_inputs_sorted), batch_size), desc="Compute Scores",

0 commit comments

Comments
 (0)