Skip to content

Commit 12d8a18

Browse files
committed
Allow multiple --wasi-nn-graphs
1 parent 6bc7512 commit 12d8a18

8 files changed

Lines changed: 129 additions & 98 deletions

File tree

core/iwasm/common/wasm_runtime_common.c

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,20 +1805,28 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
18051805
}
18061806

18071807
bool
1808-
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding,
1809-
const char *target, uint32_t n_graphs,
1808+
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding,
1809+
const char **target, uint32_t n_graphs,
18101810
const char **graph_paths)
18111811
{
18121812
if (!registry || !encoding || !target || !graph_paths) {
18131813
return false;
18141814
}
1815-
registry->encoding = strdup(encoding);
1816-
registry->target = strdup(target);
1815+
18171816
registry->n_graphs = n_graphs;
1817+
registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
1818+
registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
18181819
registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
1820+
memset(registry->target, 0, sizeof(uint32_t *) * n_graphs);
1821+
memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs);
18191822
memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
1823+
18201824
for (uint32_t i = 0; i < registry->n_graphs; i++)
1825+
{
18211826
registry->graph_paths[i] = strdup(graph_paths[i]);
1827+
registry->encoding[i] = strdup(encoding[i]);
1828+
registry->target[i] = strdup(target[i]);
1829+
}
18221830

18231831
return true;
18241832
}
@@ -1841,14 +1849,12 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry)
18411849
if (registry) {
18421850
for (uint32_t i = 0; i < registry->n_graphs; i++)
18431851
if (registry->graph_paths[i]) {
1844-
// wasi_nn_graph_registry_unregister_graph(registry,
1845-
// registry->name[i]);
18461852
free(registry->graph_paths[i]);
1853+
if (registry->encoding[i])
1854+
free(registry->encoding[i]);
1855+
if (registry->target[i])
1856+
free(registry->target[i]);
18471857
}
1848-
if (registry->encoding)
1849-
free(registry->encoding);
1850-
if (registry->target)
1851-
free(registry->target);
18521858
free(registry);
18531859
}
18541860
}
@@ -8150,7 +8156,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
81508156
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
81518157
bool
81528158
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
8153-
const char *encoding, const char *target,
8159+
const char **encoding, const char **target,
81548160
const uint32_t n_graphs,
81558161
char *graph_paths[], char *error_buf,
81568162
uint32_t error_buf_size)
@@ -8162,16 +8168,21 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
81628168
if (!ctx)
81638169
return false;
81648170

8165-
ctx->encoding = strdup(encoding);
8166-
ctx->target = strdup(target);
81678171
ctx->n_graphs = n_graphs;
8172+
8173+
ctx->encoding = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
8174+
memset(ctx->encoding, 0, sizeof(uint32_t) * n_graphs);
8175+
ctx->target = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
8176+
memset(ctx->target, 0, sizeof(uint32_t) * n_graphs);
81688177
ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
81698178
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
8170-
81718179
ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
81728180
memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
8181+
81738182
for (uint32_t i = 0; i < n_graphs; i++) {
81748183
ctx->graph_paths[i] = strdup(graph_paths[i]);
8184+
ctx->target[i] = strdup(target[i]);
8185+
ctx->encoding[i] = strdup(encoding[i]);
81758186
}
81768187

81778188
wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);
@@ -8191,6 +8202,10 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
81918202
// All graphs will be unregistered in deinit()
81928203
if (wasi_nn_global_ctx->graph_paths[i])
81938204
free(wasi_nn_global_ctx->graph_paths[i]);
8205+
if (wasi_nn_global_ctx->encoding[i])
8206+
free(wasi_nn_global_ctx->encoding[i]);
8207+
if (wasi_nn_global_ctx->encoding[i])
8208+
free(wasi_nn_global_ctx->target[i]);
81948209
}
81958210
free(wasi_nn_global_ctx->encoding);
81968211
free(wasi_nn_global_ctx->target);
@@ -8243,21 +8258,21 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
82438258
}
82448259

