Skip to content

Commit 736f357

Browse files
committed
Add model_name option for --wasi-nn-graphs to make it more flexible and simpler
1 parent c024c94 commit 736f357

7 files changed

Lines changed: 67 additions & 45 deletions

File tree

core/iwasm/common/wasm_runtime_common.c

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,23 +1806,27 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
18061806

18071807
bool
18081808
wasi_nn_graph_registry_set_args(WASINNArguments *registry,
1809-
const char **encoding, const char **target,
1810-
uint32_t n_graphs, const char **graph_paths)
1809+
const char **model_names, const char **encoding,
1810+
const char **target, uint32_t n_graphs,
1811+
const char **graph_paths)
18111812
{
1812-
if (!registry || !encoding || !target || !graph_paths) {
1813+
if (!registry || !model_names || !encoding || !target || !graph_paths) {
18131814
return false;
18141815
}
18151816

18161817
registry->n_graphs = n_graphs;
18171818
registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
18181819
registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
1820+
registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
18191821
registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
18201822
memset(registry->target, 0, sizeof(uint32_t *) * n_graphs);
18211823
memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs);
1824+
memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs);
18221825
memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
18231826

18241827
for (uint32_t i = 0; i < registry->n_graphs; i++) {
18251828
registry->graph_paths[i] = strdup(graph_paths[i]);
1829+
registry->model_names[i] = strdup(model_names[i]);
18261830
registry->encoding[i] = strdup(encoding[i]);
18271831
registry->target[i] = strdup(target[i]);
18281832
}
@@ -1849,6 +1853,8 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry)
18491853
for (uint32_t i = 0; i < registry->n_graphs; i++)
18501854
if (registry->graph_paths[i]) {
18511855
free(registry->graph_paths[i]);
1856+
if (registry->model_names[i])
1857+
free(registry->model_names[i]);
18521858
if (registry->encoding[i])
18531859
free(registry->encoding[i]);
18541860
if (registry->target[i])
@@ -8155,6 +8161,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
81558161
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
81568162
bool
81578163
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
8164+
const char **model_names,
81588165
const char **encoding, const char **target,
81598166
const uint32_t n_graphs,
81608167
char *graph_paths[], char *error_buf,
@@ -8175,11 +8182,14 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
81758182
memset(ctx->target, 0, sizeof(uint32_t) * n_graphs);
81768183
ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
81778184
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
8185+
ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
8186+
memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs);
81788187
ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
81798188
memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
81808189

81818190
for (uint32_t i = 0; i < n_graphs; i++) {
81828191
ctx->graph_paths[i] = strdup(graph_paths[i]);
8192+
ctx->model_names[i] = strdup(model_names[i]);
81838193
ctx->target[i] = strdup(target[i]);
81848194
ctx->encoding[i] = strdup(encoding[i]);
81858195
}
@@ -8201,14 +8211,17 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
82018211
// All graphs will be unregistered in deinit()
82028212
if (wasi_nn_global_ctx->graph_paths[i])
82038213
free(wasi_nn_global_ctx->graph_paths[i]);
8214+
if (wasi_nn_global_ctx->model_names[i])
8215+
free(wasi_nn_global_ctx->model_names[i]);
82048216
if (wasi_nn_global_ctx->encoding[i])
82058217
free(wasi_nn_global_ctx->encoding[i]);
8206-
if (wasi_nn_global_ctx->encoding[i])
8218+
if (wasi_nn_global_ctx->target[i])
82078219
free(wasi_nn_global_ctx->target[i]);
82088220
}
82098221
free(wasi_nn_global_ctx->encoding);
82108222
free(wasi_nn_global_ctx->target);
82118223
free(wasi_nn_global_ctx->loaded);
8224+
free(wasi_nn_global_ctx->model_names);
82128225
free(wasi_nn_global_ctx->graph_paths);
82138226

82148227
if (wasi_nn_global_ctx) {
@@ -8226,6 +8239,16 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
82268239
return -1;
82278240
}
82288241

