Skip to content

Commit 7f02137

Browse files
committed
fix OOM bug: add torch.OutOfMemoryError exception
1 parent 9b75543 commit 7f02137

8 files changed

Lines changed: 27 additions & 9 deletions

File tree

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def encode_single_device(
180180
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
181181
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
182182
flag = True
183-
except (RuntimeError, torch.OutofMemoryError) as e:
183+
except RuntimeError as e:
184+
batch_size = batch_size * 3 // 4
185+
except torch.OutofMemoryError as e:
184186
batch_size = batch_size * 3 // 4
185187

186188
# encode

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def encode_queries_single_device(
278278
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
279279
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
280280
flag = True
281-
except (RuntimeError, torch.OutofMemoryError) as e:
281+
except RuntimeError as e:
282+
batch_size = batch_size * 3 // 4
283+
except torch.OutofMemoryError as e:
282284
batch_size = batch_size * 3 // 4
283285

284286
# encode
@@ -389,7 +391,9 @@ def encode_single_device(
389391
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
390392
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
391393
flag = True
392-
except (RuntimeError, torch.OutofMemoryError) as e:
394+
except RuntimeError as e:
395+
batch_size = batch_size * 3 // 4
396+
except torch.OutofMemoryError as e:
393397
batch_size = batch_size * 3 // 4
394398

395399
# encode

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def encode_single_device(
170170
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
171171
embeddings = self.pooling(last_hidden_state, test_inputs_batch['attention_mask'])
172172
flag = True
173-
except (RuntimeError, torch.OutofMemoryError) as e:
173+
except RuntimeError as e:
174+
batch_size = batch_size * 3 // 4
175+
except torch.OutofMemoryError as e:
174176
batch_size = batch_size * 3 // 4
175177

176178
# encode

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
304304
return_colbert_vecs=return_colbert_vecs
305305
)
306306
flag = True
307-
except (RuntimeError, torch.OutofMemoryError) as e:
307+
except RuntimeError as e:
308+
batch_size = batch_size * 3 // 4
309+
except torch.OutofMemoryError as e:
308310
batch_size = batch_size * 3 // 4
309311

310312
# encode

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def compute_score_single_gpu(
317317

318318
self.model(**batch_inputs, output_hidden_states=True)
319319
flag = True
320-
except (RuntimeError, torch.OutofMemoryError) as e:
320+
except RuntimeError as e:
321+
batch_size = batch_size * 3 // 4
322+
except torch.OutofMemoryError as e:
321323
batch_size = batch_size * 3 // 4
322324

323325
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/layerwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def compute_score_single_gpu(
225225

226226
self.model(**batch_inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
227227
flag = True
228-
except (RuntimeError, torch.OutofMemoryError) as e:
228+
except RuntimeError as e:
229+
batch_size = batch_size * 3 // 4
230+
except torch.OutofMemoryError as e:
229231
batch_size = batch_size * 3 // 4
230232

231233
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/lightweight.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ def compute_score_single_gpu(
300300
cutoff_layers=cutoff_layers
301301
)
302302
flag = True
303-
except (RuntimeError, torch.OutofMemoryError) as e:
303+
except RuntimeError as e:
304+
batch_size = batch_size * 3 // 4
305+
except torch.OutofMemoryError as e:
304306
batch_size = batch_size * 3 // 4
305307

306308
all_scores = []

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def compute_score_single_gpu(
134134
).to(device)
135135
scores = self.model(**test_inputs_batch, return_dict=True).logits.view(-1, ).float()
136136
flag = True
137-
except (RuntimeError, torch.OutofMemoryError) as e:
137+
except RuntimeError as e:
138+
batch_size = batch_size * 3 // 4
139+
except torch.OutofMemoryError as e:
138140
batch_size = batch_size * 3 // 4
139141

140142
all_scores = []

0 commit comments

Comments
 (0)