82458260
char *
8246-
wasm_runtime_get_wasi_nn_global_ctx_encoding(
8247-
WASINNGlobalContext *wasi_nn_global_ctx)
8261+
wasm_runtime_get_wasi_nn_global_ctx_encoding_i(
8262+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
82488263
{
8249-
if (wasi_nn_global_ctx)
8250-
return wasi_nn_global_ctx->encoding;
8264+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8265+
return wasi_nn_global_ctx->encoding[idx];
82518266

82528267
return NULL;
82538268
}
82548269

82558270
char *
8256-
wasm_runtime_get_wasi_nn_global_ctx_target(
8257-
WASINNGlobalContext *wasi_nn_global_ctx)
8271+
wasm_runtime_get_wasi_nn_global_ctx_target_i(
8272+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
82588273
{
8259-
if (wasi_nn_global_ctx)
8260-
return wasi_nn_global_ctx->target;
8274+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8275+
return wasi_nn_global_ctx->target[idx];
82618276

82628277
return NULL;
82638278
}

core/iwasm/common/wasm_runtime_common.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,8 @@ typedef struct WASMModuleInstMemConsumption {
547547

548548
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
549549
typedef struct WASINNGlobalContext {
550-
char *encoding;
551-
char *target;
550+
char **encoding;
551+
char **target;
552552

553553
uint32_t n_graphs;
554554
uint32_t *loaded;
@@ -625,8 +625,8 @@ wasm_runtime_get_exec_env_tls(void);
625625

626626
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
627627
typedef struct WASINNArguments {
628-
char *encoding;
629-
char *target;
628+
char **encoding;
629+
char **target;
630630

631631
char **graph_paths;
632632
uint32_t n_graphs;
@@ -811,8 +811,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
811811
struct InstantiationArgs2 *p, WASINNArguments *registry);
812812

813813
WASM_RUNTIME_API_EXTERN bool
814-
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding,
815-
const char *target, uint32_t n_graphs,
814+
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char **encoding,
815+
const char **target, uint32_t n_graphs,
816816
const char **graph_paths);
817817
#endif
818818

@@ -1471,7 +1471,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
14711471
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
14721472
WASM_RUNTIME_API_EXTERN bool
14731473
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
1474-
const char *encoding, const char *target,
1474+
const char **encoding, const char **target,
14751475
const uint32_t n_graphs,
14761476
char *graph_paths[], char *error_buf,
14771477
uint32_t error_buf_size);
@@ -1507,12 +1507,12 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
15071507
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value);
15081508

15091509
WASM_RUNTIME_API_EXTERN char *
1510-
wasm_runtime_get_wasi_nn_global_ctx_encoding(
1511-
WASINNGlobalContext *wasi_nn_global_ctx);
1510+
wasm_runtime_get_wasi_nn_global_ctx_encoding_i(
1511+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
15121512

15131513
WASM_RUNTIME_API_EXTERN char *
1514-
wasm_runtime_get_wasi_nn_global_ctx_target(
1515-
WASINNGlobalContext *wasi_nn_global_ctx);
1514+
wasm_runtime_get_wasi_nn_global_ctx_target_i(
1515+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
15161516
#endif
15171517

15181518
#ifdef __cplusplus

core/iwasm/include/wasm_export.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -818,12 +818,12 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
818818
WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx, uint32_t value);
819819

820820
WASM_RUNTIME_API_EXTERN char *
821-
wasm_runtime_get_wasi_nn_global_ctx_encoding(
822-
WASINNGlobalContext * wasi_nn_global_ctx);
821+
wasm_runtime_get_wasi_nn_global_ctx_encoding_i(
822+
WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx);
823823

824824
WASM_RUNTIME_API_EXTERN char *
825-
wasm_runtime_get_wasi_nn_global_ctx_target(
826-
WASINNGlobalContext * wasi_nn_global_ctx);
825+
wasm_runtime_get_wasi_nn_global_ctx_target_i(
826+
WASINNGlobalContext * wasi_nn_global_ctx, uint32_t idx);
827827

828828
/**
829829
* Instantiate a WASM module, with specified instantiation arguments

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -614,24 +614,15 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
614614
res = not_found;
615615
goto fail;
616616
}
617-
graph_encoding encoding = str2encoding(
618-
wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx));
619-
execution_target target = str2target(
620-
wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx));
621-
622-
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
623-
res = ensure_backend(instance, encoding, wasi_nn_ctx);
624-
if (res != success)
625-
goto fail;
626617

627618
bool is_loaded = false;
628619
uint32 model_idx = 0;
629620
char *global_model_path_i;
630621
uint32_t global_n_graphs =
631622
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx);
632-
// Assume filename got from user wasm app : max; sum; average; ...
633-
// Assume file path got from user cmd opt: /your/path1/max.tflite;
634-
// /your/path2/sum.tflite; ......
623+
// Model got from user wasm app : modelA; modelB...
624+
// Filelist got from user cmd opt: /path1/modelA.tflite;
625+
// /path/modelB.tflite; ......
635626
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
636627
// Extract filename from file path
637628
global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
@@ -655,45 +646,54 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
655646
model_name[model_name_len] = '\0';
656647
}
657648

658-
if (model_name && strcmp(nul_terminated_name, model_name) == 0) {
659-
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
660-
wasi_nn_global_ctx, model_idx);
649+
if (model_name && strcmp(nul_terminated_name, model_name) != 0) {
661650
free(model_name);
662-
break;
651+
continue;
663652
}
653+
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
654+
wasi_nn_global_ctx, model_idx);
664655
free(model_name);
665-
}
666656

667-
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)
668-
&& (model_idx < global_n_graphs)) {
669-
NN_DBG_PRINTF("Model is not yet loaded, will add to global context");
670-
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
671-
wasi_nn_ctx->backend_ctx, global_model_path_i,
672-
strlen(global_model_path_i), encoding, target, g);
657+
graph_encoding encoding = str2encoding(
658+
wasm_runtime_get_wasi_nn_global_ctx_encoding_i(wasi_nn_global_ctx, model_idx));
659+
execution_target target = str2target(
660+
wasm_runtime_get_wasi_nn_global_ctx_target_i(wasi_nn_global_ctx, model_idx));
661+
662+
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
663+
res = ensure_backend(instance, encoding, wasi_nn_ctx);
673664
if (res != success)
674665
goto fail;
675666

676-
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx,
677-
model_idx, 1);
678-
res = success;
679-
}
680-
else {
681-
if (is_loaded) {
682-
NN_DBG_PRINTF("Model is already loaded");
667+
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)
668+
&& (model_idx < global_n_graphs)) {
669+
NN_DBG_PRINTF("Model is not yet loaded, will add to global context");
670+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
671+
wasi_nn_ctx->backend_ctx, global_model_path_i,
672+
strlen(global_model_path_i), encoding, target, g);
673+
if (res != success)
674+
goto fail;
675+
676+
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx,
677+
model_idx, 1);
683678
res = success;
679+
break;
684680
}
685-
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) {
686-
// No enlarge for now
687-
NN_ERR_PRINTF("No enough space for new model");
688-
res = too_large;
689-
}
690-
else if (model_idx >= global_n_graphs) {
691-
NN_ERR_PRINTF("Model %s is not loaded, you should pass its path "
692-
"through --wasi-nn-graph",
693-
nul_terminated_name);
694-
res = not_found;
695-
}
696-
goto fail;
681+
}
682+
683+
if (is_loaded) {
684+
NN_DBG_PRINTF("Model is already loaded");
685+
res = success;
686+
}
687+
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) {
688+
// No enlarge for now
689+
NN_ERR_PRINTF("No enough space for new model");
690+
res = too_large;
691+
}
692+
else if (model_idx >= global_n_graphs) {
693+
NN_ERR_PRINTF("Model %s is not loaded, you should pass its path "
694+
"through --wasi-nn-graph",
695+
nul_terminated_name);
696+
res = not_found;
697697
}
698698
fail:
699699
if (nul_terminated_name != NULL) {

core/iwasm/libraries/wasi-nn/test/test_tensorflow.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ int
144144
main()
145145
{
146146
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so "
147-
"--wasi-nn-graph=encoding:target:model_path1:model_path2:..."
148-
":model_pathN test_tensorflow.wasm\"");
147+
"--wasi-nn-graph=encodingA:targetA:<modelA_path> "
148+
"--wasi-nn-graph=encodingB:targetB:<modelB_path>..."
149+
" test_tensorflow.wasm");
149150

150151
NN_INFO_PRINTF("################### Testing sum...");
151152
test_sum();

core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ int
3939
main()
4040
{
4141
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so "
42-
"--wasi-nn-graph=encoding:target:model_path1:model_path2:..."
43-
":model_pathN test_tensorflow.wasm\"");
42+
"--wasi-nn-graph=encodingA:targetA:<modelA_path> "
43+
"--wasi-nn-graph=encodingB:targetB:<modelB_path>..."
44+
" test_tensorflow_quantized.wasm");
4445

4546
NN_INFO_PRINTF("################### Testing quantized model...");
4647
test_average_quantized();

0 commit comments

Comments
 (0)