Skip to content

Commit 3e9c603

Browse files
committed
update adjust batch size
1 parent abd9ae8 commit 3e9c603

4 files changed

Lines changed: 39 additions & 54 deletions

File tree

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,19 +243,16 @@ def encode_single_device(
243243

244244
# adjust batch size
245245
flag = False
246-
max_length_inputs = self.tokenizer.pad(
247-
all_inputs_sorted[:1],
248-
padding=True,
249-
return_tensors='pt',
250-
**kwargs
251-
).to(device)
252246
while flag is False:
253247
try:
254-
test_inputs_batch = {}
255-
for k, v in max_length_inputs.items():
256-
test_inputs_batch[k] = v.repeat(batch_size, 1)
257-
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
258-
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
248+
inputs_batch = self.tokenizer.pad(
249+
all_inputs_sorted[: batch_size],
250+
padding=True,
251+
return_tensors='pt',
252+
**kwargs
253+
).to(device)
254+
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
255+
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
259256
flag = True
260257
except RuntimeError as e:
261258
batch_size = batch_size * 3 // 4

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -393,19 +393,16 @@ def encode_queries_single_device(
393393

394394
# adjust batch size
395395
flag = False
396-
max_length_inputs = self.tokenizer.pad(
397-
all_inputs_sorted[:1],
398-
padding=True,
399-
return_tensors='pt',
400-
**kwargs
401-
).to(device)
402396
while flag is False:
403397
try:
404-
test_inputs_batch = {}
405-
for k, v in max_length_inputs.items():
406-
test_inputs_batch[k] = v.repeat(batch_size, 1)
407-
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
408-
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
398+
inputs_batch = self.tokenizer.pad(
399+
all_inputs_sorted[: batch_size],
400+
padding=True,
401+
return_tensors='pt',
402+
**kwargs
403+
).to(device)
404+
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
405+
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
409406
flag = True
410407
except RuntimeError as e:
411408
batch_size = batch_size * 3 // 4
@@ -505,19 +502,16 @@ def encode_single_device(
505502

506503
# adjust batch size
507504
flag = False
508-
max_length_inputs = self.tokenizer.pad(
509-
all_inputs_sorted[:1],
510-
padding=True,
511-
return_tensors='pt',
512-
**kwargs
513-
).to(device)
514505
while flag is False:
515506
try:
516-
test_inputs_batch = {}
517-
for k, v in max_length_inputs.items():
518-
test_inputs_batch[k] = v.repeat(batch_size, 1)
519-
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
520-
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
507+
inputs_batch = self.tokenizer.pad(
508+
all_inputs_sorted[: batch_size],
509+
padding=True,
510+
return_tensors='pt',
511+
**kwargs
512+
).to(device)
513+
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
514+
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
521515
flag = True
522516
except RuntimeError as e:
523517
batch_size = batch_size * 3 // 4

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,16 @@ def encode_single_device(
224224

225225
# adjust batch size
226226
flag = False
227-
max_length_inputs = self.tokenizer.pad(
228-
all_inputs_sorted[:1],
229-
padding=True,
230-
return_tensors='pt',
231-
**kwargs
232-
).to(device)
233227
while flag is False:
234228
try:
235-
test_inputs_batch = {}
236-
for k, v in max_length_inputs.items():
237-
test_inputs_batch[k] = v.repeat(batch_size, 1)
238-
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
239-
embeddings = self.pooling(last_hidden_state, test_inputs_batch['attention_mask'])
229+
inputs_batch = self.tokenizer.pad(
230+
all_inputs_sorted[: batch_size],
231+
padding=True,
232+
return_tensors='pt',
233+
**kwargs
234+
).to(device)
235+
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
236+
embeddings = self.pooling(last_hidden_state, inputs_batch['attention_mask'])
240237
flag = True
241238
except RuntimeError as e:
242239
batch_size = batch_size * 3 // 4

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,19 +388,16 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
388388

389389
# adjust batch size
390390
flag = False
391-
max_length_inputs = self.tokenizer.pad(
392-
all_inputs_sorted[:1],
393-
padding=True,
394-
return_tensors='pt',
395-
**kwargs
396-
).to(device)
397391
while flag is False:
398392
try:
399-
test_inputs_batch = {}
400-
for k, v in max_length_inputs.items():
401-
test_inputs_batch[k] = v.repeat(batch_size, 1)
393+
inputs_batch = self.tokenizer.pad(
394+
all_inputs_sorted[: batch_size],
395+
padding=True,
396+
return_tensors='pt',
397+
**kwargs
398+
).to(device)
402399
outputs = self.model(
403-
test_inputs_batch,
400+
inputs_batch,
404401
return_dense=return_dense,
405402
return_sparse=return_sparse,
406403
return_colbert_vecs=return_colbert_vecs

0 commit comments

Comments
 (0)