Skip to content

Commit 96cdfa6

Browse files
QiuYuan Hanzhanheng1
authored andcommitted
Add the way to set the target evenif we use load_by_name
1 parent 2063ac1 commit 96cdfa6

21 files changed

+659
-128
lines changed

core/iwasm/common/wasm_native.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL;
2525
static void *g_wasi_context_key;
2626
#endif /* WASM_ENABLE_LIBC_WASI */
2727

28+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
29+
static void *g_wasi_nn_context_key;
30+
#endif
31+
2832
uint32
2933
get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis);
3034

@@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
473477
}
474478
#endif /* end of WASM_ENABLE_LIBC_WASI */
475479

480+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
481+
WASINNGlobalContext *
482+
wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm)
483+
{
484+
return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key);
485+
}
486+
487+
void
488+
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm,
489+
WASINNGlobalContext *wasi_nn_ctx)
490+
{
491+
wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx);
492+
}
493+
494+
static void
495+
wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
496+
{
497+
if (ctx == NULL) {
498+
return;
499+
}
500+
501+
wasm_runtime_destroy_wasi_nn_global_ctx(inst);
502+
}
503+
#endif
504+
476505
#if WASM_ENABLE_QUICK_AOT_ENTRY != 0
477506
static bool
478507
quick_aot_entry_init(void);
@@ -582,6 +611,11 @@ wasm_native_init()
582611
#endif /* WASM_ENABLE_LIB_RATS */
583612

584613
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
614+
g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor);
615+
if (g_wasi_nn_context_key == NULL) {
616+
goto fail;
617+
}
618+
585619
if (!wasi_nn_initialize())
586620
goto fail;
587621

@@ -648,6 +682,10 @@ wasm_native_destroy()
648682
#endif
649683

650684
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
685+
if (g_wasi_nn_context_key != NULL) {
686+
wasm_native_destroy_context_key(g_wasi_nn_context_key);
687+
g_wasi_nn_context_key = NULL;
688+
}
651689
wasi_nn_destroy();
652690
#endif
653691

core/iwasm/common/wasm_runtime_common.c

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p)
16961696
wasm_runtime_free(p);
16971697
}
16981698

1699+
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
1700+
struct wasi_nn_graph_registry;
1701+
1702+
void
1703+
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args)
1704+
{
1705+
memset(args, 0, sizeof(*args));
1706+
}
1707+
1708+
bool
1709+
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
1710+
const char* target, uint32_t n_graphs,
1711+
const char** graph_paths)
1712+
{
1713+
if (!registry || !encoding || !target || !graph_paths)
1714+
{
1715+
return false;
1716+
}
1717+
registry->encoding = strdup(encoding);
1718+
registry->target = strdup(target);
1719+
registry->n_graphs = n_graphs;
1720+
registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
1721+
memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
1722+
for (uint32_t i = 0; i < registry->n_graphs; i++)
1723+
registry->graph_paths[i] = strdup(graph_paths[i]);
1724+
1725+
return true;
1726+
}
1727+
1728+
int
1729+
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp)
1730+
{
1731+
struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args));
1732+
if (args == NULL) {
1733+
return false;
1734+
}
1735+
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args);
1736+
*registryp = args;
1737+
return 0;
1738+
}
1739+
1740+
void
1741+
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry)
1742+
{
1743+
if (registry)
1744+
{
1745+
for (uint32_t i = 0; i < registry->n_graphs; i++)
1746+
if (registry->graph_paths[i])
1747+
{
1748+
// wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]);
1749+
free(registry->graph_paths[i]);
1750+
}
1751+
if (registry->encoding)
1752+
free(registry->encoding);
1753+
if (registry->target)
1754+
free(registry->target);
1755+
free(registry);
1756+
}
1757+
}
1758+
#endif
1759+
16991760
void
17001761
wasm_runtime_instantiation_args_set_default_stack_size(
17011762
struct InstantiationArgs2 *p, uint32 v)
@@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
17941855
wasi_args->set_by_user = true;
17951856
}
17961857
#endif /* WASM_ENABLE_LIBC_WASI != 0 */
1858+
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
1859+
void
1860+
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
1861+
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry)
1862+
{
1863+
p->nn_registry = *registry;
1864+
}
1865+
#endif
17971866

