@@ -535,38 +535,13 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len,
535535 return success ;
536536}
537537
538- wasi_nn_error
539- wasi_nn_load_by_name (wasm_exec_env_t exec_env , char * name , uint32_t name_len ,
540- graph * g )
538+ static wasi_nn_error
539+ load_by_name_with_optional_config (WASINNContext * wasi_nn_ctx ,
540+ wasm_module_inst_t instance , bool use_config ,
541+ graph * g , const char * model_name ,
542+ const char * config , int32_t config_len )
541543{
542- WASINNContext * wasi_nn_ctx = NULL ;
543- char * nul_terminated_name = NULL ;
544- wasi_nn_error res ;
545-
546- wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
547- if (!instance ) {
548- return runtime_error ;
549- }
550-
551- if (!wasm_runtime_validate_native_addr (instance , g ,
552- (uint64 )sizeof (graph ))) {
553- NN_ERR_PRINTF ("graph is invalid" );
554- return invalid_argument ;
555- }
556-
557- res = copyin_and_nul_terminate (instance , name , name_len ,
558- & nul_terminated_name );
559- if (res != success ) {
560- goto fail ;
561- }
562-
563- NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , nul_terminated_name );
564-
565- wasi_nn_ctx = lock_ctx (instance );
566- if (wasi_nn_ctx == NULL ) {
567- res = busy ;
568- goto fail ;
569- }
544+ wasi_nn_error res = success ;
570545
571546 WASINNRegistry * wasi_nn_registry =
572547 wasm_runtime_get_wasi_nn_registry (instance );
@@ -580,9 +555,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
580555 uint32 model_idx = 0 ;
581556 uint32_t global_n_graphs = wasi_nn_registry -> n_graphs ;
582557 for (model_idx = 0 ; model_idx < global_n_graphs ; model_idx ++ ) {
583- char * model_name = wasi_nn_registry -> model_names [model_idx ];
558+ char * model_name_i = wasi_nn_registry -> model_names [model_idx ];
584559
585- if (strcmp (nul_terminated_name , model_name ) != 0 ) {
560+ if (strcmp (model_name , model_name_i ) != 0 ) {
586561 continue ;
587562 }
588563
@@ -594,7 +569,6 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
594569 execution_target target =
595570 (execution_target )(wasi_nn_registry -> target [model_idx ]);
596571
597- // res = ensure_backend(instance, autodetect, wasi_nn_ctx);
598572 res = ensure_backend (instance , encoding , wasi_nn_ctx );
599573 if (res != success )
600574 goto fail ;
@@ -603,9 +577,17 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
603577 && (model_idx < global_n_graphs )) {
604578 NN_DBG_PRINTF (
605579 "Model is not yet loaded, will add to global context" );
606- call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
607- wasi_nn_ctx -> backend_ctx , global_model_path_i ,
608- strlen (global_model_path_i ), target , g );
580+ if (use_config && config && config_len > 0 ) {
581+ call_wasi_nn_func (
582+ wasi_nn_ctx -> backend , load_by_name_with_config , res ,
583+ wasi_nn_ctx -> backend_ctx , global_model_path_i ,
584+ strlen (global_model_path_i ), config , config_len , g );
585+ }
586+ else {
587+ call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
588+ wasi_nn_ctx -> backend_ctx , global_model_path_i ,
589+ strlen (global_model_path_i ), target , g );
590+ }
609591 if (res != success )
610592 goto fail ;
611593
@@ -627,9 +609,51 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
627609 else if (model_idx >= global_n_graphs ) {
628610 NN_ERR_PRINTF ("Model %s is not loaded, you should pass its path "
629611 "through --wasi-nn-graph" ,
630- nul_terminated_name );
612+ model_name );
631613 res = not_found ;
632614 }
615+
616+ fail :
617+
618+ return res ;
619+ }
620+
621+ wasi_nn_error
622+ wasi_nn_load_by_name (wasm_exec_env_t exec_env , char * name , uint32_t name_len ,
623+ graph * g )
624+ {
625+ WASINNContext * wasi_nn_ctx = NULL ;
626+ char * nul_terminated_name = NULL ;
627+ wasi_nn_error res ;
628+
629+ wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
630+ if (!instance ) {
631+ return runtime_error ;
632+ }
633+
634+ if (!wasm_runtime_validate_native_addr (instance , g ,
635+ (uint64 )sizeof (graph ))) {
636+ NN_ERR_PRINTF ("graph is invalid" );
637+ return invalid_argument ;
638+ }
639+
640+ res = copyin_and_nul_terminate (instance , name , name_len ,
641+ & nul_terminated_name );
642+ if (res != success ) {
643+ goto fail ;
644+ }
645+
646+ NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , nul_terminated_name );
647+
648+ wasi_nn_ctx = lock_ctx (instance );
649+ if (wasi_nn_ctx == NULL ) {
650+ res = busy ;
651+ goto fail ;
652+ }
653+
654+ res = load_by_name_with_optional_config (wasi_nn_ctx , instance , false, g ,
655+ nul_terminated_name , NULL , 0 );
656+
633657fail :
634658 if (nul_terminated_name != NULL ) {
635659 wasm_runtime_free (nul_terminated_name );
@@ -686,9 +710,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
686710 goto fail ;
687711 ;
688712
689- call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name_with_config , res ,
690- wasi_nn_ctx -> backend_ctx , nul_terminated_name , name_len ,
691- nul_terminated_config , config_len , g );
713+ res = load_by_name_with_optional_config (wasi_nn_ctx , instance , true, g ,
714+ nul_terminated_name ,
715+ nul_terminated_config , config_len );
692716 if (res != success )
693717 goto fail ;
694718
0 commit comments