Skip to content

Commit 5357fb5

Browse files
committed
Move the error checks to an earlier stage.
1 parent ccee194 commit 5357fb5

File tree

4 files changed

+100
-78
lines changed

4 files changed

+100
-78
lines changed

core/iwasm/common/wasm_runtime_common.c

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,8 +1804,8 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args)
18041804

18051805
bool
18061806
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
1807-
const char **model_names, const char **encoding,
1808-
const char **target, uint32_t n_graphs,
1807+
const char **model_names, const uint32_t **encoding,
1808+
const uint32_t **target, uint32_t n_graphs,
18091809
const char **graph_paths)
18101810
{
18111811
if (!registry || !model_names || !encoding || !target || !graph_paths) {
@@ -1832,8 +1832,8 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
18321832
for (uint32_t i = 0; i < registry->n_graphs; i++) {
18331833
registry->graph_paths[i] = bh_strdup(graph_paths[i]);
18341834
registry->model_names[i] = bh_strdup(model_names[i]);
1835-
registry->encoding[i] = bh_strdup(encoding[i]);
1836-
registry->target[i] = bh_strdup(target[i]);
1835+
registry->encoding[i] = encoding[i];
1836+
registry->target[i] = target[i];
18371837
}
18381838

18391839
return true;
@@ -1860,13 +1860,13 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry)
18601860
wasm_runtime_free(registry->graph_paths[i]);
18611861
if (registry->model_names[i])
18621862
wasm_runtime_free(registry->model_names[i]);
1863-
if (registry->encoding[i])
1864-
wasm_runtime_free(registry->encoding[i]);
1865-
if (registry->target[i])
1866-
wasm_runtime_free(registry->target[i]);
18671863
}
1868-
if (registry->loaded)
1869-
wasm_runtime_free(registry->loaded);
1864+
if (registry->encoding)
1865+
wasm_runtime_free(registry->encoding);
1866+
if (registry->target)
1867+
wasm_runtime_free(registry->target);
1868+
if (registry->loaded)
1869+
wasm_runtime_free(registry->loaded);
18701870
wasm_runtime_free(registry);
18711871
}
18721872
}
@@ -1881,16 +1881,13 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
18811881

18821882
wasi_nn_registry->n_graphs = registry->n_graphs;
18831883

1884-
if (registry->model_names)
1885-
wasi_nn_registry->model_names = bh_strdup(registry->model_names);
1886-
if (registry->encoding)
1887-
wasi_nn_registry->encoding = bh_strdup(registry->encoding);
1888-
if (registry->target)
1889-
wasi_nn_registry->target = bh_strdup(registry->target);
1890-
if (registry->loaded)
1891-
wasi_nn_registry->loaded = bh_strdup(registry->loaded);
1892-
if (registry->graph_paths)
1893-
wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths);
1884+
for (uint32_t i = 0; i < registry->n_graphs; i++) {
1885+
registry->graph_paths[i] = bh_strdup(registry->graph_paths[i]);
1886+
registry->model_names[i] = bh_strdup(registry->model_names[i]);
1887+
registry->encoding[i] = registry->encoding[i];
1888+
registry->target[i] = registry->target[i];
1889+
wasi_nn_registry->loaded = registry->loaded;
1890+
}
18941891
}
18951892
#endif
18961893

core/iwasm/common/wasm_runtime_common.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,11 @@ typedef struct WASMModuleInstMemConsumption {
548548
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
549549
typedef struct WASINNRegistry {
550550
char **model_names;
551-
char **encoding;
552-
char **target;
551+
uint32_t **encoding;
552+
uint32_t **target;
553553

554554
uint32_t n_graphs;
555-
uint32_t *loaded;
555+
uint32_t **loaded;
556556
char **graph_paths;
557557
} WASINNRegistry;
558558
#endif
@@ -805,8 +805,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
805805

806806
WASM_RUNTIME_API_EXTERN bool
807807
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
808-
const char **model_names, const char **encoding,
809-
const char **target, uint32_t n_graphs,
808+
const char **model_names, const uint32_t **encoding,
809+
const uint32_t **target, uint32_t n_graphs,
810810
const char **graph_paths);
811811
#endif
812812

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -211,46 +211,6 @@ wasi_nn_destroy()
211211
* - model file format
212212
* - on device ML framework
213213
*/
214-
static graph_encoding
215-
str2encoding(char *str_encoding)
216-
{
217-
if (!str_encoding) {
218-
NN_ERR_PRINTF("Got empty string encoding");
219-
return -1;
220-
}
221-
222-
if (!strcmp(str_encoding, "openvino"))
223-
return openvino;
224-
else if (!strcmp(str_encoding, "tensorflowlite"))
225-
return tensorflowlite;
226-
else if (!strcmp(str_encoding, "ggml"))
227-
return ggml;
228-
else if (!strcmp(str_encoding, "onnx"))
229-
return onnx;
230-
else
231-
return unknown_backend;
232-
// return autodetect;
233-
}
234-
235-
static execution_target
236-
str2target(char *str_target)
237-
{
238-
if (!str_target) {
239-
NN_ERR_PRINTF("Got empty string target");
240-
return -1;
241-
}
242-
243-
if (!strcmp(str_target, "cpu"))
244-
return cpu;
245-
else if (!strcmp(str_target, "gpu"))
246-
return gpu;
247-
else if (!strcmp(str_target, "tpu"))
248-
return tpu;
249-
else
250-
return unsupported_target;
251-
// return autodetect;
252-
}
253-
254214
static graph_encoding
255215
choose_a_backend()
256216
{
@@ -630,10 +590,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
630590
is_loaded = wasi_nn_registry->loaded[model_idx];
631591
char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx];
632592

