Skip to content

Commit 9b75543

Browse files
committed
fix OOM bug: add torch.OutOfMemoryError exception
1 parent e2ca52e commit 9b75543

8 files changed

Lines changed: 9 additions & 9 deletions

File tree

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def encode_single_device(
180180
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
181181
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
182182
flag = True
183-
except:
183+
except (RuntimeError, torch.OutofMemoryError) as e:
184184
batch_size = batch_size * 3 // 4
185185

186186
# encode

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def encode_queries_single_device(
278278
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
279279
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
280280
flag = True
281-
except:
281+
except (RuntimeError, torch.OutofMemoryError) as e:
282282
batch_size = batch_size * 3 // 4
283283

284284
# encode
@@ -389,7 +389,7 @@ def encode_single_device(
389389
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
390390
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
391391
flag = True
392-
except:
392+
except (RuntimeError, torch.OutofMemoryError) as e:
393393
batch_size = batch_size * 3 // 4
394394

395395
# encode

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def encode_single_device(
170170
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
171171
embeddings = self.pooling(last_hidden_state, test_inputs_batch['attention_mask'])
172172
flag = True
173-
except:
173+
except (RuntimeError, torch.OutofMemoryError) as e:
174174
batch_size = batch_size * 3 // 4
175175

176176
# encode

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
304304
return_colbert_vecs=return_colbert_vecs
305305
)
306306
flag = True
307-
except:
307+
except (RuntimeError, torch.OutofMemoryError) as e:
308308
batch_size = batch_size * 3 // 4
309309

310310
# encode

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def compute_score_single_gpu(
317317

318318
self.model(**batch_inputs, output_hidden_states=True)
319319
flag = True
320-
except:
320+
except (RuntimeError, torch.OutofMemoryError) as e:
321321
batch_size = batch_size * 3 // 4
322322

323323
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/layerwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def compute_score_single_gpu(
225225

226226
self.model(**batch_inputs, output_hidden_states=True, cutoff_layers=cutoff_layers)
227227
flag = True
228-
except:
228+
except (RuntimeError, torch.OutofMemoryError) as e:
229229
batch_size = batch_size * 3 // 4
230230

231231
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/lightweight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def compute_score_single_gpu(
300300
cutoff_layers=cutoff_layers
301301
)
302302
flag = True
303-
except:
303+
except (RuntimeError, torch.OutofMemoryError) as e:
304304
batch_size = batch_size * 3 // 4
305305

306306
all_scores = []

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def compute_score_single_gpu(
134134
).to(device)
135135
scores = self.model(**test_inputs_batch, return_dict=True).logits.view(-1, ).float()
136136
flag = True
137-
except:
137+
except (RuntimeError, torch.OutofMemoryError) as e:
138138
batch_size = batch_size * 3 // 4
139139

140140
all_scores = []

0 commit comments

Comments
 (0)