3434#define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION
3535#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
3636#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
37+ #define TFLITE_MODEL_FILE_EXT ".tflite"
38+ #define ONNX_MODEL_FILE_EXT ".onnx"
3739
3840/* Global variables */
3941static korp_mutex wasi_nn_lock ;
@@ -569,12 +571,25 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len,
569571 return success ;
570572}
571573
574+ static bool
575+ ends_with (const char * str , const char * suffix )
576+ {
577+ if (!str || !suffix )
578+ return false;
579+ uint32_t lenstr = strlen (str );
580+ uint32_t lensuf = strlen (suffix );
581+ if (lensuf > lenstr )
582+ return false;
583+ return strcmp (str + lenstr - lensuf , suffix ) == 0 ;
584+ }
585+
572586wasi_nn_error
573587wasi_nn_load_by_name (wasm_exec_env_t exec_env , char * name , uint32_t name_len ,
574588 graph * g )
575589{
576590 WASINNContext * wasi_nn_ctx = NULL ;
577591 char * nul_terminated_name = NULL ;
592+ graph_encoding encoding = unknown_backend ;
578593 wasi_nn_error res ;
579594
580595 wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
@@ -602,28 +617,54 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
602617 goto fail ;
603618 }
604619
605- for (int i = 0 ; i < auto_detect_encoding_num ; i ++ ) {
620+ if (ends_with (nul_terminated_name , TFLITE_MODEL_FILE_EXT )) {
621+ encoding = tensorflowlite ;
622+ }
623+ else if (ends_with (nul_terminated_name , ONNX_MODEL_FILE_EXT )) {
624+ encoding = onnx ;
625+ }
626+ if (encoding == unknown_backend ) {
627+ for (int i = 0 ; i < auto_detect_encoding_num ; i ++ ) {
628+ if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
629+ call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
630+ wasi_nn_ctx -> backend_ctx );
631+ }
632+
633+ res = ensure_backend (instance , auto_detect_encoding_order [i ],
634+ wasi_nn_ctx );
635+ if (res != success ) {
636+ NN_ERR_PRINTF ("continue trying the next" );
637+ continue ;
638+ }
639+
640+ call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
641+ wasi_nn_ctx -> backend_ctx , nul_terminated_name ,
642+ name_len , g );
643+ if (res != success ) {
644+ NN_ERR_PRINTF ("continue trying the next" );
645+ continue ;
646+ }
647+
648+ break ;
649+ }
650+ }
651+ else {
606652 if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
607653 call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
608654 wasi_nn_ctx -> backend_ctx );
609655 }
610656
611- res = ensure_backend (instance , auto_detect_encoding_order [i ],
612- wasi_nn_ctx );
657+ res = ensure_backend (instance , encoding , wasi_nn_ctx );
613658 if (res != success ) {
614- NN_ERR_PRINTF ("continue trying the next" );
615- continue ;
659+ goto fail ;
616660 }
617661
618662 call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
619663 wasi_nn_ctx -> backend_ctx , nul_terminated_name ,
620664 name_len , g );
621665 if (res != success ) {
622- NN_ERR_PRINTF ("continue trying the next" );
623- continue ;
666+ goto fail ;
624667 }
625-
626- break ;
627668 }
628669
629670fail :
@@ -644,6 +685,7 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
644685 WASINNContext * wasi_nn_ctx = NULL ;
645686 char * nul_terminated_name = NULL ;
646687 char * nul_terminated_config = NULL ;
688+ graph_encoding encoding = unknown_backend ;
647689 wasi_nn_error res ;
648690
649691 wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
@@ -677,28 +719,55 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
677719 goto fail ;
678720 }
679721
680- for (int i = 0 ; i < auto_detect_encoding_num ; i ++ ) {
722+ if (ends_with (nul_terminated_name , TFLITE_MODEL_FILE_EXT )) {
723+ encoding = tensorflowlite ;
724+ }
725+ else if (ends_with (nul_terminated_name , ONNX_MODEL_FILE_EXT )) {
726+ encoding = onnx ;
727+ }
728+ if (encoding == unknown_backend ) {
729+ for (int i = 0 ; i < auto_detect_encoding_num ; i ++ ) {
730+ if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
731+ call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
732+ wasi_nn_ctx -> backend_ctx );
733+ }
734+
735+ res = ensure_backend (instance , auto_detect_encoding_order [i ],
736+ wasi_nn_ctx );
737+ if (res != success ) {
738+ NN_ERR_PRINTF ("continue trying the next" );
739+ continue ;
740+ }
741+
742+ call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name_with_config ,
743+ res , wasi_nn_ctx -> backend_ctx ,
744+ nul_terminated_name , name_len ,
745+ nul_terminated_config , config_len , g );
746+ if (res != success ) {
747+ NN_ERR_PRINTF ("continue trying the next" );
748+ continue ;
749+ }
750+
751+ break ;
752+ }
753+ }
754+ else {
681755 if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
682756 call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
683757 wasi_nn_ctx -> backend_ctx );
684758 }
685759
686- res = ensure_backend (instance , auto_detect_encoding_order [i ],
687- wasi_nn_ctx );
760+ res = ensure_backend (instance , encoding , wasi_nn_ctx );
688761 if (res != success ) {
689- NN_ERR_PRINTF ("continue trying the next" );
690- continue ;
762+ goto fail ;
691763 }
692764
693765 call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name_with_config , res ,
694766 wasi_nn_ctx -> backend_ctx , nul_terminated_name ,
695767 name_len , nul_terminated_config , config_len , g );
696768 if (res != success ) {
697- NN_ERR_PRINTF ("continue trying the next" );
698- continue ;
769+ goto fail ;
699770 }
700-
701- break ;
702771 }
703772
704773fail :
0 commit comments