2222
2323/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
2424
25- typedef error (* LOAD )(graph_builder_array * , graph_encoding , execution_target ,
26- graph * );
27- typedef error (* INIT_EXECUTION_CONTEXT )(graph , graph_execution_context * );
28- typedef error (* SET_INPUT )(graph_execution_context , uint32_t , tensor * );
29- typedef error (* COMPUTE )(graph_execution_context );
30- typedef error (* GET_OUTPUT )(graph_execution_context , uint32_t , tensor_data ,
31- uint32_t * );
25+ typedef error (* LOAD )(void * , graph_builder_array * , graph_encoding ,
26+ execution_target , graph * );
27+ typedef error (* INIT_EXECUTION_CONTEXT )(void * , graph ,
28+ graph_execution_context * );
29+ typedef error (* SET_INPUT )(void * , graph_execution_context , uint32_t , tensor * );
30+ typedef error (* COMPUTE )(void * , graph_execution_context );
31+ typedef error (* GET_OUTPUT )(void * , graph_execution_context , uint32_t ,
32+ tensor_data , uint32_t * );
3233
3334typedef struct {
3435 LOAD load ;
@@ -123,12 +124,12 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
123124 goto fail ;
124125 }
125126
126- res = lookup [encoding ].load (& builder_native , encoding , target , g );
127+ WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
128+ res = lookup [encoding ].load (wasi_nn_ctx -> tflite_ctx , & builder_native ,
129+ encoding , target , g );
127130
128131 NN_DBG_PRINTF ("wasi_nn_load finished with status %d [graph=%d]" , res , * g );
129132
130- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
131-
132133 wasi_nn_ctx -> current_encoding = encoding ;
133134 wasi_nn_ctx -> is_initialized = true;
134135
@@ -160,8 +161,9 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
160161 return invalid_argument ;
161162 }
162163
163- res = lookup [wasi_nn_ctx -> current_encoding ].init_execution_context (g , ctx );
164- * ctx = g ;
164+ res = lookup [wasi_nn_ctx -> current_encoding ].init_execution_context (
165+ wasi_nn_ctx -> tflite_ctx , g , ctx );
166+
165167 NN_DBG_PRINTF (
166168 "wasi_nn_init_execution_context finished with status %d [ctx=%d]" , res ,
167169 * ctx );
@@ -189,8 +191,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
189191 & input_tensor_native )))
190192 return res ;
191193
192- res = lookup [wasi_nn_ctx -> current_encoding ].set_input (ctx , index ,
193- & input_tensor_native );
194+ res = lookup [wasi_nn_ctx -> current_encoding ].set_input (
195+ wasi_nn_ctx -> tflite_ctx , ctx , index , & input_tensor_native );
194196
195197 // XXX: Free intermediate structure pointers
196198 if (input_tensor_native .dimensions )
@@ -213,7 +215,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
213215 if (success != (res = is_model_initialized (wasi_nn_ctx )))
214216 return res ;
215217
216- res = lookup [wasi_nn_ctx -> current_encoding ].compute (ctx );
218+ res = lookup [wasi_nn_ctx -> current_encoding ].compute (wasi_nn_ctx -> tflite_ctx ,
219+ ctx );
217220 NN_DBG_PRINTF ("wasi_nn_compute finished with status %d" , res );
218221 return res ;
219222}
@@ -241,7 +244,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
241244 }
242245
243246 res = lookup [wasi_nn_ctx -> current_encoding ].get_output (
244- ctx , index , output_tensor , output_tensor_size );
247+ wasi_nn_ctx -> tflite_ctx , ctx , index , output_tensor , output_tensor_size );
245248 NN_DBG_PRINTF ("wasi_nn_get_output finished with status %d [data_size=%d]" ,
246249 res , * output_tensor_size );
247250 return res ;
@@ -261,6 +264,7 @@ wasi_nn_initialize()
261264 }
262265 wasi_nn_ctx -> is_initialized = true;
263266 wasi_nn_ctx -> current_encoding = 3 ;
267+ tensorflowlite_initialize (& wasi_nn_ctx -> tflite_ctx );
264268 return wasi_nn_ctx ;
265269}
266270
@@ -275,7 +279,7 @@ wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
275279 NN_DBG_PRINTF ("Freeing wasi-nn" );
276280 NN_DBG_PRINTF ("-> is_initialized: %d" , wasi_nn_ctx -> is_initialized );
277281 NN_DBG_PRINTF ("-> current_encoding: %d" , wasi_nn_ctx -> current_encoding );
278- tensorflowlite_destroy ();
282+ tensorflowlite_destroy (wasi_nn_ctx -> tflite_ctx );
279283 wasm_runtime_free (wasi_nn_ctx );
280284}
281285
0 commit comments