@@ -393,19 +393,16 @@ def encode_queries_single_device(
393393
394394 # adjust batch size
395395 flag = False
396- max_length_inputs = self .tokenizer .pad (
397- all_inputs_sorted [:1 ],
398- padding = True ,
399- return_tensors = 'pt' ,
400- ** kwargs
401- ).to (device )
402396 while flag is False :
403397 try :
404- test_inputs_batch = {}
405- for k , v in max_length_inputs .items ():
406- test_inputs_batch [k ] = v .repeat (batch_size , 1 )
407- last_hidden_state = self .model (** test_inputs_batch , return_dict = True ).last_hidden_state
408- embeddings = last_token_pool (last_hidden_state , test_inputs_batch ['attention_mask' ])
398+ inputs_batch = self .tokenizer .pad (
399+ all_inputs_sorted [: batch_size ],
400+ padding = True ,
401+ return_tensors = 'pt' ,
402+ ** kwargs
403+ ).to (device )
404+ last_hidden_state = self .model (** inputs_batch , return_dict = True ).last_hidden_state
405+ embeddings = last_token_pool (last_hidden_state , inputs_batch ['attention_mask' ])
409406 flag = True
410407 except RuntimeError as e :
411408 batch_size = batch_size * 3 // 4
@@ -505,19 +502,16 @@ def encode_single_device(
505502
506503 # adjust batch size
507504 flag = False
508- max_length_inputs = self .tokenizer .pad (
509- all_inputs_sorted [:1 ],
510- padding = True ,
511- return_tensors = 'pt' ,
512- ** kwargs
513- ).to (device )
514505 while flag is False :
515506 try :
516- test_inputs_batch = {}
517- for k , v in max_length_inputs .items ():
518- test_inputs_batch [k ] = v .repeat (batch_size , 1 )
519- last_hidden_state = self .model (** test_inputs_batch , return_dict = True ).last_hidden_state
520- embeddings = last_token_pool (last_hidden_state , test_inputs_batch ['attention_mask' ])
507+ inputs_batch = self .tokenizer .pad (
508+ all_inputs_sorted [: batch_size ],
509+ padding = True ,
510+ return_tensors = 'pt' ,
511+ ** kwargs
512+ ).to (device )
513+ last_hidden_state = self .model (** inputs_batch , return_dict = True ).last_hidden_state
514+ embeddings = last_token_pool (last_hidden_state , inputs_batch ['attention_mask' ])
521515 flag = True
522516 except RuntimeError as e :
523517 batch_size = batch_size * 3 // 4
0 commit comments