@@ -102,13 +102,16 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
102102 NN_DBG_PRINTF ("-> is_model_loaded: %d" , wasi_nn_ctx -> is_model_loaded );
103103 NN_DBG_PRINTF ("-> current_encoding: %d" , wasi_nn_ctx -> backend );
104104
105+ bh_assert (!wasi_nn_ctx -> busy );
106+
105107 /* deinit() the backend */
106108 if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
107109 wasi_nn_error res ;
108110 call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
109111 wasi_nn_ctx -> backend_ctx );
110112 }
111113
114+ os_mutex_destroy (& wasi_nn_ctx -> lock );
112115 wasm_runtime_free (wasi_nn_ctx );
113116}
114117
@@ -154,6 +157,11 @@ wasi_nn_initialize_context()
154157 }
155158
156159 memset (wasi_nn_ctx , 0 , sizeof (WASINNContext ));
160+ if (os_mutex_init (& wasi_nn_ctx -> lock )) {
161+ NN_ERR_PRINTF ("Error when initializing a lock for WASI-NN context" );
162+ wasm_runtime_free (wasi_nn_ctx );
163+ return NULL ;
164+ }
157165 return wasi_nn_ctx ;
158166}
159167
@@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
180188 return wasi_nn_ctx ;
181189}
182190
191+ static WASINNContext *
192+ lock_ctx (wasm_module_inst_t instance )
193+ {
194+ WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
195+ if (wasi_nn_ctx == NULL ) {
196+ return NULL ;
197+ }
198+ os_mutex_lock (& wasi_nn_ctx -> lock );
199+ if (wasi_nn_ctx -> busy ) {
200+ os_mutex_unlock (& wasi_nn_ctx -> lock );
201+ return NULL ;
202+ }
203+ wasi_nn_ctx -> busy = true;
204+ os_mutex_unlock (& wasi_nn_ctx -> lock );
205+ return wasi_nn_ctx ;
206+ }
207+
208+ static void
209+ unlock_ctx (WASINNContext * wasi_nn_ctx )
210+ {
211+ if (wasi_nn_ctx == NULL ) {
212+ return ;
213+ }
214+ os_mutex_lock (& wasi_nn_ctx -> lock );
215+ bh_assert (wasi_nn_ctx -> busy );
216+ wasi_nn_ctx -> busy = false;
217+ os_mutex_unlock (& wasi_nn_ctx -> lock );
218+ }
219+
183220void
184221wasi_nn_destroy ()
185222{
@@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
405442
406443static wasi_nn_error
407444ensure_backend (wasm_module_inst_t instance , graph_encoding encoding ,
408- WASINNContext * * wasi_nn_ctx_ptr )
445+ WASINNContext * wasi_nn_ctx )
409446{
410447 wasi_nn_error res ;
411448
@@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
416453 goto fail ;
417454 }
418455
419- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
420456 if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
421457 if (wasi_nn_ctx -> backend != loaded_backend ) {
422458 res = unsupported_operation ;
@@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
434470
435471 wasi_nn_ctx -> is_backend_ctx_initialized = true;
436472 }
437- * wasi_nn_ctx_ptr = wasi_nn_ctx ;
438473 return success ;
439474fail :
440475 return res ;
@@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
462497 if (!instance )
463498 return runtime_error ;
464499
500+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
501+ if (wasi_nn_ctx == NULL ) {
502+ res = busy ;
503+ goto fail ;
504+ }
505+
465506 graph_builder_array builder_native = { 0 };
466507#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
467508 if (success
468509 != (res = graph_builder_array_app_native (
469510 instance , builder , builder_wasm_size , & builder_native )))
470- return res ;
511+ goto fail ;
471512#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
472513 if (success
473514 != (res = graph_builder_array_app_native (instance , builder ,
474515 & builder_native )))
475- return res ;
516+ goto fail ;
476517#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
477518
478519 if (!wasm_runtime_validate_native_addr (instance , g ,
@@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
482523 goto fail ;
483524 }
484525
485- WASINNContext * wasi_nn_ctx ;
486- res = ensure_backend (instance , encoding , & wasi_nn_ctx );
526+ res = ensure_backend (instance , encoding , wasi_nn_ctx );
487527 if (res != success )
488528 goto fail ;
489529
@@ -498,6 +538,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
498538 // XXX: Free intermediate structure pointers
499539 if (builder_native .buf )
500540 wasm_runtime_free (builder_native .buf );
541+ unlock_ctx (wasi_nn_ctx );
501542
502543 return res ;
503544}
@@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
531572
532573 NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , name );
533574
534- WASINNContext * wasi_nn_ctx ;
535- res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
575+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
576+ if (wasi_nn_ctx == NULL ) {
577+ res = busy ;
578+ goto fail ;
579+ }
580+
581+ res = ensure_backend (instance , autodetect , wasi_nn_ctx );
536582 if (res != success )
537- return res ;
583+ goto fail ;
538584
539585 call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
540586 wasi_nn_ctx -> backend_ctx , name , name_len , g );
541587 if (res != success )
542- return res ;
588+ goto fail ;
543589
544590 wasi_nn_ctx -> is_model_loaded = true;
545- return success ;
591+ res = success ;
592+ fail :
593+ unlock_ctx (wasi_nn_ctx );
594+ return res ;
546595}
547596
548597wasi_nn_error
@@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
580629
581630 NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s..." , name , config );
582631
583- WASINNContext * wasi_nn_ctx ;
584- res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
632+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
633+ if (wasi_nn_ctx == NULL ) {
634+ res = busy ;
635+ goto fail ;
636+ }
637+
638+ res = ensure_backend (instance , autodetect , wasi_nn_ctx );
585639 if (res != success )
586- return res ;
640+ goto fail ;
641+ ;
587642
588643 call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name_with_config , res ,
589644 wasi_nn_ctx -> backend_ctx , name , name_len , config ,
590645 config_len , g );
591646 if (res != success )
592- return res ;
647+ goto fail ;
593648
594649 wasi_nn_ctx -> is_model_loaded = true;
595- return success ;
650+ res = success ;
651+ fail :
652+ unlock_ctx (wasi_nn_ctx );
653+ return res ;
596654}
597655
598656wasi_nn_error
@@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
606664 return runtime_error ;
607665 }
608666
609- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
610-
611667 wasi_nn_error res ;
668+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
669+ if (wasi_nn_ctx == NULL ) {
670+ res = busy ;
671+ goto fail ;
672+ }
673+
612674 if (success != (res = is_model_initialized (wasi_nn_ctx )))
613- return res ;
675+ goto fail ;
614676
615677 if (!wasm_runtime_validate_native_addr (
616678 instance , ctx , (uint64 )sizeof (graph_execution_context ))) {
617679 NN_ERR_PRINTF ("ctx is invalid" );
618- return invalid_argument ;
680+ res = invalid_argument ;
681+ goto fail ;
619682 }
620683
621684 call_wasi_nn_func (wasi_nn_ctx -> backend , init_execution_context , res ,
622685 wasi_nn_ctx -> backend_ctx , g , ctx );
686+ fail :
687+ unlock_ctx (wasi_nn_ctx );
623688 return res ;
624689}
625690
@@ -634,25 +699,30 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
634699 return runtime_error ;
635700 }
636701
637- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
638-
639702 wasi_nn_error res ;
703+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
704+ if (wasi_nn_ctx == NULL ) {
705+ res = busy ;
706+ goto fail ;
707+ }
708+
640709 if (success != (res = is_model_initialized (wasi_nn_ctx )))
641- return res ;
710+ goto fail ;
642711
643712 tensor input_tensor_native = { 0 };
644713 if (success
645714 != (res = tensor_app_native (instance , input_tensor ,
646715 & input_tensor_native )))
647- return res ;
716+ goto fail ;
648717
649718 call_wasi_nn_func (wasi_nn_ctx -> backend , set_input , res ,
650719 wasi_nn_ctx -> backend_ctx , ctx , index ,
651720 & input_tensor_native );
652721 // XXX: Free intermediate structure pointers
653722 if (input_tensor_native .dimensions )
654723 wasm_runtime_free (input_tensor_native .dimensions );
655-
724+ fail :
725+ unlock_ctx (wasi_nn_ctx );
656726 return res ;
657727}
658728
@@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
666736 return runtime_error ;
667737 }
668738
669- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
670-
671739 wasi_nn_error res ;
740+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
741+ if (wasi_nn_ctx == NULL ) {
742+ res = busy ;
743+ goto fail ;
744+ }
745+
672746 if (success != (res = is_model_initialized (wasi_nn_ctx )))
673- return res ;
747+ goto fail ;
674748
675749 call_wasi_nn_func (wasi_nn_ctx -> backend , compute , res ,
676750 wasi_nn_ctx -> backend_ctx , ctx );
751+ fail :
752+ unlock_ctx (wasi_nn_ctx );
677753 return res ;
678754}
679755
@@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
696772 return runtime_error ;
697773 }
698774
699- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
700-
701775 wasi_nn_error res ;
776+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
777+ if (wasi_nn_ctx == NULL ) {
778+ res = busy ;
779+ goto fail ;
780+ }
781+
702782 if (success != (res = is_model_initialized (wasi_nn_ctx )))
703- return res ;
783+ goto fail ;
704784
705785 if (!wasm_runtime_validate_native_addr (instance , output_tensor_size ,
706786 (uint64 )sizeof (uint32_t ))) {
707787 NN_ERR_PRINTF ("output_tensor_size is invalid" );
708- return invalid_argument ;
788+ res = invalid_argument ;
789+ goto fail ;
709790 }
710791
711792#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
718799 wasi_nn_ctx -> backend_ctx , ctx , index , output_tensor ,
719800 output_tensor_size );
720801#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
802+ fail :
803+ unlock_ctx (wasi_nn_ctx );
721804 return res ;
722805}
723806
0 commit comments