Skip to content

Commit 9e89828

Browse files
committed
Make wasi_nn_load_by_name and wasi_nn_load_by_name_with_config share a common logic.
1 parent 4747d61 commit 9e89828

File tree

1 file changed

+65
-41
lines changed

1 file changed

+65
-41
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -535,38 +535,13 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len,
535535
return success;
536536
}
537537

538-
wasi_nn_error
539-
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
540-
graph *g)
538+
static wasi_nn_error
539+
load_by_name_with_optional_config(WASINNContext *wasi_nn_ctx,
540+
wasm_module_inst_t instance, bool use_config,
541+
graph *g, const char *model_name,
542+
const char *config, int32_t config_len)
541543
{
542-
WASINNContext *wasi_nn_ctx = NULL;
543-
char *nul_terminated_name = NULL;
544-
wasi_nn_error res;
545-
546-
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
547-
if (!instance) {
548-
return runtime_error;
549-
}
550-
551-
if (!wasm_runtime_validate_native_addr(instance, g,
552-
(uint64)sizeof(graph))) {
553-
NN_ERR_PRINTF("graph is invalid");
554-
return invalid_argument;
555-
}
556-
557-
res = copyin_and_nul_terminate(instance, name, name_len,
558-
&nul_terminated_name);
559-
if (res != success) {
560-
goto fail;
561-
}
562-
563-
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", nul_terminated_name);
564-
565-
wasi_nn_ctx = lock_ctx(instance);
566-
if (wasi_nn_ctx == NULL) {
567-
res = busy;
568-
goto fail;
569-
}
544+
wasi_nn_error res = success;
570545

571546
WASINNRegistry *wasi_nn_registry =
572547
wasm_runtime_get_wasi_nn_registry(instance);
@@ -580,9 +555,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
580555
uint32 model_idx = 0;
581556
uint32_t global_n_graphs = wasi_nn_registry->n_graphs;
582557
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
583-
char *model_name = wasi_nn_registry->model_names[model_idx];
558+
char *model_name_i = wasi_nn_registry->model_names[model_idx];
584559

585-
if (strcmp(nul_terminated_name, model_name) != 0) {
560+
if (strcmp(model_name, model_name_i) != 0) {
586561
continue;
587562
}
588563

@@ -594,7 +569,6 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
594569
execution_target target =
595570
(execution_target)(wasi_nn_registry->target[model_idx]);
596571

597-
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
598572
res = ensure_backend(instance, encoding, wasi_nn_ctx);
599573
if (res != success)
600574
goto fail;
@@ -603,9 +577,17 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
603577
&& (model_idx < global_n_graphs)) {
604578
NN_DBG_PRINTF(
605579
"Model is not yet loaded, will add to global context");
606-
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
607-
wasi_nn_ctx->backend_ctx, global_model_path_i,
608-
strlen(global_model_path_i), target, g);
580+
if (use_config && config && config_len > 0) {
581+
call_wasi_nn_func(
582+
wasi_nn_ctx->backend, load_by_name_with_config, res,
583+
wasi_nn_ctx->backend_ctx, global_model_path_i,
584+
strlen(global_model_path_i), config, config_len, g);
585+
}
586+
else {
587+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
588+
wasi_nn_ctx->backend_ctx, global_model_path_i,
589+
strlen(global_model_path_i), target, g);
590+
}
609591
if (res != success)
610592
goto fail;
611593

@@ -627,9 +609,51 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
627609
else if (model_idx >= global_n_graphs) {
628610
NN_ERR_PRINTF("Model %s is not loaded, you should pass its path "
629611
"through --wasi-nn-graph",
630-
nul_terminated_name);
612+
model_name);
631613
res = not_found;
632614
}
615+
616+
fail:
617+
618+
return res;
619+
}
620+
621+
wasi_nn_error
622+
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
623+
graph *g)
624+
{
625+
WASINNContext *wasi_nn_ctx = NULL;
626+
char *nul_terminated_name = NULL;
627+
wasi_nn_error res;
628+
629+
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
630+
if (!instance) {
631+
return runtime_error;
632+
}
633+
634+
if (!wasm_runtime_validate_native_addr(instance, g,
635+
(uint64)sizeof(graph))) {
636+
NN_ERR_PRINTF("graph is invalid");
637+
return invalid_argument;
638+
}
639+
640+
res = copyin_and_nul_terminate(instance, name, name_len,
641+
&nul_terminated_name);
642+
if (res != success) {
643+
goto fail;
644+
}
645+
646+
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", nul_terminated_name);
647+
648+
wasi_nn_ctx = lock_ctx(instance);
649+
if (wasi_nn_ctx == NULL) {
650+
res = busy;
651+
goto fail;
652+
}
653+
654+
res = load_by_name_with_optional_config(wasi_nn_ctx, instance, false, g,
655+
nul_terminated_name, NULL, 0);
656+
633657
fail:
634658
if (nul_terminated_name != NULL) {
635659
wasm_runtime_free(nul_terminated_name);
@@ -686,9 +710,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
686710
goto fail;
687711
;
688712

689-
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
690-
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
691-
nul_terminated_config, config_len, g);
713+
res = load_by_name_with_optional_config(wasi_nn_ctx, instance, true, g,
714+
nul_terminated_name,
715+
nul_terminated_config, config_len);
692716
if (res != success)
693717
goto fail;
694718

0 commit comments

Comments
 (0)