@@ -83,9 +83,9 @@ class SimpleTokenizer {
8383 return true ;
8484 }
8585
86- std::vector<int64_t > encode (const std::string& text, size_t max_length = 512 ) {
86+ std::vector<int64_t > encode_unpadded (const std::string& text, size_t max_length = 512 ) {
8787 std::vector<int64_t > token_ids;
88- token_ids.reserve (max_length);
88+ token_ids.reserve (std::min ( max_length, static_cast < size_t >( 128 )) );
8989 token_ids.push_back (cls_id_); // [CLS]
9090
9191 const auto words = basic_tokenize (text);
@@ -104,12 +104,18 @@ class SimpleTokenizer {
104104 }
105105
106106 token_ids.push_back (sep_id_); // [SEP]
107-
108- // Pad to max_length
109- while (token_ids.size () < max_length) {
110- token_ids.push_back (pad_id_); // [PAD]
107+ return token_ids;
108+ }
109+
110+ void pad_to (std::vector<int64_t >& token_ids, size_t target_length) {
111+ while (token_ids.size () < target_length) {
112+ token_ids.push_back (pad_id_);
111113 }
112-
114+ }
115+
116+ std::vector<int64_t > encode (const std::string& text, size_t max_length = 512 ) {
117+ auto token_ids = encode_unpadded (text, max_length);
118+ pad_to (token_ids, max_length);
113119 return token_ids;
114120 }
115121
@@ -553,12 +559,22 @@ class ONNXEmbeddingProvider::Impl {
553559 std::lock_guard<std::mutex> lock (embed_mutex_);
554560
555561 try {
556- // 1. Tokenize input
557- auto token_ids = tokenizer_.encode (text, max_seq_length_);
562+ auto token_ids = tokenizer_.encode_unpadded (text, max_seq_length_);
563+ const size_t pad_length = align_up (token_ids.size (), 8 );
564+ tokenizer_.pad_to (token_ids, pad_length);
565+
558566 auto attention_mask = tokenizer_.create_attention_mask (token_ids);
559567
560- std::memcpy (input_ids_buf_.data (), token_ids.data (), max_seq_length_ * sizeof (int64_t ));
561- std::memcpy (attention_mask_buf_.data (), attention_mask.data (), max_seq_length_ * sizeof (int64_t ));
568+ std::memcpy (input_ids_buf_.data (), token_ids.data (), pad_length * sizeof (int64_t ));
569+ std::memcpy (attention_mask_buf_.data (), attention_mask.data (), pad_length * sizeof (int64_t ));
570+ std::memset (token_type_ids_buf_.data (), 0 , pad_length * sizeof (int64_t ));
571+
572+ input_shape_ = {1 , static_cast <int64_t >(pad_length)};
573+
574+ LOGI (" Single embed: %zu real tokens, padded to %zu (max %zu)" ,
575+ token_ids.size () - std::count (token_ids.begin (), token_ids.end (), 0 ),
576+ pad_length, max_seq_length_);
577+
562578 OrtStatusGuard status_guard (ort_api_);
563579 OrtValueGuard input_ids_guard (ort_api_);
564580 OrtValueGuard attention_mask_guard (ort_api_);
@@ -568,7 +584,7 @@ class ONNXEmbeddingProvider::Impl {
568584 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
569585 memory_info_,
570586 input_ids_buf_.data (),
571- max_seq_length_ * sizeof (int64_t ),
587+ pad_length * sizeof (int64_t ),
572588 input_shape_.data (),
573589 input_shape_.size (),
574590 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -583,7 +599,7 @@ class ONNXEmbeddingProvider::Impl {
583599 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
584600 memory_info_,
585601 attention_mask_buf_.data (),
586- max_seq_length_ * sizeof (int64_t ),
602+ pad_length * sizeof (int64_t ),
587603 input_shape_.data (),
588604 input_shape_.size (),
589605 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -598,7 +614,7 @@ class ONNXEmbeddingProvider::Impl {
598614 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
599615 memory_info_,
600616 token_type_ids_buf_.data (),
601- max_seq_length_ * sizeof (int64_t ),
617+ pad_length * sizeof (int64_t ),
602618 input_shape_.data (),
603619 input_shape_.size (),
604620 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -671,11 +687,10 @@ class ONNXEmbeddingProvider::Impl {
671687 ort_api_->ReleaseTensorTypeAndShapeInfo (shape_info);
672688 }
673689
674- // 5. Mean pooling
675690 auto pooled = mean_pooling (
676691 output_data,
677692 attention_mask,
678- max_seq_length_ ,
693+ pad_length ,
679694 actual_hidden_dim
680695 );
681696
@@ -707,34 +722,90 @@ class ONNXEmbeddingProvider::Impl {
707722 return {};
708723 }
709724
710- const size_t batch_size = texts.size ();
711-
712725 std::lock_guard<std::mutex> lock (embed_mutex_);
713726
727+ std::vector<std::vector<float >> all_results;
728+ all_results.reserve (texts.size ());
729+
730+ for (size_t offset = 0 ; offset < texts.size (); offset += kMaxSubBatchSize ) {
731+ size_t sub_batch_size = std::min (kMaxSubBatchSize , texts.size () - offset);
732+
733+ LOGI (" Embedding sub-batch %zu/%zu (size=%zu)" ,
734+ offset / kMaxSubBatchSize + 1 ,
735+ (texts.size () + kMaxSubBatchSize - 1 ) / kMaxSubBatchSize ,
736+ sub_batch_size);
737+
738+ auto sub_results = embed_sub_batch (texts, offset, sub_batch_size);
739+ if (sub_results.empty ()) {
740+ LOGE (" Sub-batch embedding failed at offset %zu" , offset);
741+ return {};
742+ }
743+
744+ for (auto & r : sub_results) {
745+ all_results.push_back (std::move (r));
746+ }
747+ }
748+
749+ LOGI (" Generated batch embeddings: count=%zu, dim=%zu" , all_results.size (), embedding_dim_);
750+ return all_results;
751+ }
752+
753+ size_t dimension () const noexcept {
754+ return embedding_dim_;
755+ }
756+
757+ bool is_ready () const noexcept {
758+ return ready_;
759+ }
760+
761+ private:
762+ static constexpr size_t kMaxSubBatchSize = 50 ;
763+
764+ static size_t align_up (size_t value, size_t alignment) {
765+ const size_t aligned = ((value + alignment - 1 ) / alignment) * alignment;
766+ return std::min (aligned, static_cast <size_t >(512 ));
767+ }
768+
769+ std::vector<std::vector<float >> embed_sub_batch (
770+ const std::vector<std::string>& texts,
771+ size_t offset,
772+ size_t count
773+ ) {
714774 try {
715- // 1. Tokenize all texts into flat contiguous buffers
716- std::vector<int64_t > flat_input_ids (batch_size * max_seq_length_, 0 );
717- std::vector<int64_t > flat_attention_mask (batch_size * max_seq_length_, 0 );
718- std::vector<int64_t > flat_token_type_ids (batch_size * max_seq_length_, 0 );
775+ std::vector<std::vector<int64_t >> all_token_ids (count);
776+ size_t max_actual_len = 0 ;
777+
778+ for (size_t i = 0 ; i < count; ++i) {
779+ all_token_ids[i] = tokenizer_.encode_unpadded (texts[offset + i], max_seq_length_);
780+ max_actual_len = std::max (max_actual_len, all_token_ids[i].size ());
781+ }
719782
720- std::vector<std::vector< int64_t >> attention_masks (batch_size );
783+ const size_t pad_length = align_up (max_actual_len, 8 );
721784
722- for (size_t i = 0 ; i < batch_size; ++i) {
723- auto token_ids = tokenizer_.encode (texts[i], max_seq_length_);
724- auto attn_mask = tokenizer_.create_attention_mask (token_ids);
785+ LOGI (" Sub-batch dynamic padding: max_actual=%zu, pad_length=%zu (was %zu)" ,
786+ max_actual_len, pad_length, max_seq_length_);
725787
726- std::memcpy (flat_input_ids.data () + i * max_seq_length_,
727- token_ids.data (), max_seq_length_ * sizeof (int64_t ));
728- std::memcpy (flat_attention_mask.data () + i * max_seq_length_,
729- attn_mask.data (), max_seq_length_ * sizeof (int64_t ));
788+ std::vector<int64_t > flat_input_ids (count * pad_length, 0 );
789+ std::vector<int64_t > flat_attention_mask (count * pad_length, 0 );
790+ std::vector<int64_t > flat_token_type_ids (count * pad_length, 0 );
791+
792+ std::vector<std::vector<int64_t >> attention_masks (count);
793+
794+ for (size_t i = 0 ; i < count; ++i) {
795+ tokenizer_.pad_to (all_token_ids[i], pad_length);
796+ auto attn_mask = tokenizer_.create_attention_mask (all_token_ids[i]);
797+
798+ std::memcpy (flat_input_ids.data () + i * pad_length,
799+ all_token_ids[i].data (), pad_length * sizeof (int64_t ));
800+ std::memcpy (flat_attention_mask.data () + i * pad_length,
801+ attn_mask.data (), pad_length * sizeof (int64_t ));
730802
731803 attention_masks[i] = std::move (attn_mask);
732804 }
733805
734- // 2. Create tensors with batch shape {batch_size, max_seq_length_}
735806 std::vector<int64_t > batch_shape = {
736- static_cast <int64_t >(batch_size ),
737- static_cast <int64_t >(max_seq_length_ )
807+ static_cast <int64_t >(count ),
808+ static_cast <int64_t >(pad_length )
738809 };
739810
740811 OrtStatusGuard status_guard (ort_api_);
@@ -745,40 +816,39 @@ class ONNXEmbeddingProvider::Impl {
745816 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
746817 memory_info_,
747818 flat_input_ids.data (),
748- batch_size * max_seq_length_ * sizeof (int64_t ),
819+ count * pad_length * sizeof (int64_t ),
749820 batch_shape.data (), batch_shape.size (),
750821 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
751822 input_ids_guard.ptr ()));
752823 if (status_guard.is_error ()) {
753- LOGE (" Batch : CreateTensor (input_ids) failed: %s" , status_guard.error_message ());
824+ LOGE (" Sub-batch : CreateTensor (input_ids) failed: %s" , status_guard.error_message ());
754825 return {};
755826 }
756827
757828 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
758829 memory_info_,
759830 flat_attention_mask.data (),
760- batch_size * max_seq_length_ * sizeof (int64_t ),
831+ count * pad_length * sizeof (int64_t ),
761832 batch_shape.data (), batch_shape.size (),
762833 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
763834 attention_mask_guard.ptr ()));
764835 if (status_guard.is_error ()) {
765- LOGE (" Batch : CreateTensor (attention_mask) failed: %s" , status_guard.error_message ());
836+ LOGE (" Sub-batch : CreateTensor (attention_mask) failed: %s" , status_guard.error_message ());
766837 return {};
767838 }
768839
769840 status_guard.reset (ort_api_->CreateTensorWithDataAsOrtValue (
770841 memory_info_,
771842 flat_token_type_ids.data (),
772- batch_size * max_seq_length_ * sizeof (int64_t ),
843+ count * pad_length * sizeof (int64_t ),
773844 batch_shape.data (), batch_shape.size (),
774845 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
775846 token_type_ids_guard.ptr ()));
776847 if (status_guard.is_error ()) {
777- LOGE (" Batch : CreateTensor (token_type_ids) failed: %s" , status_guard.error_message ());
848+ LOGE (" Sub-batch : CreateTensor (token_type_ids) failed: %s" , status_guard.error_message ());
778849 return {};
779850 }
780851
781- // 3. Run ONNX once for the entire batch
782852 const char * input_names[] = {" input_ids" , " attention_mask" , " token_type_ids" };
783853 const OrtValue* inputs[] = {
784854 input_ids_guard.get (), attention_mask_guard.get (), token_type_ids_guard.get ()
@@ -793,22 +863,21 @@ class ONNXEmbeddingProvider::Impl {
793863 output_names, 1 ,
794864 &output_ptr));
795865 if (status_guard.is_error ()) {
796- LOGE (" Batch ONNX inference failed: %s" , status_guard.error_message ());
866+ LOGE (" Sub-batch ONNX inference failed: %s" , status_guard.error_message ());
797867 return {};
798868 }
799869 *output_guard.ptr () = output_ptr;
800870
801- // 4. Extract output and get hidden_dim from shape
802871 float * output_data = nullptr ;
803872 OrtStatusGuard data_status (ort_api_);
804873 data_status.reset (ort_api_->GetTensorMutableData (output_guard.get (), (void **)&output_data));
805874 if (data_status.is_error () || output_data == nullptr ) {
806- LOGE (" Batch : Failed to get output tensor data" );
875+ LOGE (" Sub-batch : Failed to get output tensor data" );
807876 return {};
808877 }
809878
810879 size_t actual_hidden_dim = embedding_dim_;
811- size_t actual_seq_len = max_seq_length_;
880+ size_t actual_seq_len = pad_length; // Default to what we sent
812881 OrtTensorTypeAndShapeInfo* shape_info = nullptr ;
813882 OrtStatusGuard shape_status (ort_api_);
814883 shape_status.reset (ort_api_->GetTensorTypeAndShape (output_guard.get (), &shape_info));
@@ -829,11 +898,10 @@ class ONNXEmbeddingProvider::Impl {
829898 ort_api_->ReleaseTensorTypeAndShapeInfo (shape_info);
830899 }
831900
832- // 5. Extract per-sentence embeddings via mean pooling + normalize
833- std::vector<std::vector<float >> results (batch_size);
901+ std::vector<std::vector<float >> results (count);
834902 const size_t stride = actual_seq_len * actual_hidden_dim;
835903
836- for (size_t i = 0 ; i < batch_size ; ++i) {
904+ for (size_t i = 0 ; i < count ; ++i) {
837905 const float * sentence_data = output_data + i * stride;
838906 auto pooled = mean_pooling (
839907 sentence_data, attention_masks[i],
@@ -842,24 +910,14 @@ class ONNXEmbeddingProvider::Impl {
842910 results[i] = std::move (pooled);
843911 }
844912
845- LOGI (" Generated batch embeddings: count=%zu, dim=%zu" , batch_size, actual_hidden_dim);
846913 return results;
847914
848915 } catch (const std::exception& e) {
849- LOGE (" Batch embedding generation failed: %s" , e.what ());
850- return std::vector<std::vector< float >>(batch_size, std::vector< float >(embedding_dim_, 0 . 0f )) ;
916+ LOGE (" Sub-batch embedding failed: %s" , e.what ());
917+ return {} ;
851918 }
852919 }
853920
854- size_t dimension () const noexcept {
855- return embedding_dim_;
856- }
857-
858- bool is_ready () const noexcept {
859- return ready_;
860- }
861-
862- private:
863921 bool initialize_onnx_runtime () {
864922 const OrtApiBase* ort_api_base = OrtGetApiBase ();
865923 const char * ort_version = ort_api_base ? ort_api_base->GetVersionString () : " unknown" ;
0 commit comments