8242+
char *
8243+
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
8244+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
8245+
{
8246+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8247+
return wasi_nn_global_ctx->model_names[idx];
8248+
8249+
return NULL;
8250+
}
8251+
82298252
char *
82308253
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
82318254
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)

core/iwasm/common/wasm_runtime_common.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ typedef struct WASMModuleInstMemConsumption {
547547

548548
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
549549
typedef struct WASINNGlobalContext {
550+
char **model_names;
550551
char **encoding;
551552
char **target;
552553

@@ -625,6 +626,7 @@ wasm_runtime_get_exec_env_tls(void);
625626

626627
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
627628
typedef struct WASINNArguments {
629+
char **model_names;
628630
char **encoding;
629631
char **target;
630632

@@ -812,8 +814,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
812814

813815
WASM_RUNTIME_API_EXTERN bool
814816
wasi_nn_graph_registry_set_args(WASINNArguments *registry,
815-
const char **encoding, const char **target,
816-
uint32_t n_graphs, const char **graph_paths);
817+
const char **model_names, const char **encoding,
818+
const char **target, uint32_t n_graphs,
819+
const char **graph_paths);
817820
#endif
818821

819822
/* See wasm_export.h for description */
@@ -1471,6 +1474,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
14711474
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
14721475
WASM_RUNTIME_API_EXTERN bool
14731476
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
1477+
const char **model_names,
14741478
const char **encoding, const char **target,
14751479
const uint32_t n_graphs,
14761480
char *graph_paths[], char *error_buf,
@@ -1494,6 +1498,10 @@ WASM_RUNTIME_API_EXTERN uint32_t
14941498
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
14951499
WASINNGlobalContext *wasi_nn_global_ctx);
14961500

1501+
WASM_RUNTIME_API_EXTERN char *
1502+
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
1503+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
1504+
14971505
WASM_RUNTIME_API_EXTERN char *
14981506
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
14991507
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);

core/iwasm/include/wasm_export.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,10 @@ WASM_RUNTIME_API_EXTERN uint32_t
805805
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
806806
WASINNGlobalContext *wasi_nn_global_ctx);
807807

808+
WASM_RUNTIME_API_EXTERN char *
809+
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
810+
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
811+
808812
WASM_RUNTIME_API_EXTERN char *
809813
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
810814
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);

core/iwasm/interpreter/wasm_runtime.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3305,8 +3305,8 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent,
33053305
* load_by_name */
33063306
WASINNArguments *nn_registry = &args->nn_registry;
33073307
if (!wasm_runtime_init_wasi_nn_global_ctx(
3308-
(WASMModuleInstanceCommon *)module_inst, nn_registry->encoding,
3309-
nn_registry->target, nn_registry->n_graphs,
3308+
(WASMModuleInstanceCommon *)module_inst, nn_registry->model_names,
3309+
nn_registry->encoding, nn_registry->target, nn_registry->n_graphs,
33103310
nn_registry->graph_paths, error_buf, error_buf_size)) {
33113311
goto fail;
33123312
}

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -617,42 +617,21 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
617617

618618
bool is_loaded = false;
619619
uint32 model_idx = 0;
620-
char *global_model_path_i;
621620
uint32_t global_n_graphs =
622621
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx);
623-
// Model got from user wasm app : modelA; modelB...
624-
// Filelist got from user cmd opt: /path1/modelA.tflite;
625-
// /path/modelB.tflite; ......
626622
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
627-
// Extract filename from file path
628-
global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
623+
char *model_name = wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
629624
wasi_nn_global_ctx, model_idx);
630-
char *model_file_name;
631-
const char *slash = strrchr(global_model_path_i, '/');
632-
if (slash != NULL) {
633-
model_file_name = (char *)(slash + 1);
634-
}
635-
else
636-
model_file_name = global_model_path_i;
637-
638-
// Extract modelname from filename
639-
char *model_name = NULL;
640-
size_t model_name_len = 0;
641-
char *dot = strrchr(model_file_name, '.');
642-
if (dot) {
643-
model_name_len = dot - model_file_name;
644-
model_name = malloc(model_name_len + 1);
645-
strncpy(model_name, model_file_name, model_name_len);
646-
model_name[model_name_len] = '\0';
647-
}
648625

