Skip to content

Commit 7da40e1

Browse files
committed
Add auto detecting encoding by model file extension.
1 parent 9fc0848 commit 7da40e1

1 file changed

Lines changed: 87 additions & 18 deletions

File tree

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION
3535
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
3636
#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
37+
#define TFLITE_MODEL_FILE_EXT ".tflite"
38+
#define ONNX_MODEL_FILE_EXT ".onnx"
3739

3840
/* Global variables */
3941
static korp_mutex wasi_nn_lock;
@@ -569,12 +571,25 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len,
569571
return success;
570572
}
571573

574+
static bool
575+
ends_with(const char *str, const char *suffix)
576+
{
577+
if (!str || !suffix)
578+
return false;
579+
uint32_t lenstr = strlen(str);
580+
uint32_t lensuf = strlen(suffix);
581+
if (lensuf > lenstr)
582+
return false;
583+
return strcmp(str + lenstr - lensuf, suffix) == 0;
584+
}
585+
572586
wasi_nn_error
573587
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
574588
graph *g)
575589
{
576590
WASINNContext *wasi_nn_ctx = NULL;
577591
char *nul_terminated_name = NULL;
592+
graph_encoding encoding = unknown_backend;
578593
wasi_nn_error res;
579594

580595
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
@@ -602,28 +617,54 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
602617
goto fail;
603618
}
604619

605-
for (int i = 0; i < auto_detect_encoding_num; i++) {
620+
if (ends_with(nul_terminated_name, TFLITE_MODEL_FILE_EXT)) {
621+
encoding = tensorflowlite;
622+
}
623+
else if (ends_with(nul_terminated_name, ONNX_MODEL_FILE_EXT)) {
624+
encoding = onnx;
625+
}
626+
if (encoding == unknown_backend) {
627+
for (int i = 0; i < auto_detect_encoding_num; i++) {
628+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
629+
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
630+
wasi_nn_ctx->backend_ctx);
631+
}
632+
633+
res = ensure_backend(instance, auto_detect_encoding_order[i],
634+
wasi_nn_ctx);
635+
if (res != success) {
636+
NN_ERR_PRINTF("continue trying the next");
637+
continue;
638+
}
639+
640+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
641+
wasi_nn_ctx->backend_ctx, nul_terminated_name,
642+
name_len, g);
643+
if (res != success) {
644+
NN_ERR_PRINTF("continue trying the next");
645+
continue;
646+
}
647+
648+
break;
649+
}
650+
}
651+
else {
606652
if (wasi_nn_ctx->is_backend_ctx_initialized) {
607653
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
608654
wasi_nn_ctx->backend_ctx);
609655
}
610656

611-
res = ensure_backend(instance, auto_detect_encoding_order[i],
612-
wasi_nn_ctx);
657+
res = ensure_backend(instance, encoding, wasi_nn_ctx);
613658
if (res != success) {
614-
NN_ERR_PRINTF("continue trying the next");
615-
continue;
659+
goto fail;
616660
}
617661

618662
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
619663
wasi_nn_ctx->backend_ctx, nul_terminated_name,
620664
name_len, g);
621665
if (res != success) {
622-
NN_ERR_PRINTF("continue trying the next");
623-
continue;
666+
goto fail;
624667
}
625-
626-
break;
627668
}
628669

629670
fail:
@@ -644,6 +685,7 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
644685
WASINNContext *wasi_nn_ctx = NULL;
645686
char *nul_terminated_name = NULL;
646687
char *nul_terminated_config = NULL;
688+
graph_encoding encoding = unknown_backend;
647689
wasi_nn_error res;
648690

649691
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
@@ -677,28 +719,55 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
677719
goto fail;
678720
}
679721

680-
for (int i = 0; i < auto_detect_encoding_num; i++) {
722+
if (ends_with(nul_terminated_name, TFLITE_MODEL_FILE_EXT)) {
723+
encoding = tensorflowlite;
724+
}
725+
else if (ends_with(nul_terminated_name, ONNX_MODEL_FILE_EXT)) {
726+
encoding = onnx;
727+
}
728+
if (encoding == unknown_backend) {
729+
for (int i = 0; i < auto_detect_encoding_num; i++) {
730+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
731+
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
732+
wasi_nn_ctx->backend_ctx);
733+
}
734+
735+
res = ensure_backend(instance, auto_detect_encoding_order[i],
736+
wasi_nn_ctx);
737+
if (res != success) {
738+
NN_ERR_PRINTF("continue trying the next");
739+
continue;
740+
}
741+
742+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config,
743+
res, wasi_nn_ctx->backend_ctx,
744+
nul_terminated_name, name_len,
745+
nul_terminated_config, config_len, g);
746+
if (res != success) {
747+
NN_ERR_PRINTF("continue trying the next");
748+
continue;
749+
}
750+
751+
break;
752+
}
753+
}
754+
else {
681755
if (wasi_nn_ctx->is_backend_ctx_initialized) {
682756
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
683757
wasi_nn_ctx->backend_ctx);
684758
}
685759

686-
res = ensure_backend(instance, auto_detect_encoding_order[i],
687-
wasi_nn_ctx);
760+
res = ensure_backend(instance, encoding, wasi_nn_ctx);
688761
if (res != success) {
689-
NN_ERR_PRINTF("continue trying the next");
690-
continue;
762+
goto fail;
691763
}
692764

693765
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
694766
wasi_nn_ctx->backend_ctx, nul_terminated_name,
695767
name_len, nul_terminated_config, config_len, g);
696768
if (res != success) {
697-
NN_ERR_PRINTF("continue trying the next");
698-
continue;
769+
goto fail;
699770
}
700-
701-
break;
702771
}
703772

704773
fail:

0 commit comments

Comments
 (0)