Skip to content

Commit 9fc0848

Browse files
committed
Check tflite and onnx first when auto detecting encoding.
1 parent cbf1650 commit 9fc0848

1 file changed

Lines changed: 17 additions & 8 deletions

File tree

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ struct backends_api_functions {
5656
NN_ERR_PRINTF("Error %s() -> %d", #func, wasi_error); \
5757
} while (0)
5858

59+
static graph_encoding auto_detect_encoding_order[] = {
60+
tensorflowlite, onnx, openvino, tensorflow,
61+
pytorch, ggml, autodetect, unknown_backend
62+
};
63+
static int auto_detect_encoding_num =
64+
sizeof(auto_detect_encoding_order) / sizeof(graph_encoding);
5965
static void *wasi_nn_key;
6066

6167
static void
@@ -492,21 +498,22 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
492498
}
493499

494500
if (encoding == autodetect) {
495-
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
501+
for (int i = 0; i < auto_detect_encoding_num; i++) {
496502
if (wasi_nn_ctx->is_backend_ctx_initialized) {
497503
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
498504
wasi_nn_ctx->backend_ctx);
499505
}
500506

501-
res = ensure_backend(instance, e, wasi_nn_ctx);
507+
res = ensure_backend(instance, auto_detect_encoding_order[i],
508+
wasi_nn_ctx);
502509
if (res != success) {
503510
NN_ERR_PRINTF("continue trying the next");
504511
continue;
505512
}
506513

507514
call_wasi_nn_func(wasi_nn_ctx->backend, load, res,
508-
wasi_nn_ctx->backend_ctx, &builder_native, e,
509-
target, g);
515+
wasi_nn_ctx->backend_ctx, &builder_native,
516+
auto_detect_encoding_order[i], target, g);
510517
if (res != success) {
511518
NN_ERR_PRINTF("continue trying the next");
512519
continue;
@@ -595,13 +602,14 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
595602
goto fail;
596603
}
597604

598-
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
605+
for (int i = 0; i < auto_detect_encoding_num; i++) {
599606
if (wasi_nn_ctx->is_backend_ctx_initialized) {
600607
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
601608
wasi_nn_ctx->backend_ctx);
602609
}
603610

604-
res = ensure_backend(instance, e, wasi_nn_ctx);
611+
res = ensure_backend(instance, auto_detect_encoding_order[i],
612+
wasi_nn_ctx);
605613
if (res != success) {
606614
NN_ERR_PRINTF("continue trying the next");
607615
continue;
@@ -669,13 +677,14 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
669677
goto fail;
670678
}
671679

672-
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
680+
for (int i = 0; i < auto_detect_encoding_num; i++) {
673681
if (wasi_nn_ctx->is_backend_ctx_initialized) {
674682
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
675683
wasi_nn_ctx->backend_ctx);
676684
}
677685

678-
res = ensure_backend(instance, e, wasi_nn_ctx);
686+
res = ensure_backend(instance, auto_detect_encoding_order[i],
687+
wasi_nn_ctx);
679688
if (res != success) {
680689
NN_ERR_PRINTF("continue trying the next");
681690
continue;

0 commit comments

Comments
 (0)