Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions core/iwasm/common/wasm_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL;
static void *g_wasi_context_key;
#endif /* WASM_ENABLE_LIBC_WASI */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
static void *g_wasi_nn_context_key;
#endif

uint32
get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis);

Expand Down Expand Up @@ -473,6 +477,32 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
}
#endif /* end of WASM_ENABLE_LIBC_WASI */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASINNGlobalContext *
wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm)
{
return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key);
}

void
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm,
WASINNGlobalContext *wasi_nn_ctx)
{
wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key,
wasi_nn_ctx);
}

static void
wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
{
if (ctx == NULL) {
return;
}

wasm_runtime_destroy_wasi_nn_global_ctx(inst);
}
#endif

#if WASM_ENABLE_QUICK_AOT_ENTRY != 0
static bool
quick_aot_entry_init(void);
Expand Down Expand Up @@ -582,6 +612,12 @@ wasm_native_init()
#endif /* WASM_ENABLE_LIB_RATS */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
g_wasi_nn_context_key =
wasm_native_create_context_key(wasi_nn_context_dtor);
if (g_wasi_nn_context_key == NULL) {
goto fail;
}

if (!wasi_nn_initialize())
goto fail;

Expand Down Expand Up @@ -648,6 +684,10 @@ wasm_native_destroy()
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (g_wasi_nn_context_key != NULL) {
wasm_native_destroy_context_key(g_wasi_nn_context_key);
g_wasi_nn_context_key = NULL;
}
wasi_nn_destroy();
#endif

Expand Down
184 changes: 184 additions & 0 deletions core/iwasm/common/wasm_runtime_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,65 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p)
wasm_runtime_free(p);
}

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASINNArguments;

void
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
{
memset(args, 0, sizeof(*args));
}

bool
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding,
const char *target, uint32_t n_graphs,
const char **graph_paths)
{
if (!registry || !encoding || !target || !graph_paths) {
return false;
}
registry->encoding = strdup(encoding);
registry->target = strdup(target);
registry->n_graphs = n_graphs;
registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
for (uint32_t i = 0; i < registry->n_graphs; i++)
registry->graph_paths[i] = strdup(graph_paths[i]);

return true;
}

int
wasi_nn_graph_registry_create(WASINNArguments **registryp)
{
WASINNArguments *args = wasm_runtime_malloc(sizeof(*args));
if (args == NULL) {
return -1;
}
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args);
*registryp = args;
return 0;
}

void
wasi_nn_graph_registry_destroy(WASINNArguments *registry)
{
if (registry) {
for (uint32_t i = 0; i < registry->n_graphs; i++)
if (registry->graph_paths[i]) {
// wasi_nn_graph_registry_unregister_graph(registry,
// registry->name[i]);
free(registry->graph_paths[i]);
}
if (registry->encoding)
free(registry->encoding);
if (registry->target)
free(registry->target);
free(registry);
}
}
#endif

void
wasm_runtime_instantiation_args_set_default_stack_size(
struct InstantiationArgs2 *p, uint32 v)
Expand Down Expand Up @@ -1794,6 +1853,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
wasi_args->set_by_user = true;
}
#endif /* WASM_ENABLE_LIBC_WASI != 0 */
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
void
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
struct InstantiationArgs2 *p, WASINNArguments *registry)
{
p->nn_registry = *registry;
}
#endif

WASMModuleInstanceCommon *
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
Expand Down Expand Up @@ -8080,3 +8147,120 @@ wasm_runtime_check_and_update_last_used_shared_heap(
return false;
}
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
bool
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
const char *encoding, const char *target,
const uint32_t n_graphs,
char *graph_paths[], char *error_buf,
uint32_t error_buf_size)
{
WASINNGlobalContext *ctx;
bool ret = false;

ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size);
if (!ctx)
return false;

ctx->encoding = strdup(encoding);
ctx->target = strdup(target);
ctx->n_graphs = n_graphs;
ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);

ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
for (uint32_t i = 0; i < n_graphs; i++) {
ctx->graph_paths[i] = strdup(graph_paths[i]);
}

wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);

ret = true;

return ret;
}

void
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
{
WASINNGlobalContext *wasi_nn_global_ctx =
wasm_runtime_get_wasi_nn_global_ctx(module_inst);

for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) {
// All graphs will be unregistered in deinit()
if (wasi_nn_global_ctx->graph_paths[i])
free(wasi_nn_global_ctx->graph_paths[i]);
}
free(wasi_nn_global_ctx->encoding);
free(wasi_nn_global_ctx->target);
free(wasi_nn_global_ctx->loaded);
free(wasi_nn_global_ctx->graph_paths);

if (wasi_nn_global_ctx) {
wasm_runtime_free(wasi_nn_global_ctx);
}
}

uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->n_graphs;

return -1;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->graph_paths[idx];

return NULL;
}

uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->loaded[idx];

return -1;
}

uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
wasi_nn_global_ctx->loaded[idx] = value;

return 0;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_encoding(
WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->encoding;

return NULL;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_target(
WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->target;

return NULL;
}

#endif
81 changes: 81 additions & 0 deletions core/iwasm/common/wasm_runtime_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption {
uint32 exports_size;
} WASMModuleInstMemConsumption;

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct WASINNGlobalContext {
char *encoding;
char *target;

uint32_t n_graphs;
uint32_t *loaded;
char **graph_paths;
} WASINNGlobalContext;
#endif

#if WASM_ENABLE_LIBC_WASI != 0
#if WASM_ENABLE_UVWASI == 0
typedef struct WASIContext {
Expand Down Expand Up @@ -612,11 +623,30 @@ WASMExecEnv *
wasm_runtime_get_exec_env_tls(void);
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct WASINNArguments {
char *encoding;
char *target;

char **graph_paths;
uint32_t n_graphs;
} WASINNArguments;

WASM_RUNTIME_API_EXTERN int
wasi_nn_graph_registry_create(WASINNArguments **registryp);

WASM_RUNTIME_API_EXTERN void
wasi_nn_graph_registry_destroy(WASINNArguments *registry);
#endif

struct InstantiationArgs2 {
InstantiationArgs v1;
#if WASM_ENABLE_LIBC_WASI != 0
WASIArguments wasi;
#endif
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASINNArguments nn_registry;
#endif
};

void
Expand Down Expand Up @@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
uint32 ns_lookup_pool_size);

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASM_RUNTIME_API_EXTERN void
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
struct InstantiationArgs2 *p, WASINNArguments *registry);

WASM_RUNTIME_API_EXTERN bool
wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding,
const char *target, uint32_t n_graphs,
const char **graph_paths);
#endif

/* See wasm_export.h for description */
WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon *
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
Expand Down Expand Up @@ -1427,6 +1468,46 @@ wasm_runtime_check_and_update_last_used_shared_heap(
uint8 **shared_heap_base_addr_adj_p, bool is_memory64);
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASM_RUNTIME_API_EXTERN bool
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
const char *encoding, const char *target,
const uint32_t n_graphs,
char *graph_paths[], char *error_buf,
uint32_t error_buf_size);

WASM_RUNTIME_API_EXTERN void
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst);

WASM_RUNTIME_API_EXTERN void
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
WASINNGlobalContext *wasi_ctx);

WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
WASINNGlobalContext *wasi_nn_global_ctx);

WASM_RUNTIME_API_EXTERN char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);

WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);

WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value);

WASM_RUNTIME_API_EXTERN char *
wasm_runtime_get_wasi_nn_global_ctx_encoding(
WASINNGlobalContext *wasi_nn_global_ctx);

WASM_RUNTIME_API_EXTERN char *
wasm_runtime_get_wasi_nn_global_ctx_target(
WASINNGlobalContext *wasi_nn_global_ctx);
#endif

#ifdef __cplusplus
}
#endif
Expand Down
Loading
Loading