Skip to content

Commit bacb277

Browse files
AmanSwarsanchitmonga22
authored andcommitted
fix: address CodeRabbit review — stop-sequence, format specifiers, race condition, token counts
- Stop-sequence sliding window (llamacpp_backend.cpp): port the Utf8State/stop_window approach from generate_stream to the timing variant generate_stream_with_timing, matching the non-timing variant's behavior exactly. - PRId32 format (rac_benchmark_log.cpp): change %d to PRId32 for all int32_t fields to match the PRId64 convention for int64_t fields. - Mutex guard (rac_benchmark_metrics.cpp): add static std::mutex around the double-buffer write path in rac_benchmark_set_metrics_provider to prevent torn fn/user_data pairs. Reader side remains lock-free. - Actual token counts (llamacpp_backend.cpp + llm_component.cpp): write tokens_generated to timing_out->output_tokens in the backend; read backend-populated prompt_tokens/output_tokens in the component layer instead of overwriting with estimate_tokens() heuristics.
1 parent c2acc46 commit bacb277

4 files changed

Lines changed: 85 additions & 38 deletions

File tree

sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -900,10 +900,27 @@ bool LlamaCppTextGeneration::generate_stream_with_timing(const TextGenerationReq
900900
llama_sampler_reset(sampler_);
901901

902902
const auto vocab = llama_model_get_vocab(model_);
903-
std::string cached_token_chars;
904-
std::string accumulated_text;
903+
904+
static const std::vector<std::string> STOP_SEQUENCES = {
905+
"<|im_end|>", "<|eot_id|>", "</s>", "<|end|>", "<|endoftext|>",
906+
"\n\nUser:", "\n\nHuman:",
907+
};
908+
909+
static const size_t MAX_STOP_LEN = []{
910+
size_t m = 0;
911+
for (const auto& s : STOP_SEQUENCES) m = std::max(m, s.size());
912+
return m;
913+
}();
914+
915+
std::string stop_window;
916+
stop_window.reserve(MAX_STOP_LEN * 2);
917+
918+
std::string partial_utf8_buffer;
919+
partial_utf8_buffer.reserve(8);
920+
905921
int n_cur = batch.n_tokens;
906922
int tokens_generated = 0;
923+
bool stop_sequence_hit = false;
907924

908925
while (tokens_generated < effective_max_tokens && !cancel_requested_.load()) {
909926
const llama_token new_token_id = llama_sampler_sample(sampler_, context_, -1);
@@ -915,41 +932,55 @@ bool LlamaCppTextGeneration::generate_stream_with_timing(const TextGenerationReq
915932
break;
916933
}
917934

918-
auto new_token_chars = common_token_to_piece(context_, new_token_id);
919-
cached_token_chars += new_token_chars;
920-
accumulated_text += new_token_chars;
921-
922-
static const std::vector<std::string> stop_sequences = {
923-
"<|im_end|>",
924-
"<|eot_id|>",
925-
"</s>",
926-
"<|end|>",
927-
"<|endoftext|>",
928-
"\n\nUser:",
929-
"\n\nHuman:",
930-
};
935+
const std::string new_token_chars =
936+
common_token_to_piece(context_, new_token_id);
931937

932-
bool hit_stop_sequence = false;
933-
for (const auto& stop_seq : stop_sequences) {
934-
size_t pos = accumulated_text.find(stop_seq);
935-
if (pos != std::string::npos) {
936-
LOGI("Stop sequence detected: %s", stop_seq.c_str());
937-
hit_stop_sequence = true;
938-
break;
938+
partial_utf8_buffer.append(new_token_chars);
939+
940+
Utf8State scanner_state;
941+
size_t valid_upto = 0;
942+
for (size_t i = 0; i < partial_utf8_buffer.size(); ++i) {
943+
scanner_state.process(static_cast<uint8_t>(partial_utf8_buffer[i]));
944+
if (scanner_state.state == 0) {
945+
valid_upto = i + 1;
939946
}
940947
}
941948

942-
if (hit_stop_sequence) {
943-
break;
944-
}
949+
if (valid_upto > 0) {
950+
std::string valid_chunk = partial_utf8_buffer.substr(0, valid_upto);
951+
stop_window.append(valid_chunk);
952+
partial_utf8_buffer.erase(0, valid_upto);
945953

946-
if (is_valid_utf8(cached_token_chars.c_str())) {
947-
if (!callback(cached_token_chars)) {
948-
LOGI("Generation cancelled by callback");
949-
cancel_requested_.store(true);
954+
size_t found_stop_pos = std::string::npos;
955+
for (const auto& stop_seq : STOP_SEQUENCES) {
956+
size_t pos = stop_window.find(stop_seq);
957+
if (pos != std::string::npos) {
958+
if (found_stop_pos == std::string::npos || pos < found_stop_pos) {
959+
found_stop_pos = pos;
960+
}
961+
}
962+
}
963+
964+
if (found_stop_pos != std::string::npos) {
965+
LOGI("Stop sequence detected");
966+
stop_sequence_hit = true;
967+
if (found_stop_pos > 0) {
968+
if (!callback(stop_window.substr(0, found_stop_pos))) {
969+
cancel_requested_.store(true);
970+
}
971+
}
950972
break;
951973
}
952-
cached_token_chars.clear();
974+
975+
if (stop_window.size() > MAX_STOP_LEN) {
976+
size_t safe_len = stop_window.size() - MAX_STOP_LEN;
977+
if (!callback(stop_window.substr(0, safe_len))) {
978+
LOGI("Generation cancelled by callback");
979+
cancel_requested_.store(true);
980+
break;
981+
}
982+
stop_window.erase(0, safe_len);
983+
}
953984
}
954985

955986
batch.n_tokens = 0;
@@ -967,10 +998,11 @@ bool LlamaCppTextGeneration::generate_stream_with_timing(const TextGenerationReq
967998
// t5: Record last token time (decode loop exit)
968999
if (timing_out != nullptr) {
9691000
timing_out->t5_last_token_ms = rac_monotonic_now_ms();
1001+
timing_out->output_tokens = static_cast<int32_t>(tokens_generated);
9701002
}
9711003

972-
if (!cached_token_chars.empty() && is_valid_utf8(cached_token_chars.c_str())) {
973-
callback(cached_token_chars);
1004+
if (!cancel_requested_.load() && !stop_sequence_hit && !stop_window.empty()) {
1005+
callback(stop_window);
9741006
}
9751007

9761008
llama_memory_clear(llama_get_memory(context_), true);

sdk/runanywhere-commons/src/core/rac_benchmark_log.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ char* rac_benchmark_timing_to_csv(const rac_benchmark_timing_t* timing, rac_bool
117117
char buf[512];
118118
snprintf(buf, sizeof(buf),
119119
"%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64
120-
",%d,%d,%d,%d,%.2f,%.2f,%.2f,%.2f,%.2f",
120+
",%" PRId32 ",%" PRId32 ",%" PRId32 ",%" PRId32 ",%.2f,%.2f,%.2f,%.2f,%.2f",
121121
timing->t0_request_start_ms, timing->t2_prefill_start_ms,
122122
timing->t3_prefill_end_ms, timing->t4_first_token_ms, timing->t5_last_token_ms,
123123
timing->t6_request_end_ms, timing->prompt_tokens, timing->output_tokens,

sdk/runanywhere-commons/src/core/rac_benchmark_metrics.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <atomic>
1313
#include <cstring>
14+
#include <mutex>
1415

1516
namespace {
1617

@@ -45,12 +46,15 @@ void rac_benchmark_extended_metrics_init(rac_benchmark_extended_metrics_t* metri
4546

4647
void rac_benchmark_set_metrics_provider(rac_benchmark_metrics_provider_fn provider,
4748
void* user_data) {
49+
static std::mutex write_mutex;
50+
4851
if (provider == nullptr) {
4952
g_provider.store(nullptr, std::memory_order_release);
5053
return;
5154
}
5255

53-
// Use double-buffering to avoid data races on the provider struct
56+
// Serialize the rare registration path to prevent torn fn/user_data pairs
57+
std::lock_guard<std::mutex> lock(write_mutex);
5458
int idx = g_provider_index.load(std::memory_order_relaxed);
5559
int next = 1 - idx;
5660
g_provider_storage[next].fn = provider;

sdk/runanywhere-commons/src/features/llm/llm_component.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,20 @@ extern "C" rac_result_t rac_llm_component_generate_stream_with_timing(
947947

948948
rac_llm_result_t final_result = {};
949949
final_result.text = strdup(ctx.full_text.c_str());
950-
final_result.prompt_tokens = ctx.prompt_tokens;
951-
final_result.completion_tokens = estimate_tokens(ctx.full_text.c_str());
950+
951+
// Use actual backend token counts if available, fall back to estimates
952+
if (timing_out != nullptr && timing_out->prompt_tokens > 0) {
953+
final_result.prompt_tokens = timing_out->prompt_tokens;
954+
} else {
955+
final_result.prompt_tokens = ctx.prompt_tokens;
956+
}
957+
958+
if (timing_out != nullptr && timing_out->output_tokens > 0) {
959+
final_result.completion_tokens = timing_out->output_tokens;
960+
} else {
961+
final_result.completion_tokens = estimate_tokens(ctx.full_text.c_str());
962+
}
963+
952964
final_result.total_tokens = final_result.prompt_tokens + final_result.completion_tokens;
953965
final_result.total_time_ms = total_time_ms;
954966

@@ -972,8 +984,7 @@ extern "C" rac_result_t rac_llm_component_generate_stream_with_timing(
972984
// Record t6 (request end) before complete callback
973985
if (timing_out != nullptr) {
974986
timing_out->t6_request_end_ms = rac_monotonic_now_ms();
975-
timing_out->prompt_tokens = final_result.prompt_tokens;
976-
timing_out->output_tokens = final_result.completion_tokens;
987+
// prompt_tokens and output_tokens already set by backend
977988
timing_out->status = RAC_BENCHMARK_STATUS_SUCCESS;
978989
timing_out->error_code = RAC_SUCCESS;
979990
}

0 commit comments

Comments
 (0)