@@ -1806,23 +1806,27 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
18061806
18071807bool
18081808wasi_nn_graph_registry_set_args (WASINNArguments * registry ,
1809- const char * * encoding , const char * * target ,
1810- uint32_t n_graphs , const char * * graph_paths )
1809+ const char * * model_names , const char * * encoding ,
1810+ const char * * target , uint32_t n_graphs ,
1811+ const char * * graph_paths )
18111812{
1812- if (!registry || !encoding || !target || !graph_paths ) {
1813+ if (!registry || !model_names || ! encoding || !target || !graph_paths ) {
18131814 return false;
18141815 }
18151816
18161817 registry -> n_graphs = n_graphs ;
18171818 registry -> target = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
18181819 registry -> encoding = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
1820+ registry -> model_names = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
18191821 registry -> graph_paths = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
18201822 memset (registry -> target , 0 , sizeof (uint32_t * ) * n_graphs );
18211823 memset (registry -> encoding , 0 , sizeof (uint32_t * ) * n_graphs );
1824+ memset (registry -> model_names , 0 , sizeof (uint32_t * ) * n_graphs );
18221825 memset (registry -> graph_paths , 0 , sizeof (uint32_t * ) * n_graphs );
18231826
18241827 for (uint32_t i = 0 ; i < registry -> n_graphs ; i ++ ) {
18251828 registry -> graph_paths [i ] = strdup (graph_paths [i ]);
1829+ registry -> model_names [i ] = strdup (model_names [i ]);
18261830 registry -> encoding [i ] = strdup (encoding [i ]);
18271831 registry -> target [i ] = strdup (target [i ]);
18281832 }
@@ -1849,6 +1853,8 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry)
18491853 for (uint32_t i = 0 ; i < registry -> n_graphs ; i ++ )
18501854 if (registry -> graph_paths [i ]) {
18511855 free (registry -> graph_paths [i ]);
1856+ if (registry -> model_names [i ])
1857+ free (registry -> model_names [i ]);
18521858 if (registry -> encoding [i ])
18531859 free (registry -> encoding [i ]);
18541860 if (registry -> target [i ])
@@ -8155,6 +8161,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
81558161#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
81568162bool
81578163wasm_runtime_init_wasi_nn_global_ctx (WASMModuleInstanceCommon * module_inst ,
8164+ const char * * model_names ,
81588165 const char * * encoding , const char * * target ,
81598166 const uint32_t n_graphs ,
81608167 char * graph_paths [], char * error_buf ,
@@ -8175,11 +8182,14 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
81758182 memset (ctx -> target , 0 , sizeof (uint32_t ) * n_graphs );
81768183 ctx -> loaded = (uint32_t * )malloc (sizeof (uint32_t ) * n_graphs );
81778184 memset (ctx -> loaded , 0 , sizeof (uint32_t ) * n_graphs );
8185+ ctx -> model_names = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
8186+ memset (ctx -> model_names , 0 , sizeof (uint32_t * ) * n_graphs );
81788187 ctx -> graph_paths = (uint32_t * * )malloc (sizeof (uint32_t * ) * n_graphs );
81798188 memset (ctx -> graph_paths , 0 , sizeof (uint32_t * ) * n_graphs );
81808189
81818190 for (uint32_t i = 0 ; i < n_graphs ; i ++ ) {
81828191 ctx -> graph_paths [i ] = strdup (graph_paths [i ]);
8192+ ctx -> model_names [i ] = strdup (model_names [i ]);
81838193 ctx -> target [i ] = strdup (target [i ]);
81848194 ctx -> encoding [i ] = strdup (encoding [i ]);
81858195 }
@@ -8201,14 +8211,17 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
82018211 // All graphs will be unregistered in deinit()
82028212 if (wasi_nn_global_ctx -> graph_paths [i ])
82038213 free (wasi_nn_global_ctx -> graph_paths [i ]);
8214+ if (wasi_nn_global_ctx -> model_names [i ])
8215+ free (wasi_nn_global_ctx -> model_names [i ]);
82048216 if (wasi_nn_global_ctx -> encoding [i ])
82058217 free (wasi_nn_global_ctx -> encoding [i ]);
8206- if (wasi_nn_global_ctx -> encoding [i ])
8218+ if (wasi_nn_global_ctx -> target [i ])
82078219 free (wasi_nn_global_ctx -> target [i ]);
82088220 }
82098221 free (wasi_nn_global_ctx -> encoding );
82108222 free (wasi_nn_global_ctx -> target );
82118223 free (wasi_nn_global_ctx -> loaded );
8224+ free (wasi_nn_global_ctx -> model_names );
82128225 free (wasi_nn_global_ctx -> graph_paths );
82138226
82148227 if (wasi_nn_global_ctx ) {
@@ -8226,6 +8239,16 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
82268239 return -1 ;
82278240}
82288241
8242+ char *
8243+ wasm_runtime_get_wasi_nn_global_ctx_model_names_i (
8244+ WASINNGlobalContext * wasi_nn_global_ctx , uint32_t idx )
8245+ {
8246+ if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx -> n_graphs ))
8247+ return wasi_nn_global_ctx -> model_names [idx ];
8248+
8249+ return NULL ;
8250+ }
8251+
82298252char *
82308253wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i (
82318254 WASINNGlobalContext * wasi_nn_global_ctx , uint32_t idx )
0 commit comments