Skip to content

Commit a15a731

Browse files
authored
wasi-nn: Support multiple TFLite models (#2002)
Remove restrictions: - Only 1 WASM app at a time - Only 1 model at a time - `graph` and `graph-execution-context` are ignored Refer to previous document: https://github.com/bytecodealliance/wasm-micro-runtime/blob/e8d718096dc56d4b1aa66ec6cd04d6024ca1c6e2/core/iwasm/libraries/wasi-nn/README.md
1 parent f279ba8 commit a15a731

16 files changed

Lines changed: 568 additions & 347 deletions

build-scripts/config_common.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ if (WAMR_BUILD_SGX_IPFS EQUAL 1)
333333
endif ()
334334
if (WAMR_BUILD_WASI_NN EQUAL 1)
335335
message (" WASI-NN enabled")
336+
add_definitions (-DWASM_ENABLE_WASI_NN=1)
337+
if (WASI_NN_ENABLE_GPU EQUAL 1)
338+
message (" WASI-NN: GPU enabled")
339+
add_definitions (-DWASI_NN_ENABLE_GPU=1)
340+
endif ()
336341
endif ()
337342
if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1)
338343
add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1)

build-scripts/runtime_lib.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
109109
message("Tensorflow is already downloaded.")
110110
endif()
111111
set(TENSORFLOW_SOURCE_DIR "${WAMR_ROOT_DIR}/core/deps/tensorflow-src")
112+
113+
if (WASI_NN_ENABLE_GPU EQUAL 1)
114+
# Tensorflow specific:
115+
# * https://www.tensorflow.org/lite/guide/build_cmake#available_options_to_build_tensorflow_lite
116+
set (TFLITE_ENABLE_GPU ON)
117+
endif ()
118+
112119
include_directories (${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/include)
113120
include_directories (${TENSORFLOW_SOURCE_DIR})
114121
add_subdirectory(

core/iwasm/libraries/wasi-nn/README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ To run the tests we assume that the current directory is the root of the reposit
1919

2020
### Build the runtime
2121

22-
Build the runtime base image,
23-
24-
```
25-
docker build -t wasi-nn-base -f core/iwasm/libraries/wasi-nn/test/Dockerfile.base .
26-
```
27-
2822
Build the runtime image for your execution target type.
2923

3024
`EXECUTION_TYPE` can be:
@@ -84,9 +78,6 @@ Requirements:
8478

8579
Supported:
8680

87-
* Only 1 WASM app at a time.
88-
* Only 1 model at a time.
89-
* `graph` and `graph-execution-context` are ignored.
9081
* Graph encoding: `tensorflowlite`.
9182
* Execution target: `cpu` and `gpu`.
9283
* Tensor type: `fp32`.

core/iwasm/libraries/wasi-nn/src/utils/logger.h

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,57 @@
1313
(strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
1414

1515
/* Disable a level by removing the define */
16-
#define ENABLE_ERR_LOG
17-
#define ENABLE_WARN_LOG
18-
#define ENABLE_DBG_LOG
19-
#define ENABLE_INFO_LOG
16+
#ifndef NN_LOG_LEVEL
17+
/*
18+
0 -> debug, info, warn, err
19+
1 -> info, warn, err
20+
2 -> warn, err
21+
3 -> err
22+
4 -> NO LOGS
23+
*/
24+
#define NN_LOG_LEVEL 0
25+
#endif
2026

2127
// Definition of the levels
22-
#ifdef ENABLE_ERR_LOG
23-
#define NN_ERR_PRINTF(fmt, ...) \
24-
do { \
25-
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
26-
printf("\n"); \
27-
fflush(stdout); \
28+
#if NN_LOG_LEVEL <= 3
29+
#define NN_ERR_PRINTF(fmt, ...) \
30+
do { \
31+
printf("[%s:%d ERROR] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
32+
printf("\n"); \
33+
fflush(stdout); \
2834
} while (0)
2935
#else
3036
#define NN_ERR_PRINTF(fmt, ...)
3137
#endif
32-
#ifdef ENABLE_WARN_LOG
33-
#define NN_WARN_PRINTF(fmt, ...) \
34-
do { \
35-
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
36-
printf("\n"); \
37-
fflush(stdout); \
38+
#if NN_LOG_LEVEL <= 2
39+
#define NN_WARN_PRINTF(fmt, ...) \
40+
do { \
41+
printf("[%s:%d WARNING] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
42+
printf("\n"); \
43+
fflush(stdout); \
3844
} while (0)
3945
#else
4046
#define NN_WARN_PRINTF(fmt, ...)
4147
#endif
42-
#ifdef ENABLE_DBG_LOG
43-
#define NN_DBG_PRINTF(fmt, ...) \
44-
do { \
45-
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
46-
printf("\n"); \
47-
fflush(stdout); \
48+
#if NN_LOG_LEVEL <= 1
49+
#define NN_INFO_PRINTF(fmt, ...) \
50+
do { \
51+
printf("[%s:%d INFO] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
52+
printf("\n"); \
53+
fflush(stdout); \
4854
} while (0)
4955
#else
50-
#define NN_DBG_PRINTF(fmt, ...)
56+
#define NN_INFO_PRINTF(fmt, ...)
5157
#endif
52-
#ifdef ENABLE_INFO_LOG
53-
#define NN_INFO_PRINTF(fmt, ...) \
54-
do { \
55-
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
56-
printf("\n"); \
57-
fflush(stdout); \
58+
#if NN_LOG_LEVEL <= 0
59+
#define NN_DBG_PRINTF(fmt, ...) \
60+
do { \
61+
printf("[%s:%d DEBUG] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
62+
printf("\n"); \
63+
fflush(stdout); \
5864
} while (0)
5965
#else
60-
#define NN_INFO_PRINTF(fmt, ...)
66+
#define NN_DBG_PRINTF(fmt, ...)
6167
#endif
6268

6369
#endif

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222

2323
/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
2424

25-
typedef error (*LOAD)(graph_builder_array *, graph_encoding, execution_target,
26-
graph *);
27-
typedef error (*INIT_EXECUTION_CONTEXT)(graph, graph_execution_context *);
28-
typedef error (*SET_INPUT)(graph_execution_context, uint32_t, tensor *);
29-
typedef error (*COMPUTE)(graph_execution_context);
30-
typedef error (*GET_OUTPUT)(graph_execution_context, uint32_t, tensor_data,
31-
uint32_t *);
25+
typedef error (*LOAD)(void *, graph_builder_array *, graph_encoding,
26+
execution_target, graph *);
27+
typedef error (*INIT_EXECUTION_CONTEXT)(void *, graph,
28+
graph_execution_context *);
29+
typedef error (*SET_INPUT)(void *, graph_execution_context, uint32_t, tensor *);
30+
typedef error (*COMPUTE)(void *, graph_execution_context);
31+
typedef error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
32+
tensor_data, uint32_t *);
3233

3334
typedef struct {
3435
LOAD load;
@@ -123,12 +124,12 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
123124
goto fail;
124125
}
125126

126-
res = lookup[encoding].load(&builder_native, encoding, target, g);
127+
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
128+
res = lookup[encoding].load(wasi_nn_ctx->tflite_ctx, &builder_native,
129+
encoding, target, g);
127130

128131
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
129132

130-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
131-
132133
wasi_nn_ctx->current_encoding = encoding;
133134
wasi_nn_ctx->is_initialized = true;
134135

@@ -160,8 +161,9 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
160161
return invalid_argument;
161162
}
162163

163-
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(g, ctx);
164-
*ctx = g;
164+
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
165+
wasi_nn_ctx->tflite_ctx, g, ctx);
166+
165167
NN_DBG_PRINTF(
166168
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
167169
*ctx);
@@ -189,8 +191,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
189191
&input_tensor_native)))
190192
return res;
191193

192-
res = lookup[wasi_nn_ctx->current_encoding].set_input(ctx, index,
193-
&input_tensor_native);
194+
res = lookup[wasi_nn_ctx->current_encoding].set_input(
195+
wasi_nn_ctx->tflite_ctx, ctx, index, &input_tensor_native);
194196

195197
// XXX: Free intermediate structure pointers
196198
if (input_tensor_native.dimensions)
@@ -213,7 +215,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
213215
if (success != (res = is_model_initialized(wasi_nn_ctx)))
214216
return res;
215217

216-
res = lookup[wasi_nn_ctx->current_encoding].compute(ctx);
218+
res = lookup[wasi_nn_ctx->current_encoding].compute(wasi_nn_ctx->tflite_ctx,
219+
ctx);
217220
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
218221
return res;
219222
}
@@ -241,7 +244,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
241244
}
242245

243246
res = lookup[wasi_nn_ctx->current_encoding].get_output(
244-
ctx, index, output_tensor, output_tensor_size);
247+
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
245248
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
246249
res, *output_tensor_size);
247250
return res;
@@ -261,6 +264,7 @@ wasi_nn_initialize()
261264
}
262265
wasi_nn_ctx->is_initialized = true;
263266
wasi_nn_ctx->current_encoding = 3;
267+
tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
264268
return wasi_nn_ctx;
265269
}
266270

@@ -275,7 +279,7 @@ wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
275279
NN_DBG_PRINTF("Freeing wasi-nn");
276280
NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
277281
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
278-
tensorflowlite_destroy();
282+
tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
279283
wasm_runtime_free(wasi_nn_ctx);
280284
}
281285

core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
typedef struct {
1212
bool is_initialized;
1313
graph_encoding current_encoding;
14+
void *tflite_ctx;
1415
} WASINNContext;
1516

1617
/**

0 commit comments

Comments
 (0)