17981867
WASMModuleInstanceCommon *
17991868
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
@@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap(
80808149
return false;
80818150
}
80828151
#endif
8152+
8153+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
8154+
bool
8155+
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
8156+
const char* encoding, const char* target,
8157+
const uint32_t n_graphs, char* graph_paths[],
8158+
char *error_buf, uint32_t error_buf_size)
8159+
{
8160+
WASINNGlobalContext *ctx;
8161+
bool ret = false;
8162+
8163+
ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size);
8164+
if (!ctx)
8165+
return false;
8166+
8167+
ctx->encoding = strdup(encoding);
8168+
ctx->target = strdup(target);
8169+
ctx->n_graphs = n_graphs;
8170+
ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs);
8171+
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
8172+
8173+
ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
8174+
memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
8175+
for (uint32_t i = 0; i < n_graphs; i++)
8176+
{
8177+
ctx->graph_paths[i] = strdup(graph_paths[i]);
8178+
}
8179+
8180+
wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);
8181+
8182+
ret = true;
8183+
8184+
return ret;
8185+
}
8186+
8187+
void
8188+
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
8189+
{
8190+
WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst);
8191+
8192+
for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++)
8193+
{
8194+
// All graphs will be unregistered in deinit()
8195+
if (wasi_nn_global_ctx->graph_paths[i])
8196+
free(wasi_nn_global_ctx->graph_paths[i]);
8197+
}
8198+
free(wasi_nn_global_ctx->encoding);
8199+
free(wasi_nn_global_ctx->target);
8200+
free(wasi_nn_global_ctx->loaded);
8201+
free(wasi_nn_global_ctx->graph_paths);
8202+
8203+
if (wasi_nn_global_ctx) {
8204+
wasm_runtime_free(wasi_nn_global_ctx);
8205+
}
8206+
}
8207+
8208+
uint32_t
8209+
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx)
8210+
{
8211+
if (wasi_nn_global_ctx)
8212+
return wasi_nn_global_ctx->n_graphs;
8213+
8214+
return -1;
8215+
}
8216+
8217+
char *
8218+
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
8219+
{
8220+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8221+
return wasi_nn_global_ctx->graph_paths[idx];
8222+
8223+
return NULL;
8224+
}
8225+
8226+
uint32_t
8227+
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
8228+
{
8229+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8230+
return wasi_nn_global_ctx->loaded[idx];
8231+
8232+
return -1;
8233+
}
8234+
8235+
uint32_t
8236+
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value)
8237+
{
8238+
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
8239+
wasi_nn_global_ctx->loaded[idx] = value;
8240+
8241+
return 0;
8242+
}
8243+
8244+
char*
8245+
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx)
8246+
{
8247+
if (wasi_nn_global_ctx)
8248+
return wasi_nn_global_ctx->encoding;
8249+
8250+
return NULL;
8251+
}
8252+
8253+
char*
8254+
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx)
8255+
{
8256+
if (wasi_nn_global_ctx)
8257+
return wasi_nn_global_ctx->target;
8258+
8259+
return NULL;
8260+
}
8261+
8262+
#endif

core/iwasm/common/wasm_runtime_common.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption {
545545
uint32 exports_size;
546546
} WASMModuleInstMemConsumption;
547547

548+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
549+
typedef struct WASINNGlobalContext {
550+
char* encoding;
551+
char* target;
552+
553+
uint32_t n_graphs;
554+
uint32_t *loaded;
555+
char** graph_paths;
556+
} WASINNGlobalContext;
557+
#endif
558+
548559
#if WASM_ENABLE_LIBC_WASI != 0
549560
#if WASM_ENABLE_UVWASI == 0
550561
typedef struct WASIContext {
@@ -612,11 +623,30 @@ WASMExecEnv *
612623
wasm_runtime_get_exec_env_tls(void);
613624
#endif
614625

626+
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
627+
struct wasi_nn_graph_registry {
628+
char* encoding;
629+
char* target;
630+
631+
char** graph_paths;
632+
uint32_t n_graphs;
633+
};
634+
635+
WASM_RUNTIME_API_EXTERN int
636+
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp);
637+
638+
WASM_RUNTIME_API_EXTERN void
639+
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry);
640+
#endif
641+
615642
struct InstantiationArgs2 {
616643
InstantiationArgs v1;
617644
#if WASM_ENABLE_LIBC_WASI != 0
618645
WASIArguments wasi;
619646
#endif
647+
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
648+
struct wasi_nn_graph_registry nn_registry;
649+
#endif
620650
};
621651

622652
void
@@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
775805
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
776806
uint32 ns_lookup_pool_size);
777807

808+
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
809+
WASM_RUNTIME_API_EXTERN void
810+
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
811+
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry);
812+
813+
WASM_RUNTIME_API_EXTERN bool
814+
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
815+
const char* target, uint32_t n_graphs,
816+
const char** graph_paths);
817+
#endif
818+
778819
/* See wasm_export.h for description */
779820
WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon *
780821
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
@@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap(
14271468
uint8 **shared_heap_base_addr_adj_p, bool is_memory64);
14281469
#endif
14291470

1471+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
1472+
WASM_RUNTIME_API_EXTERN bool
1473+
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
1474+
const char* encoding, const char* target,
1475+
const uint32_t n_graphs, char* graph_paths[],
1476+
char *error_buf, uint32_t error_buf_size);
1477+
1478+
WASM_RUNTIME_API_EXTERN void
1479+
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst);
1480+
1481+
WASM_RUNTIME_API_EXTERN void
1482+
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
1483+
WASINNGlobalContext *wasi_ctx);
1484+
1485+
WASM_RUNTIME_API_EXTERN uint32_t
1486+
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx);
1487+
1488+
WASM_RUNTIME_API_EXTERN char *
1489+
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
1490+
1491+
WASM_RUNTIME_API_EXTERN uint32_t
1492+
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
1493+
1494+
WASM_RUNTIME_API_EXTERN uint32_t
1495+
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value);
1496+
1497+
WASM_RUNTIME_API_EXTERN char*
1498+
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx);
1499+
1500+
WASM_RUNTIME_API_EXTERN char*
1501+
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx);
1502+
#endif
1503+
14301504
#ifdef __cplusplus
14311505
}
14321506
#endif

0 commit comments

Comments
 (0)