1717#include < jni.h>
1818
1919#include < condition_variable>
20+ #include < chrono>
2021#include < cstring>
2122#include < mutex>
2223#include < string>
@@ -679,6 +680,9 @@ struct LLMStreamCallbackContext {
679680 JavaVM* jvm = nullptr ;
680681 jobject callback = nullptr ;
681682 jmethodID onTokenMethod = nullptr ;
683+ bool onTokenExpectsBytes = true ;
684+ std::mutex mtx;
685+ std::condition_variable cv;
682686 std::string accumulated_text;
683687 int token_count = 0 ;
684688 bool is_complete = false ;
@@ -694,9 +698,12 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
694698
695699 auto * ctx = static_cast <LLMStreamCallbackContext*>(user_data);
696700
697- // Accumulate token
698- ctx->accumulated_text += token;
699- ctx->token_count ++;
701+ // Accumulate token (thread-safe)
702+ {
703+ std::lock_guard<std::mutex> lock (ctx->mtx );
704+ ctx->accumulated_text += token;
705+ ctx->token_count ++;
706+ }
700707
701708 // Call back to Kotlin
702709 if (ctx->jvm && ctx->callback && ctx->onTokenMethod ) {
@@ -714,21 +721,27 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
714721 }
715722
716723 if (env) {
717- jsize len = static_cast <jsize>(strlen (token));
718-
719- jbyteArray jToken = env->NewByteArray (len);
720- env->SetByteArrayRegion (
721- jToken,
722- 0 ,
723- len,
724- reinterpret_cast <const jbyte*>(token)
725- );
726-
727- jboolean continueGen =
728- env->CallBooleanMethod (ctx->callback , ctx->onTokenMethod , jToken);
729- env->DeleteLocalRef (jToken);
724+ jboolean continueGen = JNI_TRUE;
725+
726+ if (ctx->onTokenExpectsBytes ) {
727+ jsize len = static_cast <jsize>(strlen (token));
728+ jbyteArray jToken = env->NewByteArray (len);
729+ env->SetByteArrayRegion (
730+ jToken,
731+ 0 ,
732+ len,
733+ reinterpret_cast <const jbyte*>(token)
734+ );
735+ continueGen = env->CallBooleanMethod (ctx->callback , ctx->onTokenMethod , jToken);
736+ env->DeleteLocalRef (jToken);
737+ } else {
738+ jstring jToken = env->NewStringUTF (token);
739+ continueGen = env->CallBooleanMethod (ctx->callback , ctx->onTokenMethod , jToken);
740+ env->DeleteLocalRef (jToken);
741+ }
730742
731- if (env->ExceptionCheck ()) {
743+ const bool hadException = env->ExceptionCheck ();
744+ if (hadException) {
732745 env->ExceptionDescribe ();
733746 env->ExceptionClear ();
734747 }
@@ -737,6 +750,11 @@ static rac_bool_t llm_stream_callback_token(const char* token, void* user_data)
737750 ctx->jvm ->DetachCurrentThread ();
738751 }
739752
753+ if (hadException) {
754+ // Ignore callback return value when JNI exception was thrown.
755+ return RAC_TRUE;
756+ }
757+
740758 if (!continueGen) {
741759 LOGi (" Streaming cancelled by callback" );
742760 return RAC_FALSE; // Stop streaming
@@ -752,6 +770,7 @@ static void llm_stream_callback_complete(const rac_llm_result_t* result, void* u
752770 return ;
753771
754772 auto * ctx = static_cast <LLMStreamCallbackContext*>(user_data);
773+ std::lock_guard<std::mutex> lock (ctx->mtx );
755774
756775 LOGi (" Streaming with callback complete: %d tokens" , ctx->token_count );
757776
@@ -767,6 +786,7 @@ static void llm_stream_callback_complete(const rac_llm_result_t* result, void* u
767786 }
768787
769788 ctx->is_complete = true ;
789+ ctx->cv .notify_one ();
770790}
771791
772792static void llm_stream_callback_error (rac_result_t error_code, const char * error_message,
@@ -775,6 +795,7 @@ static void llm_stream_callback_error(rac_result_t error_code, const char* error
775795 return ;
776796
777797 auto * ctx = static_cast <LLMStreamCallbackContext*>(user_data);
798+ std::lock_guard<std::mutex> lock (ctx->mtx );
778799
779800 LOGe (" Streaming with callback error: %d - %s" , error_code,
780801 error_message ? error_message : " Unknown" );
@@ -783,6 +804,7 @@ static void llm_stream_callback_error(rac_result_t error_code, const char* error
783804 ctx->error_code = error_code;
784805 ctx->error_message = error_message ? error_message : " Unknown error" ;
785806 ctx->is_complete = true ;
807+ ctx->cv .notify_one ();
786808}
787809
788810JNIEXPORT jstring JNICALL
@@ -847,7 +869,12 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
847869 // Wait for streaming to complete
848870 {
849871 std::unique_lock<std::mutex> lock (ctx.mtx );
850- ctx.cv .wait (lock, [&ctx] { return ctx.is_complete ; });
872+ constexpr auto kStreamWaitTimeout = std::chrono::minutes (10 );
873+ if (!ctx.cv .wait_for (lock, kStreamWaitTimeout , [&ctx] { return ctx.is_complete ; })) {
874+ ctx.has_error = true ;
875+ ctx.error_message = " Streaming timed out waiting for completion callback" ;
876+ ctx.is_complete = true ;
877+ }
851878 }
852879
853880 if (ctx.has_error ) {
@@ -904,7 +931,13 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
904931 env->GetJavaVM (&jvm);
905932
906933 jclass callbackClass = env->GetObjectClass (tokenCallback);
934+ bool onTokenExpectsBytes = true ;
907935 jmethodID onTokenMethod = env->GetMethodID (callbackClass, " onToken" , " ([B)Z" );
936+ if (!onTokenMethod) {
937+ env->ExceptionClear ();
938+ onTokenMethod = env->GetMethodID (callbackClass, " onToken" , " (Ljava/lang/String;)Z" );
939+ onTokenExpectsBytes = false ;
940+ }
908941
909942 if (!onTokenMethod) {
910943 LOGe (" racLlmComponentGenerateStreamWithCallback: could not find onToken method" );
@@ -948,21 +981,34 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
948981 ctx.jvm = jvm;
949982 ctx.callback = globalCallback;
950983 ctx.onTokenMethod = onTokenMethod;
984+ ctx.onTokenExpectsBytes = onTokenExpectsBytes;
951985
952986 LOGi (" racLlmComponentGenerateStreamWithCallback calling rac_llm_component_generate_stream..." );
953987
954988 rac_result_t status = rac_llm_component_generate_stream (
955989 reinterpret_cast <rac_handle_t >(handle), promptStr.c_str (), &options,
956990 llm_stream_callback_token, llm_stream_callback_complete, llm_stream_callback_error, &ctx);
957991
958- // Clean up global ref
959- env->DeleteGlobalRef (globalCallback);
960-
961992 if (status != RAC_SUCCESS) {
993+ env->DeleteGlobalRef (globalCallback);
962994 LOGe (" rac_llm_component_generate_stream failed with status=%d" , status);
963995 return nullptr ;
964996 }
965997
998+ // Wait until completion/error before releasing callback/context.
999+ {
1000+ std::unique_lock<std::mutex> lock (ctx.mtx );
1001+ constexpr auto kStreamWaitTimeout = std::chrono::minutes (10 );
1002+ if (!ctx.cv .wait_for (lock, kStreamWaitTimeout , [&ctx] { return ctx.is_complete ; })) {
1003+ ctx.has_error = true ;
1004+ ctx.error_message = " Streaming timed out waiting for completion callback" ;
1005+ ctx.is_complete = true ;
1006+ }
1007+ }
1008+
1009+ // Clean up global ref after callbacks have finished.
1010+ env->DeleteGlobalRef (globalCallback);
1011+
9661012 if (ctx.has_error ) {
9671013 LOGe (" Streaming failed: %s" , ctx.error_message .c_str ());
9681014 return nullptr ;
0 commit comments