649626
if (model_name && strcmp(nul_terminated_name, model_name) != 0) {
650-
free(model_name);
651627
continue;
652628
}
629+
653630
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
654631
wasi_nn_global_ctx, model_idx);
655-
free(model_name);
632+
char *global_model_path_i =
633+
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
634+
wasi_nn_global_ctx, model_idx);
656635

657636
graph_encoding encoding =
658637
str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i(

product-mini/platforms/common/libc_wasi.c

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef struct {
2222
} libc_wasi_parse_context_t;
2323

2424
typedef struct {
25+
const char *model_names[10];
2526
const char *encoding[10];
2627
const char *target[10];
2728
const char *graph_paths[10];
@@ -208,19 +209,23 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx)
208209
// --wasi-nn-graph=encoding2:target2:model_file_path2 ...
209210
token = strtok_r(argv[0] + 16, ":", &saveptr);
210211
while (token) {
211-
tokens[token_count] = token;
212-
token_count++;
213-
token = strtok_r(NULL, ":", &saveptr);
212+
if (strlen(token) > 0) {
213+
tokens[token_count] = token;
214+
token_count++;
215+
token = strtok_r(NULL, ":", &saveptr);
216+
}
214217
}
215218

216-
if (token_count != 3) {
219+
if (token_count != 4) {
217220
ret = LIBC_WASI_PARSE_RESULT_NEED_HELP;
221+
printf("4 arguments are needed for wasi-nn.\n");
218222
goto fail;
219223
}
220224

221-
ctx->encoding[ctx->n_graphs] = strdup(tokens[0]);
222-
ctx->target[ctx->n_graphs] = strdup(tokens[1]);
223-
ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[2]);
225+
ctx->model_names[ctx->n_graphs] = strdup(tokens[0]);
226+
ctx->encoding[ctx->n_graphs] = strdup(tokens[1]);
227+
ctx->target[ctx->n_graphs] = strdup(tokens[2]);
228+
ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[3]);
224229

225230
fail:
226231
if (token)
@@ -234,12 +239,15 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args,
234239
struct WASINNArguments *nn_registry,
235240
wasi_nn_parse_context_t *ctx)
236241
{
237-
wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target,
238-
ctx->n_graphs, ctx->graph_paths);
242+
wasi_nn_graph_registry_set_args(nn_registry, ctx->model_names,
243+
ctx->encoding, ctx->target, ctx->n_graphs,
244+
ctx->graph_paths);
239245
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args,
240246
nn_registry);
241247

242248
for (uint32_t i = 0; i < ctx->n_graphs; i++) {
249+
if (ctx->model_names[i])
250+
free(ctx->model_names[i]);
243251
if (ctx->graph_paths[i])
244252
free(ctx->graph_paths[i]);
245253
if (ctx->encoding[i])

product-mini/platforms/posix/main.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ print_help(void)
123123
printf(" --gen-prof-file=<path> Generate LLVM PGO (Profile-Guided Optimization) profile file\n");
124124
#endif
125125
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
126-
printf(" --wasi-nn-graph=encodingA:targetB:<modelA_path>\n");
127-
printf(" --wasi-nn-graph=encodingA:targetB:<modelB_path>...\n");
126+
printf(" --wasi-nn-graph=modelA_name:encodingA:targetA:<modelA_path>\n");
127+
printf(" --wasi-nn-graph=modelB_name:encodingB:targetB:<modelB_path>...\n");
128128
printf(" Set encoding, target and model_paths for wasi-nn. target can be\n");
129129
printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n");
130130
printf(" tensorflow|pytorch|ggml|autodetect\n");

0 commit comments

Comments
 (0)