633-
graph_encoding encoding =
634-
str2encoding(wasi_nn_registry->encoding[model_idx]);
635-
execution_target target =
636-
str2target(wasi_nn_registry->target[model_idx]);
593+
graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]);
594+
execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]);
637595

638596
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
639597
res = ensure_backend(instance, encoding, wasi_nn_ctx);

product-mini/platforms/common/libc_wasi.c

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,77 @@ typedef struct {
2121
uint32 ns_lookup_pool_size;
2222
} libc_wasi_parse_context_t;
2323

24+
typedef enum {
25+
LIBC_WASI_PARSE_RESULT_OK = 0,
26+
LIBC_WASI_PARSE_RESULT_NEED_HELP,
27+
LIBC_WASI_PARSE_RESULT_BAD_PARAM
28+
} libc_wasi_parse_result_t;
29+
2430
typedef struct {
2531
const char *model_names[10];
26-
const char *encoding[10];
27-
const char *target[10];
32+
const uint32_t *encoding[10];
33+
const uint32_t *target[10];
2834
const char *graph_paths[10];
2935
uint32 n_graphs;
3036
} wasi_nn_parse_context_t;
3137

3238
typedef enum {
33-
LIBC_WASI_PARSE_RESULT_OK = 0,
34-
LIBC_WASI_PARSE_RESULT_NEED_HELP,
35-
LIBC_WASI_PARSE_RESULT_BAD_PARAM
36-
} libc_wasi_parse_result_t;
39+
wasi_nn_openvino = 0,
40+
wasi_nn_onnx,
41+
wasi_nn_tensorflow,
42+
wasi_nn_pytorch,
43+
wasi_nn_tensorflowlite,
44+
wasi_nn_ggml,
45+
wasi_nn_autodetect,
46+
wasi_nn_unknown_backend,
47+
} wasi_nn_encoding;
48+
49+
typedef enum wasi_nn_target {
50+
wasi_nn_cpu = 0,
51+
wasi_nn_gpu,
52+
wasi_nn_tpu,
53+
wasi_nn_unsupported_target,
54+
} wasi_nn_target;
55+
56+
static wasi_nn_encoding
57+
str2encoding(char *str_encoding)
58+
{
59+
if (!str_encoding) {
60+
printf("Got empty string encoding");
61+
return -1;
62+
}
63+
64+
if (!strcmp(str_encoding, "openvino"))
65+
return wasi_nn_openvino;
66+
else if (!strcmp(str_encoding, "tensorflowlite"))
67+
return wasi_nn_tensorflowlite;
68+
else if (!strcmp(str_encoding, "ggml"))
69+
return wasi_nn_ggml;
70+
else if (!strcmp(str_encoding, "onnx"))
71+
return wasi_nn_onnx;
72+
else
73+
return wasi_nn_unknown_backend;
74+
// return autodetect;
75+
}
76+
77+
static wasi_nn_target
78+
str2target(char *str_target)
79+
{
80+
if (!str_target) {
81+
printf("Got empty string target");
82+
return -1;
83+
}
84+
85+
if (!strcmp(str_target, "cpu"))
86+
return wasi_nn_cpu;
87+
else if (!strcmp(str_target, "gpu"))
88+
return wasi_nn_gpu;
89+
else if (!strcmp(str_target, "tpu"))
90+
return wasi_nn_tpu;
91+
else
92+
return wasi_nn_unsupported_target;
93+
// return autodetect;
94+
}
3795

3896
static void
3997
libc_wasi_print_help(void)
@@ -223,10 +281,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx)
223281
}
224282

225283
ctx->model_names[ctx->n_graphs] = tokens[0];
226-
ctx->encoding[ctx->n_graphs] = tokens[1];
227-
ctx->target[ctx->n_graphs] = tokens[2];
228-
ctx->graph_paths[ctx->n_graphs++] = tokens[3];
284+
ctx->encoding[ctx->n_graphs] = (uint32_t)str2encoding(tokens[1]);
285+
ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]);
286+
ctx->graph_paths[ctx->n_graphs] = tokens[3];
287+
288+
if ((!ctx->model_names[ctx->n_graphs]) ||
289+
(ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) ||
290+
(ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) {
291+
ret = LIBC_WASI_PARSE_RESULT_NEED_HELP;
292+
printf("Invalid arguments for wasi-nn.\n");
293+
goto fail;
294+
}
229295

296+
ctx->n_graphs++;
230297
fail:
231298

232299
return ret;

0 commit comments

Comments
 (0)