Skip to content

Commit 3fafd41

Browse files
fix(android): support String/ByteArray onToken callbacks in JNI streaming (#376)
* fix(android): support String and ByteArray JNI stream callbacks Handle both onToken signatures in JNI streaming callback path and avoid stale callback/context usage by waiting for completion before releasing JNI global refs. Also add native lib version marker checks in downloadJniLibs to prevent stale JNI binary reuse across version changes. Refs #353 * fix(android): guard JNI stream token accumulation with mutex * fix(android): ignore onToken return when JNI exception occurs * fix(android): add timeout for JNI stream completion wait Replace unbounded condition variable waits with wait_for(10m) and propagate timeout as stream error to avoid indefinite JNI thread blocking.
1 parent 3e31dfc commit 3fafd41

2 files changed

Lines changed: 88 additions & 24 deletions

File tree

sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <jni.h>
1818

1919
#include <condition_variable>
20+
#include <chrono>
2021
#include <cstring>
2122
#include <mutex>
2223
#include <string>
@@ -679,6 +680,9 @@ struct LLMStreamCallbackContext {
679680
JavaVM* jvm = nullptr;
680681
jobject callback = nullptr;
681682
jmethodID onTokenMethod = nullptr;
683+
bool onTokenExpectsBytes = true;
684+
std::mutex mtx;
685+
std::condition_variable cv;
682686
std::string accumulated_text;
683687
int token_count = 0;
684688
bool is_complete = false;
@@ -694,9 +698,12 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
694698

695699
auto* ctx = static_cast<LLMStreamCallbackContext*>(user_data);
696700

697-
// Accumulate token
698-
ctx->accumulated_text += token;
699-
ctx->token_count++;
701+
// Accumulate token (thread-safe)
702+
{
703+
std::lock_guard<std::mutex> lock(ctx->mtx);
704+
ctx->accumulated_text += token;
705+
ctx->token_count++;
706+
}
700707

701708
// Call back to Kotlin
702709
if (ctx->jvm && ctx->callback && ctx->onTokenMethod) {
@@ -714,21 +721,27 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
714721
}
715722

716723
if (env) {
717-
jsize len = static_cast<jsize>(strlen(token));
718-
719-
jbyteArray jToken = env->NewByteArray(len);
720-
env->SetByteArrayRegion(
721-
jToken,
722-
0,
723-
len,
724-
reinterpret_cast<const jbyte*>(token)
725-
);
726-
727-
jboolean continueGen =
728-
env->CallBooleanMethod(ctx->callback, ctx->onTokenMethod, jToken);
729-
env->DeleteLocalRef(jToken);
724+
jboolean continueGen = JNI_TRUE;
725+
726+
if (ctx->onTokenExpectsBytes) {
727+
jsize len = static_cast<jsize>(strlen(token));
728+
jbyteArray jToken = env->NewByteArray(len);
729+
env->SetByteArrayRegion(
730+
jToken,
731+
0,
732+
len,
733+
reinterpret_cast<const jbyte*>(token)
734+
);
735+
continueGen = env->CallBooleanMethod(ctx->callback, ctx->onTokenMethod, jToken);
736+
env->DeleteLocalRef(jToken);
737+
} else {
738+
jstring jToken = env->NewStringUTF(token);
739+
continueGen = env->CallBooleanMethod(ctx->callback, ctx->onTokenMethod, jToken);
740+
env->DeleteLocalRef(jToken);
741+
}
730742

731-
if (env->ExceptionCheck()) {
743+
const bool hadException = env->ExceptionCheck();
744+
if (hadException) {
732745
env->ExceptionDescribe();
733746
env->ExceptionClear();
734747
}
@@ -737,6 +750,11 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
737750
ctx->jvm->DetachCurrentThread();
738751
}
739752

753+
if (hadException) {
754+
// Ignore callback return value when JNI exception was thrown.
755+
return RAC_TRUE;
756+
}
757+
740758
if (!continueGen) {
741759
LOGi("Streaming cancelled by callback");
742760
return RAC_FALSE; // Stop streaming
@@ -752,6 +770,7 @@ static void llm_stream_callback_complete(const rac_llm_result_t* result, void* u
752770
return;
753771

754772
auto* ctx = static_cast<LLMStreamCallbackContext*>(user_data);
773+
std::lock_guard<std::mutex> lock(ctx->mtx);
755774

756775
LOGi("Streaming with callback complete: %d tokens", ctx->token_count);
757776

@@ -767,6 +786,7 @@ static void llm_stream_callback_complete(const rac_llm_result_t* result, void* u
767786
}
768787

769788
ctx->is_complete = true;
789+
ctx->cv.notify_one();
770790
}
771791

772792
static void llm_stream_callback_error(rac_result_t error_code, const char* error_message,
@@ -775,6 +795,7 @@ static void llm_stream_callback_error(rac_result_t error_code, const char* error
775795
return;
776796

777797
auto* ctx = static_cast<LLMStreamCallbackContext*>(user_data);
798+
std::lock_guard<std::mutex> lock(ctx->mtx);
778799

779800
LOGe("Streaming with callback error: %d - %s", error_code,
780801
error_message ? error_message : "Unknown");
@@ -783,6 +804,7 @@ static void llm_stream_callback_error(rac_result_t error_code, const char* error
783804
ctx->error_code = error_code;
784805
ctx->error_message = error_message ? error_message : "Unknown error";
785806
ctx->is_complete = true;
807+
ctx->cv.notify_one();
786808
}
787809

788810
JNIEXPORT jstring JNICALL
@@ -847,7 +869,12 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
847869
// Wait for streaming to complete
848870
{
849871
std::unique_lock<std::mutex> lock(ctx.mtx);
850-
ctx.cv.wait(lock, [&ctx] { return ctx.is_complete; });
872+
constexpr auto kStreamWaitTimeout = std::chrono::minutes(10);
873+
if (!ctx.cv.wait_for(lock, kStreamWaitTimeout, [&ctx] { return ctx.is_complete; })) {
874+
ctx.has_error = true;
875+
ctx.error_message = "Streaming timed out waiting for completion callback";
876+
ctx.is_complete = true;
877+
}
851878
}
852879

853880
if (ctx.has_error) {
@@ -904,7 +931,13 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
904931
env->GetJavaVM(&jvm);
905932

906933
jclass callbackClass = env->GetObjectClass(tokenCallback);
934+
bool onTokenExpectsBytes = true;
907935
jmethodID onTokenMethod = env->GetMethodID(callbackClass, "onToken", "([B)Z");
936+
if (!onTokenMethod) {
937+
env->ExceptionClear();
938+
onTokenMethod = env->GetMethodID(callbackClass, "onToken", "(Ljava/lang/String;)Z");
939+
onTokenExpectsBytes = false;
940+
}
908941

909942
if (!onTokenMethod) {
910943
LOGe("racLlmComponentGenerateStreamWithCallback: could not find onToken method");
@@ -948,21 +981,34 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
948981
ctx.jvm = jvm;
949982
ctx.callback = globalCallback;
950983
ctx.onTokenMethod = onTokenMethod;
984+
ctx.onTokenExpectsBytes = onTokenExpectsBytes;
951985

952986
LOGi("racLlmComponentGenerateStreamWithCallback calling rac_llm_component_generate_stream...");
953987

954988
rac_result_t status = rac_llm_component_generate_stream(
955989
reinterpret_cast<rac_handle_t>(handle), promptStr.c_str(), &options,
956990
llm_stream_callback_token, llm_stream_callback_complete, llm_stream_callback_error, &ctx);
957991

958-
// Clean up global ref
959-
env->DeleteGlobalRef(globalCallback);
960-
961992
if (status != RAC_SUCCESS) {
993+
env->DeleteGlobalRef(globalCallback);
962994
LOGe("rac_llm_component_generate_stream failed with status=%d", status);
963995
return nullptr;
964996
}
965997

998+
// Wait until completion/error before releasing callback/context.
999+
{
1000+
std::unique_lock<std::mutex> lock(ctx.mtx);
1001+
constexpr auto kStreamWaitTimeout = std::chrono::minutes(10);
1002+
if (!ctx.cv.wait_for(lock, kStreamWaitTimeout, [&ctx] { return ctx.is_complete; })) {
1003+
ctx.has_error = true;
1004+
ctx.error_message = "Streaming timed out waiting for completion callback";
1005+
ctx.is_complete = true;
1006+
}
1007+
}
1008+
1009+
// Clean up global ref after callbacks have finished.
1010+
env->DeleteGlobalRef(globalCallback);
1011+
9661012
if (ctx.has_error) {
9671013
LOGe("Streaming failed: %s", ctx.error_message.c_str());
9681014
return nullptr;

sdk/runanywhere-kotlin/build.gradle.kts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ tasks.register("downloadJniLibs") {
399399
onlyIf { !testLocal }
400400

401401
val outputDir = file("src/androidMain/jniLibs")
402+
val nativeLibVersionMarker = file("$outputDir/.native_lib_version")
402403
val tempDir = file("${layout.buildDirectory.get()}/jni-temp")
403404

404405
val releaseBaseUrl = "https://github.com/RunanywhereAI/runanywhere-sdks/releases/download/v$nativeLibVersion"
@@ -424,12 +425,25 @@ tasks.register("downloadJniLibs") {
424425
return@doLast
425426
}
426427

428+
// Check if libs already exist (CI pre-populates build/jniLibs/).
429+
// Guard against stale libs from a different native version.
427430
val existingLibs = outputDir.walkTopDown().filter { it.extension == "so" }.count()
428-
if (existingLibs > 0) {
429-
logger.lifecycle("Skipping JNI download: $existingLibs .so files already in $outputDir")
431+
val existingVersion = nativeLibVersionMarker.takeIf { it.exists() }?.readText()?.trim()
432+
if (existingLibs > 0 && existingVersion == nativeLibVersion) {
433+
logger.lifecycle(
434+
"Skipping JNI download: $existingLibs .so files already in $outputDir " +
435+
"(native version v$nativeLibVersion)",
436+
)
430437
return@doLast
431438
}
439+
if (existingLibs > 0 && existingVersion != nativeLibVersion) {
440+
logger.lifecycle(
441+
"Refreshing JNI libs: found $existingLibs existing .so files " +
442+
"with version '${existingVersion ?: "unknown"}', expected '$nativeLibVersion'",
443+
)
444+
}
432445

446+
// Clean output directories for a fresh download
433447
outputDir.deleteRecursively()
434448
tempDir.deleteRecursively()
435449
outputDir.mkdirs()
@@ -497,6 +511,11 @@ tasks.register("downloadJniLibs") {
497511
logger.lifecycle(" Output: $outputDir")
498512
logger.lifecycle("═══════════════════════════════════════════════════════════════")
499513

514+
// Record native lib version to avoid reusing stale JNI binaries.
515+
nativeLibVersionMarker.parentFile.mkdirs()
516+
nativeLibVersionMarker.writeText(nativeLibVersion)
517+
518+
// List libraries per ABI
500519
abiDirs.forEach { abi ->
501520
val libs = file("$outputDir/$abi").listFiles()?.filter { it.extension == "so" }?.map { it.name } ?: emptyList()
502521
logger.lifecycle("$abi (${libs.size} libs):")
@@ -662,4 +681,3 @@ tasks.withType<PublishToMavenRepository>().configureEach {
662681
!dominated
663682
}
664683
}
665-

0 commit comments

Comments
 (0)