Skip to content

Commit 509e805

Browse files
VyasGurushubhammalhotra28
authored andcommitted
Changed batching parametres, similarity threshold, and optimised embedding memory+speed output
1 parent e06388f commit 509e805

8 files changed

Lines changed: 141 additions & 74 deletions

File tree

sdk/runanywhere-commons/include/rac/features/rag/rac_rag_pipeline.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ static inline rac_rag_pipeline_config_t rac_rag_pipeline_config_default(void) {
9090
rac_rag_pipeline_config_t cfg = {0};
9191
cfg.embedding_dimension = 384;
9292
cfg.top_k = 10;
93-
cfg.similarity_threshold = 0.15f;
93+
cfg.similarity_threshold = 0.12f;
9494
cfg.max_context_tokens = 2048;
9595
cfg.chunk_size = 180;
9696
cfg.chunk_overlap = 30;
@@ -121,7 +121,7 @@ static inline rac_rag_config_t rac_rag_config_default(void) {
121121
cfg.llm_model_path = NULL;
122122
cfg.embedding_dimension = 384;
123123
cfg.top_k = 10;
124-
cfg.similarity_threshold = 0.15f;
124+
cfg.similarity_threshold = 0.12f;
125125
cfg.max_context_tokens = 2048;
126126
cfg.chunk_size = 180;
127127
cfg.chunk_overlap = 30;

sdk/runanywhere-commons/src/features/rag/onnx_embedding_provider.cpp

Lines changed: 117 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ class SimpleTokenizer {
8383
return true;
8484
}
8585

86-
std::vector<int64_t> encode(const std::string& text, size_t max_length = 512) {
86+
std::vector<int64_t> encode_unpadded(const std::string& text, size_t max_length = 512) {
8787
std::vector<int64_t> token_ids;
88-
token_ids.reserve(max_length);
88+
token_ids.reserve(std::min(max_length, static_cast<size_t>(128)));
8989
token_ids.push_back(cls_id_); // [CLS]
9090

9191
const auto words = basic_tokenize(text);
@@ -104,12 +104,18 @@ class SimpleTokenizer {
104104
}
105105

106106
token_ids.push_back(sep_id_); // [SEP]
107-
108-
// Pad to max_length
109-
while (token_ids.size() < max_length) {
110-
token_ids.push_back(pad_id_); // [PAD]
107+
return token_ids;
108+
}
109+
110+
void pad_to(std::vector<int64_t>& token_ids, size_t target_length) {
111+
while (token_ids.size() < target_length) {
112+
token_ids.push_back(pad_id_);
111113
}
112-
114+
}
115+
116+
std::vector<int64_t> encode(const std::string& text, size_t max_length = 512) {
117+
auto token_ids = encode_unpadded(text, max_length);
118+
pad_to(token_ids, max_length);
113119
return token_ids;
114120
}
115121

@@ -553,12 +559,22 @@ class ONNXEmbeddingProvider::Impl {
553559
std::lock_guard<std::mutex> lock(embed_mutex_);
554560

555561
try {
556-
// 1. Tokenize input
557-
auto token_ids = tokenizer_.encode(text, max_seq_length_);
562+
auto token_ids = tokenizer_.encode_unpadded(text, max_seq_length_);
563+
const size_t pad_length = align_up(token_ids.size(), 8);
564+
tokenizer_.pad_to(token_ids, pad_length);
565+
558566
auto attention_mask = tokenizer_.create_attention_mask(token_ids);
559567

560-
std::memcpy(input_ids_buf_.data(), token_ids.data(), max_seq_length_ * sizeof(int64_t));
561-
std::memcpy(attention_mask_buf_.data(), attention_mask.data(), max_seq_length_ * sizeof(int64_t));
568+
std::memcpy(input_ids_buf_.data(), token_ids.data(), pad_length * sizeof(int64_t));
569+
std::memcpy(attention_mask_buf_.data(), attention_mask.data(), pad_length * sizeof(int64_t));
570+
std::memset(token_type_ids_buf_.data(), 0, pad_length * sizeof(int64_t));
571+
572+
input_shape_ = {1, static_cast<int64_t>(pad_length)};
573+
574+
LOGI("Single embed: %zu real tokens, padded to %zu (max %zu)",
575+
token_ids.size() - std::count(token_ids.begin(), token_ids.end(), 0),
576+
pad_length, max_seq_length_);
577+
562578
OrtStatusGuard status_guard(ort_api_);
563579
OrtValueGuard input_ids_guard(ort_api_);
564580
OrtValueGuard attention_mask_guard(ort_api_);
@@ -568,7 +584,7 @@ class ONNXEmbeddingProvider::Impl {
568584
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
569585
memory_info_,
570586
input_ids_buf_.data(),
571-
max_seq_length_ * sizeof(int64_t),
587+
pad_length * sizeof(int64_t),
572588
input_shape_.data(),
573589
input_shape_.size(),
574590
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -583,7 +599,7 @@ class ONNXEmbeddingProvider::Impl {
583599
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
584600
memory_info_,
585601
attention_mask_buf_.data(),
586-
max_seq_length_ * sizeof(int64_t),
602+
pad_length * sizeof(int64_t),
587603
input_shape_.data(),
588604
input_shape_.size(),
589605
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -598,7 +614,7 @@ class ONNXEmbeddingProvider::Impl {
598614
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
599615
memory_info_,
600616
token_type_ids_buf_.data(),
601-
max_seq_length_ * sizeof(int64_t),
617+
pad_length * sizeof(int64_t),
602618
input_shape_.data(),
603619
input_shape_.size(),
604620
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
@@ -671,11 +687,10 @@ class ONNXEmbeddingProvider::Impl {
671687
ort_api_->ReleaseTensorTypeAndShapeInfo(shape_info);
672688
}
673689

674-
// 5. Mean pooling
675690
auto pooled = mean_pooling(
676691
output_data,
677692
attention_mask,
678-
max_seq_length_,
693+
pad_length,
679694
actual_hidden_dim
680695
);
681696

@@ -707,34 +722,90 @@ class ONNXEmbeddingProvider::Impl {
707722
return {};
708723
}
709724

710-
const size_t batch_size = texts.size();
711-
712725
std::lock_guard<std::mutex> lock(embed_mutex_);
713726

727+
std::vector<std::vector<float>> all_results;
728+
all_results.reserve(texts.size());
729+
730+
for (size_t offset = 0; offset < texts.size(); offset += kMaxSubBatchSize) {
731+
size_t sub_batch_size = std::min(kMaxSubBatchSize, texts.size() - offset);
732+
733+
LOGI("Embedding sub-batch %zu/%zu (size=%zu)",
734+
offset / kMaxSubBatchSize + 1,
735+
(texts.size() + kMaxSubBatchSize - 1) / kMaxSubBatchSize,
736+
sub_batch_size);
737+
738+
auto sub_results = embed_sub_batch(texts, offset, sub_batch_size);
739+
if (sub_results.empty()) {
740+
LOGE("Sub-batch embedding failed at offset %zu", offset);
741+
return {};
742+
}
743+
744+
for (auto& r : sub_results) {
745+
all_results.push_back(std::move(r));
746+
}
747+
}
748+
749+
LOGI("Generated batch embeddings: count=%zu, dim=%zu", all_results.size(), embedding_dim_);
750+
return all_results;
751+
}
752+
753+
size_t dimension() const noexcept {
754+
return embedding_dim_;
755+
}
756+
757+
bool is_ready() const noexcept {
758+
return ready_;
759+
}
760+
761+
private:
762+
static constexpr size_t kMaxSubBatchSize = 50;
763+
764+
static size_t align_up(size_t value, size_t alignment) {
765+
const size_t aligned = ((value + alignment - 1) / alignment) * alignment;
766+
return std::min(aligned, static_cast<size_t>(512));
767+
}
768+
769+
std::vector<std::vector<float>> embed_sub_batch(
770+
const std::vector<std::string>& texts,
771+
size_t offset,
772+
size_t count
773+
) {
714774
try {
715-
// 1. Tokenize all texts into flat contiguous buffers
716-
std::vector<int64_t> flat_input_ids(batch_size * max_seq_length_, 0);
717-
std::vector<int64_t> flat_attention_mask(batch_size * max_seq_length_, 0);
718-
std::vector<int64_t> flat_token_type_ids(batch_size * max_seq_length_, 0);
775+
std::vector<std::vector<int64_t>> all_token_ids(count);
776+
size_t max_actual_len = 0;
777+
778+
for (size_t i = 0; i < count; ++i) {
779+
all_token_ids[i] = tokenizer_.encode_unpadded(texts[offset + i], max_seq_length_);
780+
max_actual_len = std::max(max_actual_len, all_token_ids[i].size());
781+
}
719782

720-
std::vector<std::vector<int64_t>> attention_masks(batch_size);
783+
const size_t pad_length = align_up(max_actual_len, 8);
721784

722-
for (size_t i = 0; i < batch_size; ++i) {
723-
auto token_ids = tokenizer_.encode(texts[i], max_seq_length_);
724-
auto attn_mask = tokenizer_.create_attention_mask(token_ids);
785+
LOGI("Sub-batch dynamic padding: max_actual=%zu, pad_length=%zu (was %zu)",
786+
max_actual_len, pad_length, max_seq_length_);
725787

726-
std::memcpy(flat_input_ids.data() + i * max_seq_length_,
727-
token_ids.data(), max_seq_length_ * sizeof(int64_t));
728-
std::memcpy(flat_attention_mask.data() + i * max_seq_length_,
729-
attn_mask.data(), max_seq_length_ * sizeof(int64_t));
788+
std::vector<int64_t> flat_input_ids(count * pad_length, 0);
789+
std::vector<int64_t> flat_attention_mask(count * pad_length, 0);
790+
std::vector<int64_t> flat_token_type_ids(count * pad_length, 0);
791+
792+
std::vector<std::vector<int64_t>> attention_masks(count);
793+
794+
for (size_t i = 0; i < count; ++i) {
795+
tokenizer_.pad_to(all_token_ids[i], pad_length);
796+
auto attn_mask = tokenizer_.create_attention_mask(all_token_ids[i]);
797+
798+
std::memcpy(flat_input_ids.data() + i * pad_length,
799+
all_token_ids[i].data(), pad_length * sizeof(int64_t));
800+
std::memcpy(flat_attention_mask.data() + i * pad_length,
801+
attn_mask.data(), pad_length * sizeof(int64_t));
730802

731803
attention_masks[i] = std::move(attn_mask);
732804
}
733805

734-
// 2. Create tensors with batch shape {batch_size, max_seq_length_}
735806
std::vector<int64_t> batch_shape = {
736-
static_cast<int64_t>(batch_size),
737-
static_cast<int64_t>(max_seq_length_)
807+
static_cast<int64_t>(count),
808+
static_cast<int64_t>(pad_length)
738809
};
739810

740811
OrtStatusGuard status_guard(ort_api_);
@@ -745,40 +816,39 @@ class ONNXEmbeddingProvider::Impl {
745816
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
746817
memory_info_,
747818
flat_input_ids.data(),
748-
batch_size * max_seq_length_ * sizeof(int64_t),
819+
count * pad_length * sizeof(int64_t),
749820
batch_shape.data(), batch_shape.size(),
750821
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
751822
input_ids_guard.ptr()));
752823
if (status_guard.is_error()) {
753-
LOGE("Batch: CreateTensor (input_ids) failed: %s", status_guard.error_message());
824+
LOGE("Sub-batch: CreateTensor (input_ids) failed: %s", status_guard.error_message());
754825
return {};
755826
}
756827

757828
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
758829
memory_info_,
759830
flat_attention_mask.data(),
760-
batch_size * max_seq_length_ * sizeof(int64_t),
831+
count * pad_length * sizeof(int64_t),
761832
batch_shape.data(), batch_shape.size(),
762833
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
763834
attention_mask_guard.ptr()));
764835
if (status_guard.is_error()) {
765-
LOGE("Batch: CreateTensor (attention_mask) failed: %s", status_guard.error_message());
836+
LOGE("Sub-batch: CreateTensor (attention_mask) failed: %s", status_guard.error_message());
766837
return {};
767838
}
768839

769840
status_guard.reset(ort_api_->CreateTensorWithDataAsOrtValue(
770841
memory_info_,
771842
flat_token_type_ids.data(),
772-
batch_size * max_seq_length_ * sizeof(int64_t),
843+
count * pad_length * sizeof(int64_t),
773844
batch_shape.data(), batch_shape.size(),
774845
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
775846
token_type_ids_guard.ptr()));
776847
if (status_guard.is_error()) {
777-
LOGE("Batch: CreateTensor (token_type_ids) failed: %s", status_guard.error_message());
848+
LOGE("Sub-batch: CreateTensor (token_type_ids) failed: %s", status_guard.error_message());
778849
return {};
779850
}
780851

781-
// 3. Run ONNX once for the entire batch
782852
const char* input_names[] = {"input_ids", "attention_mask", "token_type_ids"};
783853
const OrtValue* inputs[] = {
784854
input_ids_guard.get(), attention_mask_guard.get(), token_type_ids_guard.get()
@@ -793,22 +863,21 @@ class ONNXEmbeddingProvider::Impl {
793863
output_names, 1,
794864
&output_ptr));
795865
if (status_guard.is_error()) {
796-
LOGE("Batch ONNX inference failed: %s", status_guard.error_message());
866+
LOGE("Sub-batch ONNX inference failed: %s", status_guard.error_message());
797867
return {};
798868
}
799869
*output_guard.ptr() = output_ptr;
800870

801-
// 4. Extract output and get hidden_dim from shape
802871
float* output_data = nullptr;
803872
OrtStatusGuard data_status(ort_api_);
804873
data_status.reset(ort_api_->GetTensorMutableData(output_guard.get(), (void**)&output_data));
805874
if (data_status.is_error() || output_data == nullptr) {
806-
LOGE("Batch: Failed to get output tensor data");
875+
LOGE("Sub-batch: Failed to get output tensor data");
807876
return {};
808877
}
809878

810879
size_t actual_hidden_dim = embedding_dim_;
811-
size_t actual_seq_len = max_seq_length_;
880+
size_t actual_seq_len = pad_length; // Default to what we sent
812881
OrtTensorTypeAndShapeInfo* shape_info = nullptr;
813882
OrtStatusGuard shape_status(ort_api_);
814883
shape_status.reset(ort_api_->GetTensorTypeAndShape(output_guard.get(), &shape_info));
@@ -829,11 +898,10 @@ class ONNXEmbeddingProvider::Impl {
829898
ort_api_->ReleaseTensorTypeAndShapeInfo(shape_info);
830899
}
831900

832-
// 5. Extract per-sentence embeddings via mean pooling + normalize
833-
std::vector<std::vector<float>> results(batch_size);
901+
std::vector<std::vector<float>> results(count);
834902
const size_t stride = actual_seq_len * actual_hidden_dim;
835903

836-
for (size_t i = 0; i < batch_size; ++i) {
904+
for (size_t i = 0; i < count; ++i) {
837905
const float* sentence_data = output_data + i * stride;
838906
auto pooled = mean_pooling(
839907
sentence_data, attention_masks[i],
@@ -842,24 +910,14 @@ class ONNXEmbeddingProvider::Impl {
842910
results[i] = std::move(pooled);
843911
}
844912

845-
LOGI("Generated batch embeddings: count=%zu, dim=%zu", batch_size, actual_hidden_dim);
846913
return results;
847914

848915
} catch (const std::exception& e) {
849-
LOGE("Batch embedding generation failed: %s", e.what());
850-
return std::vector<std::vector<float>>(batch_size, std::vector<float>(embedding_dim_, 0.0f));
916+
LOGE("Sub-batch embedding failed: %s", e.what());
917+
return {};
851918
}
852919
}
853920

854-
size_t dimension() const noexcept {
855-
return embedding_dim_;
856-
}
857-
858-
bool is_ready() const noexcept {
859-
return ready_;
860-
}
861-
862-
private:
863921
bool initialize_onnx_runtime() {
864922
const OrtApiBase* ort_api_base = OrtGetApiBase();
865923
const char* ort_version = ort_api_base ? ort_api_base->GetVersionString() : "unknown";

sdk/runanywhere-commons/src/features/rag/rac_onnx_embeddings_register.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ static rac_result_t onnx_embed_vtable_embed_batch(void* impl, const char* const*
106106
}
107107

108108
auto batch_results = h->provider->embed_batch(texts_vec);
109+
if (batch_results.size() != num_texts) {
110+
RAC_LOG_ERROR(LOG_CAT, "Batch embedding returned %zu results, expected %zu",
111+
batch_results.size(), num_texts);
112+
return RAC_ERROR_INFERENCE_FAILED;
113+
}
109114

110115
size_t dim = h->provider->dimension();
111116
out_result->num_embeddings = num_texts;

sdk/runanywhere-commons/src/features/rag/rag_backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace rag {
3131
struct RAGBackendConfig {
3232
size_t embedding_dimension = 384;
3333
size_t top_k = 10;
34-
float similarity_threshold = 0.15f;
34+
float similarity_threshold = 0.12f;
3535
size_t max_context_tokens = 2048;
3636
size_t chunk_size = 180;
3737
size_t chunk_overlap = 30;

sdk/runanywhere-commons/src/features/rag/vector_store_usearch.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ class VectorStoreUSearch::Impl {
5656
usearch_config.expansion_add = config.expansion_add;
5757
usearch_config.expansion_search = config.expansion_search;
5858

59-
// Create metric for cosine similarity. Using i8 instead of float to save on RAM(quality isnt affected much)
59+
// Create metric for cosine similarity. Quantize further for RAM, switch to f32 for precision
6060
metric_punned_t metric(
6161
static_cast<std::size_t>(config.dimension),
6262
metric_kind_t::cos_k,
63-
scalar_kind_t::i8_k
63+
scalar_kind_t::f16_k
6464
);
6565

6666
// Create index
@@ -73,7 +73,7 @@ class VectorStoreUSearch::Impl {
7373

7474
// Reserve capacity
7575
index_.reserve(config.max_elements);
76-
LOGI("Created vector store: dim=%zu, max=%zu, connectivity=%zu, quantization=i8",
76+
LOGI("Created vector store: dim=%zu, max=%zu, connectivity=%zu, quantization=f16",
7777
config.dimension, config.max_elements, config.connectivity);
7878
}
7979

@@ -135,6 +135,7 @@ class VectorStoreUSearch::Impl {
135135
LOGE("Failed to add chunk to batch: %s", add_result.error.what());
136136
continue;
137137
}
138+
// Store metadata
138139
DocumentChunk metadata_copy = chunk;
139140
metadata_copy.embedding.clear();
140141
metadata_copy.embedding.shrink_to_fit();

0 commit comments

Comments
 (0)