@@ -51,17 +51,6 @@ typedef struct {
5151 OnnxRuntimeExecCtx exec_ctxs[MAX_CONTEXTS];
5252} OnnxRuntimeContext;
5353
54- /* Helper functions */
55- static void
56- check_status_and_log (const OnnxRuntimeContext *ctx, OrtStatus *status)
57- {
58- if (status != nullptr ) {
59- const char *msg = ctx->ort_api ->GetErrorMessage (status);
60- NN_ERR_PRINTF (" ONNX Runtime error: %s" , msg);
61- ctx->ort_api ->ReleaseStatus (status);
62- }
63- }
64-
6554static wasi_nn_error
6655convert_ort_error_to_wasi_nn_error (const OnnxRuntimeContext *ctx,
6756 OrtStatus *status)
@@ -104,37 +93,6 @@ convert_ort_error_to_wasi_nn_error(const OnnxRuntimeContext *ctx,
10493 return err;
10594}
10695
107- static bool
108- convert_ort_type_to_wasi_nn_type (ONNXTensorElementDataType ort_type,
109- tensor_type *tensor_type)
110- {
111- switch (ort_type) {
112- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
113- *tensor_type = fp32;
114- break ;
115- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
116- *tensor_type = fp16;
117- break ;
118- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
119- *tensor_type = fp64;
120- break ;
121- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
122- *tensor_type = u8 ;
123- break ;
124- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
125- *tensor_type = i32 ;
126- break ;
127- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
128- *tensor_type = i64 ;
129- break ;
130- default :
131- NN_WARN_PRINTF (" Unsupported ONNX tensor type: %d" , ort_type);
132- return false ;
133- }
134-
135- return true ;
136- }
137-
13896static bool
13997convert_wasi_nn_type_to_ort_type (tensor_type type,
14098 ONNXTensorElementDataType *ort_type)
@@ -191,7 +149,6 @@ init_backend(void **onnx_ctx)
191149 err = convert_ort_error_to_wasi_nn_error (ctx, status);
192150 NN_ERR_PRINTF (" Failed to create ONNX Runtime environment: %s" ,
193151 error_message);
194- ctx->ort_api ->ReleaseStatus (status);
195152 goto fail;
196153 }
197154 NN_INFO_PRINTF (" ONNX Runtime environment created successfully" );
@@ -274,6 +231,17 @@ deinit_backend(void *onnx_ctx)
274231 for (auto &output : ctx->exec_ctxs [i].outputs ) {
275232 ctx->ort_api ->ReleaseValue (output.second );
276233 }
234+
235+ for (auto name : ctx->exec_ctxs [i].input_names ) {
236+ free ((void *)name);
237+ }
238+ ctx->exec_ctxs [i].input_names .clear ();
239+
240+ for (auto name : ctx->exec_ctxs [i].output_names ) {
241+ free ((void *)name);
242+ }
243+ ctx->exec_ctxs [i].output_names .clear ();
244+
277245 ctx->ort_api ->ReleaseMemoryInfo (ctx->exec_ctxs [i].memory_info );
278246 ctx->exec_ctxs [i].is_initialized = false ;
279247 }
@@ -293,6 +261,10 @@ __attribute__((visibility("default"))) wasi_nn_error
293261load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
294262 execution_target target, graph *g)
295263{
264+ if (!onnx_ctx) {
265+ return runtime_error;
266+ }
267+
296268 if (encoding != onnx) {
297269 NN_ERR_PRINTF (" Unsupported encoding: %d" , encoding);
298270 return invalid_encoding;
@@ -349,7 +321,6 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
349321 wasi_nn_error err = convert_ort_error_to_wasi_nn_error (ctx, status);
350322 NN_ERR_PRINTF (" Failed to create ONNX Runtime session: %s" ,
351323 error_message);
352- ctx->ort_api ->ReleaseStatus (status);
353324 return err;
354325 }
355326
@@ -365,6 +336,10 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
365336__attribute__ ((visibility(" default" ))) wasi_nn_error
366337load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g)
367338{
339+ if (!onnx_ctx) {
340+ return runtime_error;
341+ }
342+
368343 OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
369344 std::lock_guard<std::mutex> lock (ctx->mutex );
370345
0 commit comments