From 96cdfa63adef2036af42f450ef70de9a6f56ccf5 Mon Sep 17 00:00:00 2001 From: QiuYuan Han Date: Wed, 10 Dec 2025 13:52:50 +0800 Subject: [PATCH 1/9] Add the way to set the target evenif we use load_by_name --- core/iwasm/common/wasm_native.c | 38 ++++ core/iwasm/common/wasm_runtime_common.c | 180 ++++++++++++++++++ core/iwasm/common/wasm_runtime_common.h | 74 +++++++ core/iwasm/include/wasm_export.h | 51 +++++ core/iwasm/interpreter/wasm_runtime.c | 12 ++ .../wasi-nn/include/wasi_ephemeral_nn.h | 4 +- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 2 +- .../libraries/wasi-nn/include/wasi_nn_types.h | 3 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 119 +++++++++++- .../libraries/wasi-nn/src/wasi_nn_backend.h | 3 +- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 3 +- .../libraries/wasi-nn/src/wasi_nn_openvino.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_private.h | 3 +- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 6 +- .../libraries/wasi-nn/test/requirements.txt | 2 +- .../libraries/wasi-nn/test/test_tensorflow.c | 66 +++---- .../wasi-nn/test/test_tensorflow_quantized.c | 26 +-- core/iwasm/libraries/wasi-nn/test/utils.c | 104 ++++++---- core/iwasm/libraries/wasi-nn/test/utils.h | 23 +-- product-mini/platforms/posix/main.c | 62 ++++++ 21 files changed, 659 insertions(+), 128 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 42aa55db28..8938524db8 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL; static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +static void *g_wasi_nn_context_key; +#endif + uint32 get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis); @@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) } #endif /* end of WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASINNGlobalContext * +wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) +{ + return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key); +} + +void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, + WASINNGlobalContext *wasi_nn_ctx) +{ + wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx); +} + +static void +wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) +{ + if (ctx == NULL) { + return; + } + + wasm_runtime_destroy_wasi_nn_global_ctx(inst); +} +#endif + #if WASM_ENABLE_QUICK_AOT_ENTRY != 0 static bool quick_aot_entry_init(void); @@ -582,6 +611,11 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor); + if (g_wasi_nn_context_key == NULL) { + goto fail; + } + if (!wasi_nn_initialize()) goto fail; @@ -648,6 +682,10 @@ wasm_native_destroy() #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (g_wasi_nn_context_key != NULL) { + wasm_native_destroy_context_key(g_wasi_nn_context_key); + g_wasi_nn_context_key = NULL; + } wasi_nn_destroy(); #endif diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 259816e0b9..312c4b9c7b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry; + +void +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args) +{ + memset(args, 0, sizeof(*args)); +} + +bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths) +{ + if (!registry || !encoding || !target || !graph_paths) + { + return false; + } + registry->encoding = strdup(encoding); + registry->target = strdup(target); + registry->n_graphs = n_graphs; + registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < registry->n_graphs; i++) + registry->graph_paths[i] = strdup(graph_paths[i]); + + return true; +} + +int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) +{ + struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); + if (args == NULL) { + return false; + } + wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); + *registryp = args; + return 0; +} + +void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) +{ + if (registry) + { + for (uint32_t i = 0; i < registry->n_graphs; i++) + if (registry->graph_paths[i]) + { + // wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]); + free(registry->graph_paths[i]); + } + if (registry->encoding) + free(registry->encoding); + if (registry->target) + free(registry->target); + free(registry); + } +} +#endif + void wasm_runtime_instantiation_args_set_default_stack_size( struct InstantiationArgs2 *p, uint32 v) @@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry) +{ + p->nn_registry = *registry; +} +#endif WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap( return false; } #endif + +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size) +{ + WASINNGlobalContext *ctx; + bool ret = false; + + ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); + if (!ctx) + return false; + + ctx->encoding = strdup(encoding); + ctx->target = strdup(target); + ctx->n_graphs = n_graphs; + ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs); + memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); + + ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) + { + ctx->graph_paths[i] = strdup(graph_paths[i]); + } + + wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); + + ret = true; + + return ret; +} + +void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) +{ + WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst); + + for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) + { + // All graphs will be unregistered in deinit() + if (wasi_nn_global_ctx->graph_paths[i]) + free(wasi_nn_global_ctx->graph_paths[i]); + } + free(wasi_nn_global_ctx->encoding); + free(wasi_nn_global_ctx->target); + free(wasi_nn_global_ctx->loaded); + free(wasi_nn_global_ctx->graph_paths); + + if (wasi_nn_global_ctx) { + wasm_runtime_free(wasi_nn_global_ctx); + } +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->n_graphs; + + return -1; +} + +char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->graph_paths[idx]; + + return NULL; +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->loaded[idx]; + + return -1; +} + +uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + wasi_nn_global_ctx->loaded[idx] = value; + + return 0; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->encoding; + + return NULL; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->target; + + return NULL; +} + +#endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 88f23485e8..8d002bedcc 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption { uint32 exports_size; } WASMModuleInstMemConsumption; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNGlobalContext { + char* encoding; + char* target; + + uint32_t n_graphs; + uint32_t *loaded; + char** graph_paths; +} WASINNGlobalContext; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 #if WASM_ENABLE_UVWASI == 0 typedef struct WASIContext { @@ -612,11 +623,30 @@ WASMExecEnv * wasm_runtime_get_exec_env_tls(void); #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry { + char* encoding; + char* target; + + char** graph_paths; + uint32_t n_graphs; +}; + +WASM_RUNTIME_API_EXTERN int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +WASM_RUNTIME_API_EXTERN void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); +#endif + struct InstantiationArgs2 { InstantiationArgs v1; #if WASM_ENABLE_LIBC_WASI != 0 WASIArguments wasi; #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) + struct wasi_nn_graph_registry nn_registry; +#endif }; void @@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32 ns_lookup_pool_size); +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +WASM_RUNTIME_API_EXTERN void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +WASM_RUNTIME_API_EXTERN bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths); +#endif + /* See wasm_export.h for description */ WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap( uint8 **shared_heap_base_addr_adj_p, bool is_memory64); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + WASINNGlobalContext *wasi_ctx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx); +#endif + #ifdef __cplusplus } #endif diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 44a45dedfc..50263f1823 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -290,6 +290,8 @@ typedef struct InstantiationArgs { #endif /* INSTANTIATION_ARGS_OPTION_DEFINED */ struct InstantiationArgs2; +struct WASINNGlobalContext; +typedef struct WASINNGlobalContext *wasi_nn_global_context; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -796,6 +798,55 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); +// WASM_RUNTIME_API_EXTERN int +// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +// WASM_RUNTIME_API_EXTERN void +// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN void +// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( +// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN bool +// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, +// const char* target, uint32_t n_graphs, +// const char** graph_paths); + +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + wasi_nn_global_context wasi_ctx); + +WASM_RUNTIME_API_EXTERN wasi_nn_global_context +wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx); + /** * Instantiate a WASM module, with specified instantiation arguments * diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index a59bc9257b..79d4c73c2e 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,6 +3300,18 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + /* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */ + // Do not consider load() for now + struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; + if (!wasm_runtime_init_wasi_nn_global_ctx( + (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, + nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, + error_buf, error_buf_size)) { + goto fail; + } +#endif + #if WASM_ENABLE_DEBUG_INTERP != 0 if (!is_sub_inst) { /* Add module instance into module's instance list */ diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index f76295a1ee..83beba98f5 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -8,5 +8,5 @@ #include "wasi_nn.h" -#undef WASM_ENABLE_WASI_EPHEMERAL_NN -#undef WASI_NN_NAME +// #undef WASM_ENABLE_WASI_EPHEMERAL_NN +// #undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index cda26324eb..d76de3ffc0 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,7 +21,7 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif /** diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 952fb65e28..d77fe9a6cb 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -27,7 +27,7 @@ extern "C" { #define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) #define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) #define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) -#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) #endif /** @@ -169,6 +169,7 @@ typedef enum WASI_NN_NAME(execution_target) { WASI_NN_TARGET_NAME(cpu) = 0, WASI_NN_TARGET_NAME(gpu), WASI_NN_TARGET_NAME(tpu), + WASI_NN_TARGET_NAME(unsupported_target), } WASI_NN_NAME(execution_target); // Bind a `graph` to the input and output tensors for an inference. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 2282534b0f..9e3e741b69 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -21,7 +21,7 @@ #include "wasm_export.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif #define HASHMAP_INITIAL_SIZE 20 @@ -35,6 +35,8 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION +#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances + /* Global variables */ static korp_mutex wasi_nn_lock; /* @@ -208,6 +210,44 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ +static graph_encoding str2encoding(char* str_encoding) +{ + if (!str_encoding) { + NN_ERR_PRINTF("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return ggml; + else if (!strcmp(str_encoding, "onnx")) + return onnx; + else + return unknown_backend; + // return autodetect; +} + +static execution_target str2target(char* str_target) +{ + if (!str_target) { + NN_ERR_PRINTF("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return cpu; + else if (!strcmp(str_target, "gpu")) + return gpu; + else if (!strcmp(str_target, "tpu")) + return tpu; + else + return unsupported_target; + // return autodetect; +} + static graph_encoding choose_a_backend() { @@ -565,17 +605,82 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) + wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); + if (!wasi_nn_global_ctx) { + NN_ERR_PRINTF("global context is invalid"); + res = not_found; goto fail; + } + graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); + execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - g); + // res = ensure_backend(instance, autodetect, wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; + + bool is_loaded = false; + uint32 model_idx = 0; + char *global_model_path_i; + // Assume filename got from user wasm app : max; sum; average; ... + // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... + for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++) + { + // Extract filename from file path + global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); + char *model_file_name; + const char *slash = strrchr(global_model_path_i, '/'); + if (slash != NULL) { + model_file_name = (char*)(slash + 1); + } + else + model_file_name = global_model_path_i; + + // Extract modelname from filename + char* model_name = NULL; + size_t model_name_len = 0; + char* dot = strrchr(model_file_name, '.'); + if (dot) + { + model_name_len = dot - model_file_name; + model_name = malloc(model_name_len + 1); + strncpy(model_name, model_file_name, model_name_len); + model_name[model_name_len] = '\0'; + } + + if (model_name && strcmp(nul_terminated_name, model_name) == 0) { + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx); + break; + } + } - res = success; + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)) + { + NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i), + encoding, target, g); + if (res != success) + goto fail; + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1); + res = success; + } + else + { + if (is_loaded) + { + NN_DBG_PRINTF("Model is already loaded"); + res = success; + } + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) + { + // No enlarge for now + NN_ERR_PRINTF("No enough space for new model"); + res = too_large; + } + goto fail; + } fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 8cd03f1214..3108f2eef0 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -17,7 +17,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, graph *g); +load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, + graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index 2e1e649365..fd09c2be08 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -338,7 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 88587f68bc..e2283df0f3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -334,7 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index 899e06ee39..eec4f8190b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -306,7 +306,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 1bff2c514d..5dcb173f42 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -21,7 +21,8 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); -typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding, + execution_target, graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 9ac54e6644..eb56a42f23 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph *g) +load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target,graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; @@ -183,7 +183,7 @@ load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, } // Use CPU as default - tfl_ctx->models[*g].target = cpu; + tfl_ctx->models[*g].target = target; return success; } diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 1643b91b00..0c80fd6b12 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.1 +tensorflow==2.14.0 numpy==1.24.4 diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index 6a9e20702f..b3d6ba8037 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -13,16 +13,16 @@ #include "logger.h" void -test_sum(execution_target target) +test_sum() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/sum.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "sum", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 300.0) < EPSILON); free(input.dim); @@ -31,16 +31,16 @@ test_sum(execution_target target) } void -test_max(execution_target target) +test_max() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/max.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "max", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 24.0) < EPSILON); NN_INFO_PRINTF("Result: max is %f", output[0]); @@ -50,16 +50,16 @@ test_max(execution_target target) } void -test_average(execution_target target) +test_average() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/average.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "average", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 12.0) < EPSILON); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -69,16 +69,16 @@ test_average(execution_target target) } void -test_mult_dimensions(execution_target target) +test_mult_dimensions() { int dims[] = { 1, 3, 3, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_dim.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_dim", 1); - assert(output_size == 9); + assert((output_size / sizeof(float)) == 9); for (int i = 0; i < 9; i++) assert(fabs(output[i] - i) < EPSILON); @@ -88,16 +88,16 @@ test_mult_dimensions(execution_target target) } void -test_mult_outputs(execution_target target) +test_mult_outputs() { int dims[] = { 1, 4, 4, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_out.tflite", 2); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_out", 2); - assert(output_size == 8); + assert((output_size / sizeof(float)) == 8); // first tensor check for (int i = 0; i < 4; i++) assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); @@ -113,30 +113,18 @@ test_mult_outputs(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing sum..."); - test_sum(target); + test_sum(); NN_INFO_PRINTF("################### Testing max..."); - test_max(target); + test_max(); NN_INFO_PRINTF("################### Testing average..."); - test_average(target); + test_average(); NN_INFO_PRINTF("################### Testing multiple dimensions..."); - test_mult_dimensions(target); + test_mult_dimensions(); NN_INFO_PRINTF("################### Testing multiple outputs..."); - test_mult_outputs(target); + test_mult_outputs(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 3ed7c751e3..0898c7ae2a 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -16,15 +16,15 @@ #define EPSILON 1e-2 void -test_average_quantized(execution_target target) +test_average_quantized() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; float *output = - run_inference(target, input.input_tensor, input.dim, &output_size, - "./models/quantized_model.tflite", 1); + run_inference(input.input_tensor, input.dim, &output_size, + "quantized_model", 1); NN_INFO_PRINTF("Output size: %d", output_size); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -39,24 +39,10 @@ test_average_quantized(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else if (strcmp(env, "tpu") == 0) - target = tpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing quantized model..."); - test_average_quantized(target); + test_average_quantized(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 690c37f0e7..97ed08378e 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -5,17 +5,15 @@ #include "utils.h" #include "logger.h" -#include "wasi_nn.h" - #include #include -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target) +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target) { FILE *pFile = fopen(model_name, "r"); if (pFile == NULL) - return invalid_argument; + return WASI_NN_ERROR_NAME(invalid_argument); uint8_t *buffer; size_t result; @@ -24,20 +22,29 @@ wasm_load(char *model_name, graph *g, execution_target target) buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); if (buffer == NULL) { fclose(pFile); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); if (result <= 0) { fclose(pFile); free(buffer); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } - graph_builder_array arr; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + WASI_NN_NAME(graph_builder) arr; + + arr.buf = buffer; + arr.size = result; + + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + // WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#else + WASI_NN_NAME(graph_builder_array) arr; arr.size = 1; - arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); + arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder))); if (arr.buf == NULL) { fclose(pFile); free(buffer); @@ -47,7 +54,8 @@ wasm_load(char *model_name, graph *g, execution_target target) arr.buf[0].size = result; arr.buf[0].buf = buffer; - wasi_nn_error res = load(&arr, tensorflowlite, target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#endif fclose(pFile); free(buffer); @@ -55,77 +63,97 @@ wasm_load(char *model_name, graph *g, execution_target target) return res; } -wasi_nn_error -wasm_load_by_name(const char *model_name, graph *g) +WASI_NN_ERROR_TYPE +wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g) { - wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx) +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx) { - return init_execution_context(g, ctx); + return WASI_NN_NAME(init_execution_context)(g, ctx); } -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim) +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) { - tensor_dimensions dims; + WASI_NN_NAME(tensor_dimensions) dims; dims.size = INPUT_TENSOR_DIMS; dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); if (dims.buf == NULL) - return too_large; - - tensor tensor; + return WASI_NN_ERROR_NAME(too_large); + + WASI_NN_NAME(tensor) tensor; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + tensor.dimensions = dims; + for (int i = 0; i < tensor.dimensions.size; ++i) + tensor.dimensions.buf[i] = dim[i]; + tensor.type = WASI_NN_TYPE_NAME(fp32); + tensor.data.buf = (uint8_t *)input_tensor; + + uint32_t tmp_size = 1; + if (dim) + for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) + tmp_size *= dim[i]; + + tensor.data.size = (tmp_size * sizeof(float)); +#else tensor.dimensions = &dims; for (int i = 0; i < tensor.dimensions->size; ++i) tensor.dimensions->buf[i] = dim[i]; - tensor.type = fp32; + tensor.type = WASI_NN_TYPE_NAME(fp32); tensor.data = (uint8_t *)input_tensor; - wasi_nn_error err = set_input(ctx, 0, &tensor); +#endif + + WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor); free(dims.buf); return err; } -wasi_nn_error -wasm_compute(graph_execution_context ctx) +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) { - return compute(ctx); + return WASI_NN_NAME(compute)(ctx); } -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size) { - return get_output(ctx, index, (uint8_t *)out_tensor, out_size); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size); +#else + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size); +#endif } float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors) { - graph graph; + WASI_NN_NAME(graph) graph; - if (wasm_load_by_name(model_name, &graph) != success) { + if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when loading model."); exit(1); } - graph_execution_context ctx; - if (wasm_init_execution_context(graph, &ctx) != success) { + WASI_NN_NAME(graph_execution_context) ctx; + if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when initialixing execution context."); exit(1); } - if (wasm_set_input(ctx, input, input_size) != success) { + if (wasm_set_input(ctx, input, input_size) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when setting input tensor."); exit(1); } - if (wasm_compute(ctx) != success) { + if (wasm_compute(ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when running inference."); exit(1); } @@ -140,7 +168,7 @@ run_inference(execution_target target, float *input, uint32_t *input_size, for (int i = 0; i < num_output_tensors; ++i) { *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size; if (wasm_get_output(ctx, i, &out_tensor[offset], output_size) - != success) { + != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when getting index %d.", i); break; } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index e0d2417724..45ba156a0f 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -8,6 +8,7 @@ #include +#include "wasi_ephemeral_nn.h" #include "wasi_nn_types.h" #define MAX_MODEL_SIZE 85000000 @@ -23,26 +24,26 @@ typedef struct { /* wasi-nn wrappers */ -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target); +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx); +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim); +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); -wasi_nn_error -wasm_compute(graph_execution_context ctx); +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors); diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 2d7d3afeb8..ef99f2a842 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -18,6 +18,10 @@ #include "../common/libc_wasi.c" #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#include "wasi_ephemeral_nn.h" +#endif + #include "../common/wasm_proposal.c" #if BH_HAS_DLFCN @@ -115,6 +119,12 @@ print_help(void) #endif #if WASM_ENABLE_STATIC_PGO != 0 printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + printf(" --wasi-nn-graph=encoding:target:::...:\n"); + printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); + printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); + printf(" tensorflow|pytorch|ggml|autodetect\n"); #endif printf(" --version Show version information\n"); return 1; @@ -635,6 +645,13 @@ main(int argc, char *argv[]) int timeout_ms = -1; #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + struct wasi_nn_graph_registry *nn_registry; + char *encoding, *target; + uint32_t n_models = 0; + char **model_paths; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 memset(&wasi_parse_ctx, 0, sizeof(wasi_parse_ctx)); #endif @@ -825,6 +842,37 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { + char *token; + char *saveptr = NULL; + int token_count = 0; + char *tokens[12] = {0}; + + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... + token = strtok_r(argv[0] + 16, ":", &saveptr); + while (token) { + tokens[token_count] = token; + token_count++; + token = strtok_r(NULL, ":", &saveptr); + } + + if (token_count < 2) { + return print_help(); + } + + n_models = token_count - 2; + encoding = strdup(tokens[0]); + target = strdup(tokens[1]); + model_paths = malloc(n_models * sizeof(void*)); + for (int i = 0; i < n_models; i++) { + model_paths[i] = strdup(tokens[i + 2]); + } + if (token) + free(token); + } +#endif else { #if WASM_ENABLE_LIBC_WASI != 0 libc_wasi_parse_result_t result = @@ -974,6 +1022,11 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_create(&nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry); +#endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( wasm_module, inst_args, error_buf, sizeof(error_buf)); @@ -1092,6 +1145,15 @@ main(int argc, char *argv[]) #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_destroy(nn_registry); + for (uint32_t i = 0; i < n_models; i++) + if (model_paths[i]) + free(model_paths[i]); + free(model_paths); + free(encoding); + free(target); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst); From 60a80118992cdad2ab9e1548cf0b8589acbf189f Mon Sep 17 00:00:00 2001 From: QiuYuan Han Date: Wed, 10 Dec 2025 17:19:49 +0800 Subject: [PATCH 2/9] Add a new error check for wasi_nn_load_by_name --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 9e3e741b69..1afb07df07 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -622,9 +622,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; char *global_model_path_i; + uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); // Assume filename got from user wasm app : max; sum; average; ... // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... - for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++) + for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { // Extract filename from file path global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); @@ -654,7 +655,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, } } - if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)) + if (!is_loaded && \ + (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) && \ + (model_idx < global_n_graphs)) { NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, @@ -679,6 +682,12 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_ERR_PRINTF("No enough space for new model"); res = too_large; } + else if (model_idx >= global_n_graphs) + { + NN_ERR_PRINTF("Cannot find model %s, you should pass its path through --wasi-nn-graph", + nul_terminated_name); + res = not_found; + } goto fail; } fail: From 6dc9d01d5f8b84f80726caf32e0adaa187a6130a Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 11 Dec 2025 10:11:21 +0800 Subject: [PATCH 3/9] Use clang-format-18 to format source files --- core/iwasm/common/wasm_native.c | 8 +- core/iwasm/common/wasm_runtime_common.c | 73 ++++++++-------- core/iwasm/common/wasm_runtime_common.h | 49 ++++++----- core/iwasm/include/wasm_export.h | 36 ++++---- core/iwasm/interpreter/wasm_runtime.c | 7 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 83 ++++++++++--------- .../libraries/wasi-nn/src/wasi_nn_backend.h | 2 +- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 4 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 4 +- .../libraries/wasi-nn/src/wasi_nn_openvino.c | 2 +- .../libraries/wasi-nn/src/wasi_nn_private.h | 5 +- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 4 +- .../libraries/wasi-nn/test/test_tensorflow.c | 24 +++--- .../wasi-nn/test/test_tensorflow_quantized.c | 9 +- core/iwasm/libraries/wasi-nn/test/utils.c | 44 ++++++---- core/iwasm/libraries/wasi-nn/test/utils.h | 18 ++-- product-mini/platforms/posix/main.c | 16 ++-- 17 files changed, 215 insertions(+), 173 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 8938524db8..7781843914 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -486,9 +486,10 @@ wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) void wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx) + WASINNGlobalContext *wasi_nn_ctx) { - wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx); + wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, + wasi_nn_ctx); } static void @@ -611,7 +612,8 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor); + g_wasi_nn_context_key = + wasm_native_create_context_key(wasi_nn_context_dtor); if (g_wasi_nn_context_key == NULL) { goto fail; } diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 312c4b9c7b..685c7de045 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1700,25 +1700,25 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) struct wasi_nn_graph_registry; void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults( + struct wasi_nn_graph_registry *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, - const char* target, uint32_t n_graphs, - const char** graph_paths) +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, + const char *encoding, const char *target, + uint32_t n_graphs, const char **graph_paths) { - if (!registry || !encoding || !target || !graph_paths) - { + if (!registry || !encoding || !target || !graph_paths) { return false; } registry->encoding = strdup(encoding); registry->target = strdup(target); registry->n_graphs = n_graphs; - registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); - memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < registry->n_graphs; i++) registry->graph_paths[i] = strdup(graph_paths[i]); @@ -1740,12 +1740,11 @@ wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) void wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) { - if (registry) - { + if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) - if (registry->graph_paths[i]) - { - // wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]); + if (registry->graph_paths[i]) { + // wasi_nn_graph_registry_unregister_graph(registry, + // registry->name[i]); free(registry->graph_paths[i]); } if (registry->encoding) @@ -8153,9 +8152,10 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size) + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size) { WASINNGlobalContext *ctx; bool ret = false; @@ -8163,17 +8163,16 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); if (!ctx) return false; - + ctx->encoding = strdup(encoding); ctx->target = strdup(target); ctx->n_graphs = n_graphs; - ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs); + ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); - - ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); - memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs); - for (uint32_t i = 0; i < n_graphs; i++) - { + + ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) { ctx->graph_paths[i] = strdup(graph_paths[i]); } @@ -8187,10 +8186,10 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, void wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) { - WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst); + WASINNGlobalContext *wasi_nn_global_ctx = + wasm_runtime_get_wasi_nn_global_ctx(module_inst); - for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) - { + for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) { // All graphs will be unregistered in deinit() if (wasi_nn_global_ctx->graph_paths[i]) free(wasi_nn_global_ctx->graph_paths[i]); @@ -8206,7 +8205,8 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->n_graphs; @@ -8215,7 +8215,8 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ } char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) return wasi_nn_global_ctx->graph_paths[idx]; @@ -8224,7 +8225,8 @@ wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_g } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) return wasi_nn_global_ctx->loaded[idx]; @@ -8233,7 +8235,8 @@ wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global } uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) { if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) wasi_nn_global_ctx->loaded[idx] = value; @@ -8241,8 +8244,9 @@ wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global return 0; } -char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx) +char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->encoding; @@ -8250,8 +8254,9 @@ wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global return NULL; } -char* -wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx) +char * +wasm_runtime_get_wasi_nn_global_ctx_target( + WASINNGlobalContext *wasi_nn_global_ctx) { if (wasi_nn_global_ctx) return wasi_nn_global_ctx->target; diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 8d002bedcc..98ea3b68e3 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -547,12 +547,12 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { - char* encoding; - char* target; + char *encoding; + char *target; uint32_t n_graphs; uint32_t *loaded; - char** graph_paths; + char **graph_paths; } WASINNGlobalContext; #endif @@ -625,10 +625,10 @@ wasm_runtime_get_exec_env_tls(void); #if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) struct wasi_nn_graph_registry { - char* encoding; - char* target; + char *encoding; + char *target; - char** graph_paths; + char **graph_paths; uint32_t n_graphs; }; @@ -811,9 +811,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, - const char* target, uint32_t n_graphs, - const char** graph_paths); +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, + const char *encoding, const char *target, + uint32_t n_graphs, const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1471,34 +1471,41 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size); + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size); WASM_RUNTIME_API_EXTERN void wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - WASINNGlobalContext *wasi_ctx); + WASINNGlobalContext *wasi_ctx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + WASINNGlobalContext *wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + WASINNGlobalContext *wasi_nn_global_ctx); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_target( + WASINNGlobalContext *wasi_nn_global_ctx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 50263f1823..16a9ad54bc 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -291,7 +291,7 @@ typedef struct InstantiationArgs { struct InstantiationArgs2; struct WASINNGlobalContext; -typedef struct WASINNGlobalContext *wasi_nn_global_context; +typedef struct WASINNGlobalContext *wasi_nn_global_context; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -809,43 +809,51 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( // struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); // WASM_RUNTIME_API_EXTERN bool -// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, +// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, +// const char* encoding, // const char* target, uint32_t n_graphs, // const char** graph_paths); WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - const char* encoding, const char* target, - const uint32_t n_graphs, char* graph_paths[], - char *error_buf, uint32_t error_buf_size); + const char *encoding, const char *target, + const uint32_t n_graphs, + char *graph_paths[], char *error_buf, + uint32_t error_buf_size); WASM_RUNTIME_API_EXTERN void wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, - wasi_nn_global_context wasi_ctx); + wasi_nn_global_context wasi_ctx); WASM_RUNTIME_API_EXTERN wasi_nn_global_context wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_global_ctx_ngraphs( + wasi_nn_global_context wasi_nn_global_ctx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_encoding( + wasi_nn_global_context wasi_nn_global_ctx); -WASM_RUNTIME_API_EXTERN char* -wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx); +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_target( + wasi_nn_global_context wasi_nn_global_ctx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index 79d4c73c2e..6c8f92975c 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3301,13 +3301,14 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, #endif #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - /* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */ + /* Store graphs' path into ctx. Graphs will be loaded until user app calls + * load_by_name */ // Do not consider load() for now struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; if (!wasm_runtime_init_wasi_nn_global_ctx( (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, - nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, - error_buf, error_buf_size)) { + nn_registry->target, nn_registry->n_graphs, + nn_registry->graph_paths, error_buf, error_buf_size)) { goto fail; } #endif diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 1afb07df07..519c799454 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -35,7 +35,7 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION -#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances +#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instance /* Global variables */ static korp_mutex wasi_nn_lock; @@ -210,7 +210,8 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ -static graph_encoding str2encoding(char* str_encoding) +static graph_encoding +str2encoding(char *str_encoding) { if (!str_encoding) { NN_ERR_PRINTF("Got empty string encoding"); @@ -227,10 +228,11 @@ static graph_encoding str2encoding(char* str_encoding) return onnx; else return unknown_backend; - // return autodetect; + // return autodetect; } -static execution_target str2target(char* str_target) +static execution_target +str2target(char *str_target) { if (!str_target) { NN_ERR_PRINTF("Got empty string target"); @@ -245,7 +247,7 @@ static execution_target str2target(char* str_target) return tpu; else return unsupported_target; - // return autodetect; + // return autodetect; } static graph_encoding @@ -605,87 +607,88 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); + wasi_nn_global_context wasi_nn_global_ctx = + wasm_runtime_get_wasi_nn_global_ctx(instance); if (!wasi_nn_global_ctx) { NN_ERR_PRINTF("global context is invalid"); res = not_found; goto fail; } - graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); - execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); + graph_encoding encoding = str2encoding( + wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); + execution_target target = str2target( + wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; - + bool is_loaded = false; uint32 model_idx = 0; char *global_model_path_i; - uint32_t global_n_graphs = wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); + uint32_t global_n_graphs = + wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); // Assume filename got from user wasm app : max; sum; average; ... - // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... - for (model_idx = 0; model_idx < global_n_graphs; model_idx++) - { + // Assume file path got from user cmd opt: /your/path1/max.tflite; + // /your/path2/sum.tflite; ...... + for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { // Extract filename from file path - global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); + global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( + wasi_nn_global_ctx, model_idx); char *model_file_name; - const char *slash = strrchr(global_model_path_i, '/'); + const char *slash = strrchr(global_model_path_i, '/'); if (slash != NULL) { - model_file_name = (char*)(slash + 1); + model_file_name = (char *)(slash + 1); } else model_file_name = global_model_path_i; // Extract modelname from filename - char* model_name = NULL; + char *model_name = NULL; size_t model_name_len = 0; - char* dot = strrchr(model_file_name, '.'); - if (dot) - { + char *dot = strrchr(model_file_name, '.'); + if (dot) { model_name_len = dot - model_file_name; model_name = malloc(model_name_len + 1); strncpy(model_name, model_file_name, model_name_len); model_name[model_name_len] = '\0'; } - + if (model_name && strcmp(nul_terminated_name, model_name) == 0) { - is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx); + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( + wasi_nn_global_ctx, model_idx); break; } } - if (!is_loaded && \ - (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) && \ - (model_idx < global_n_graphs)) - { + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) + && (model_idx < global_n_graphs)) { NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i), - encoding, target, g); + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), encoding, target, g); if (res != success) goto fail; - - wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1); + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, + model_idx, 1); res = success; } - else - { - if (is_loaded) - { + else { + if (is_loaded) { NN_DBG_PRINTF("Model is already loaded"); res = success; } - else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) - { + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { // No enlarge for now NN_ERR_PRINTF("No enough space for new model"); res = too_large; } - else if (model_idx >= global_n_graphs) - { - NN_ERR_PRINTF("Cannot find model %s, you should pass its path through --wasi-nn-graph", - nul_terminated_name); + else if (model_idx >= global_n_graphs) { + NN_ERR_PRINTF("Cannot find model %s, you should pass its path " + "through --wasi-nn-graph", + nul_terminated_name); res = not_found; } goto fail; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 3108f2eef0..344e66550b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -18,7 +18,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, - graph_encoding encoding, execution_target target, graph *g); + graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index fd09c2be08..7042affa70 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -338,8 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index e2283df0f3..947fa558e3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -334,8 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index eec4f8190b..4739953605 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -307,7 +307,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target, graph *g) + graph_encoding encoding, execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 5dcb173f42..7ea76eddb1 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -21,8 +21,9 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); -typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding, - execution_target, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, + graph_encoding, execution_target, + graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index eb56a42f23..2b4832dc41 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph_encoding encoding, execution_target target,graph *g) +load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index b3d6ba8037..a34c3be4bc 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -19,8 +19,8 @@ test_sum() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "sum", 1); + float *output = + run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 300.0) < EPSILON); @@ -37,8 +37,8 @@ test_max() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "max", 1); + float *output = + run_inference(input.input_tensor, input.dim, &output_size, "max", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 24.0) < EPSILON); @@ -56,8 +56,8 @@ test_average() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "average", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "average", 1); assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 12.0) < EPSILON); @@ -75,8 +75,8 @@ test_mult_dimensions() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "mult_dim", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "mult_dim", 1); assert((output_size / sizeof(float)) == 9); for (int i = 0; i < 9; i++) @@ -94,8 +94,8 @@ test_mult_outputs() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(input.input_tensor, input.dim, - &output_size, "mult_out", 2); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "mult_out", 2); assert((output_size / sizeof(float)) == 8); // first tensor check @@ -113,7 +113,9 @@ test_mult_outputs() int main() { - NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " + "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." + ":model_pathN test_tensorflow.wasm\""); NN_INFO_PRINTF("################### Testing sum..."); test_sum(); diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 0898c7ae2a..a55dc9bb89 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -22,9 +22,8 @@ test_average_quantized() input_info input = create_input(dims); uint32_t output_size = 0; - float *output = - run_inference(input.input_tensor, input.dim, &output_size, - "quantized_model", 1); + float *output = run_inference(input.input_tensor, input.dim, &output_size, + "quantized_model", 1); NN_INFO_PRINTF("Output size: %d", output_size); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -39,7 +38,9 @@ test_average_quantized() int main() { - NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so " + "--wasi-nn-graph=encoding:target:model_path1:model_path2:..." + ":model_pathN test_tensorflow.wasm\""); NN_INFO_PRINTF("################### Testing quantized model..."); test_average_quantized(); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 97ed08378e..0a99a95e50 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -9,7 +9,8 @@ #include WASI_NN_ERROR_TYPE -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target) +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target) { FILE *pFile = fopen(model_name, "r"); if (pFile == NULL) @@ -38,13 +39,14 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe arr.buf = buffer; arr.size = result; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); - // WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + &arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #else WASI_NN_NAME(graph_builder_array) arr; arr.size = 1; - arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder))); + arr.buf = (WASI_NN_NAME(graph_builder) *)malloc( + sizeof(WASI_NN_NAME(graph_builder))); if (arr.buf == NULL) { fclose(pFile); free(buffer); @@ -54,7 +56,8 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe arr.buf[0].size = result; arr.buf[0].buf = buffer; - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)( + &arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); #endif fclose(pFile); @@ -64,20 +67,23 @@ wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_targe } WASI_NN_ERROR_TYPE -wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g) +wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) * g) { - WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); + WASI_NN_ERROR_TYPE res = + WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } WASI_NN_ERROR_TYPE -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx) +wasm_init_execution_context(WASI_NN_NAME(graph) g, + WASI_NN_NAME(graph_execution_context) * ctx) { return WASI_NN_NAME(init_execution_context)(g, ctx); } WASI_NN_ERROR_TYPE -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim) { WASI_NN_NAME(tensor_dimensions) dims; dims.size = INPUT_TENSOR_DIMS; @@ -103,7 +109,7 @@ wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, u tensor.dimensions = &dims; for (int i = 0; i < tensor.dimensions->size; ++i) tensor.dimensions->buf[i] = dim[i]; - tensor.type = WASI_NN_TYPE_NAME(fp32); + tensor.type = WASI_NN_TYPE_NAME(fp32); tensor.data = (uint8_t *)input_tensor; #endif @@ -120,20 +126,21 @@ wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) } WASI_NN_ERROR_TYPE -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size) +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size) { #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size); + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, + MAX_OUTPUT_TENSOR_SIZE, out_size); #else - return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size); + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, + out_size); #endif } float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors) +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors) { WASI_NN_NAME(graph) graph; @@ -143,7 +150,8 @@ run_inference(float *input, uint32_t *input_size, } WASI_NN_NAME(graph_execution_context) ctx; - if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) { + if (wasm_init_execution_context(graph, &ctx) + != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when initialixing execution context."); exit(1); } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index 45ba156a0f..5a5c03c3d7 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -25,27 +25,29 @@ typedef struct { /* wasi-nn wrappers */ WASI_NN_ERROR_TYPE -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target); WASI_NN_ERROR_TYPE -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); +wasm_init_execution_context(WASI_NN_NAME(graph) g, + WASI_NN_NAME(graph_execution_context) * ctx); WASI_NN_ERROR_TYPE -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim); WASI_NN_ERROR_TYPE wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); WASI_NN_ERROR_TYPE -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size); +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors); +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors); input_info create_input(int *dims); diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index ef99f2a842..7226f5b507 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -847,9 +847,9 @@ main(int argc, char *argv[]) char *token; char *saveptr = NULL; int token_count = 0; - char *tokens[12] = {0}; + char *tokens[12] = { 0 }; - // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... token = strtok_r(argv[0] + 16, ":", &saveptr); while (token) { @@ -865,11 +865,11 @@ main(int argc, char *argv[]) n_models = token_count - 2; encoding = strdup(tokens[0]); target = strdup(tokens[1]); - model_paths = malloc(n_models * sizeof(void*)); + model_paths = malloc(n_models * sizeof(void *)); for (int i = 0; i < n_models; i++) { model_paths[i] = strdup(tokens[i + 2]); } - if (token) + if (token) free(token); } #endif @@ -1024,8 +1024,10 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_create(&nn_registry); - wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); - wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, + model_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, + nn_registry); #endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( @@ -1149,7 +1151,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_destroy(nn_registry); for (uint32_t i = 0; i < n_models; i++) - if (model_paths[i]) + if (model_paths[i]) free(model_paths[i]); free(model_paths); free(encoding); From d42581bb0b92e11b0ea2e25abccf17fac0c99476 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Thu, 11 Dec 2025 17:31:46 +0800 Subject: [PATCH 4/9] Free model_name --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 519c799454..7b39ca541c 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -658,8 +658,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, if (model_name && strcmp(nul_terminated_name, model_name) == 0) { is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( wasi_nn_global_ctx, model_idx); + free(model_name); break; } + free(model_name); } if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) From 1cca0b8789479f8c2749ad4faf94fa8f7aa98f33 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Tue, 16 Dec 2025 10:07:36 +0800 Subject: [PATCH 5/9] Add new errno for new test cases --- .../libraries/wasi-nn/include/wasi_nn_types.h | 1 + core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 4 +- .../libraries/wasi-nn/test/test_tensorflow.c | 46 +++++++++++-------- core/iwasm/libraries/wasi-nn/test/utils.c | 10 +++- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index d77fe9a6cb..aea6554b8d 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -48,6 +48,7 @@ typedef enum { WASI_NN_ERROR_NAME(unsupported_operation), WASI_NN_ERROR_NAME(too_large), WASI_NN_ERROR_NAME(not_found), + WASI_NN_ERROR_NAME(not_loaded), // for WasmEdge-wasi-nn WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 7b39ca541c..8bb7e57cd6 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -688,10 +688,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, res = too_large; } else if (model_idx >= global_n_graphs) { - NN_ERR_PRINTF("Cannot find model %s, you should pass its path " + NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", nul_terminated_name); - res = not_found; + res = not_loaded; } goto fail; } diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index a34c3be4bc..d5147b2fbd 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -22,8 +22,10 @@ test_sum() float *output = run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 300.0) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 300.0) < EPSILON); + } free(input.dim); free(input.input_tensor); @@ -40,9 +42,11 @@ test_max() float *output = run_inference(input.input_tensor, input.dim, &output_size, "max", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 24.0) < EPSILON); - NN_INFO_PRINTF("Result: max is %f", output[0]); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 24.0) < EPSILON); + NN_INFO_PRINTF("Result: max is %f", output[0]); + } free(input.dim); free(input.input_tensor); @@ -59,9 +63,11 @@ test_average() float *output = run_inference(input.input_tensor, input.dim, &output_size, "average", 1); - assert((output_size / sizeof(float)) == 1); - assert(fabs(output[0] - 12.0) < EPSILON); - NN_INFO_PRINTF("Result: average is %f", output[0]); + if (output) { + assert((output_size / sizeof(float)) == 1); + assert(fabs(output[0] - 12.0) < EPSILON); + NN_INFO_PRINTF("Result: average is %f", output[0]); + } free(input.dim); free(input.input_tensor); @@ -78,9 +84,11 @@ test_mult_dimensions() float *output = run_inference(input.input_tensor, input.dim, &output_size, "mult_dim", 1); - assert((output_size / sizeof(float)) == 9); - for (int i = 0; i < 9; i++) - assert(fabs(output[i] - i) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 9); + for (int i = 0; i < 9; i++) + assert(fabs(output[i] - i) < EPSILON); + } free(input.dim); free(input.input_tensor); @@ -97,13 +105,15 @@ test_mult_outputs() float *output = run_inference(input.input_tensor, input.dim, &output_size, "mult_out", 2); - assert((output_size / sizeof(float)) == 8); - // first tensor check - for (int i = 0; i < 4; i++) - assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); - // second tensor check - for (int i = 0; i < 4; i++) - assert(fabs(output[i + 4] - (i + 6)) < EPSILON); + if (output) { + assert((output_size / sizeof(float)) == 8); + // first tensor check + for (int i = 0; i < 4; i++) + assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); + // second tensor check + for (int i = 0; i < 4; i++) + assert(fabs(output[i + 4] - (i + 6)) < EPSILON); + } free(input.dim); free(input.input_tensor); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 0a99a95e50..f6f9bd0961 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -144,7 +144,15 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, { WASI_NN_NAME(graph) graph; - if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) { + WASI_NN_ERROR_TYPE res = wasm_load_by_name(model_name, &graph); + + if (res == WASI_NN_ERROR_NAME(not_loaded)) { + NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " + "through --wasi-nn-graph", + model_name); + return NULL; + } + else if (res != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when loading model."); exit(1); } From 9e039707c2b2ea9e040704e126ae25f5e2c845c6 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Wed, 17 Dec 2025 14:59:26 +0800 Subject: [PATCH 6/9] Fix bugs --- core/iwasm/common/wasm_runtime_common.c | 4 ++-- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 2 +- core/iwasm/libraries/wasi-nn/test/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 685c7de045..51ecf23f95 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1725,12 +1725,12 @@ wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, return true; } -int +static int wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) { struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { - return false; + return -1; } wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); *registryp = args; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8bb7e57cd6..892ca5dd79 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -35,7 +35,7 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION -#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instance +#define MAX_GLOBAL_GRAPHS_PER_INST 4 /* Global variables */ static korp_mutex wasi_nn_lock; diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 0c80fd6b12..2145e3736a 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.14.0 +tensorflow==2.12.0 numpy==1.24.4 From cd1c7f9134989ec7075e1e5b6261489a9c18deed Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 19 Dec 2025 14:47:38 +0800 Subject: [PATCH 7/9] Rename some parameters --- core/iwasm/common/wasm_native.c | 4 +-- core/iwasm/common/wasm_runtime_common.c | 27 +++++++++--------- core/iwasm/common/wasm_runtime_common.h | 28 +++++++++---------- core/iwasm/include/wasm_export.h | 16 ----------- core/iwasm/interpreter/wasm_runtime.c | 5 ++-- .../wasi-nn/include/wasi_ephemeral_nn.h | 3 -- core/iwasm/libraries/wasi-nn/test/build.sh | 2 ++ .../libraries/wasi-nn/test/test_tensorflow.c | 20 +++++++++++++ core/iwasm/libraries/wasi-nn/test/utils.h | 4 +++ product-mini/platforms/posix/main.c | 14 ++++++---- 10 files changed, 65 insertions(+), 58 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 7781843914..b8430520af 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -25,7 +25,7 @@ static NativeSymbolsList g_native_symbols_list = NULL; static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 static void *g_wasi_nn_context_key; #endif @@ -477,7 +477,7 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) } #endif /* end of WASM_ENABLE_LIBC_WASI */ -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASINNGlobalContext * wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) { diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 51ecf23f95..88c5b18e0b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,20 +1696,19 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) -struct wasi_nn_graph_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASINNArguments; void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults( - struct wasi_nn_graph_registry *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, - const char *encoding, const char *target, - uint32_t n_graphs, const char **graph_paths) +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, + const char *target, uint32_t n_graphs, + const char **graph_paths) { if (!registry || !encoding || !target || !graph_paths) { return false; @@ -1725,10 +1724,10 @@ wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, return true; } -static int -wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) +int +wasi_nn_graph_registry_create(WASINNArguments **registryp) { - struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); + WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { return -1; } @@ -1738,7 +1737,7 @@ wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) } void -wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) +wasi_nn_graph_registry_destroy(WASINNArguments *registry) { if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) @@ -1854,10 +1853,10 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 void wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry) + struct InstantiationArgs2 *p, WASINNArguments *registry) { p->nn_registry = *registry; } @@ -8149,7 +8148,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( } #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 98ea3b68e3..f06cca5da6 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -545,7 +545,7 @@ typedef struct WASMModuleInstMemConsumption { uint32 exports_size; } WASMModuleInstMemConsumption; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNGlobalContext { char *encoding; char *target; @@ -623,20 +623,20 @@ WASMExecEnv * wasm_runtime_get_exec_env_tls(void); #endif -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) -struct wasi_nn_graph_registry { +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNArguments { char *encoding; char *target; char **graph_paths; uint32_t n_graphs; -}; +} WASINNArguments; WASM_RUNTIME_API_EXTERN int -wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); +wasi_nn_graph_registry_create(WASINNArguments **registryp); WASM_RUNTIME_API_EXTERN void -wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); +wasi_nn_graph_registry_destroy(WASINNArguments *registry); #endif struct InstantiationArgs2 { @@ -644,8 +644,8 @@ struct InstantiationArgs2 { #if WASM_ENABLE_LIBC_WASI != 0 WASIArguments wasi; #endif -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) - struct wasi_nn_graph_registry nn_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + WASINNArguments nn_registry; #endif }; @@ -805,15 +805,15 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32 ns_lookup_pool_size); -#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + struct InstantiationArgs2 *p, WASINNArguments *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, - const char *encoding, const char *target, - uint32_t n_graphs, const char **graph_paths); +wasi_nn_graph_registry_set_args(WASINNArguments *registry, const char *encoding, + const char *target, uint32_t n_graphs, + const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1468,7 +1468,7 @@ wasm_runtime_check_and_update_last_used_shared_heap( uint8 **shared_heap_base_addr_adj_p, bool is_memory64); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 16a9ad54bc..17a15688cf 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -798,22 +798,6 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -// WASM_RUNTIME_API_EXTERN int -// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); - -// WASM_RUNTIME_API_EXTERN void -// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); - -// WASM_RUNTIME_API_EXTERN void -// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( -// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); - -// WASM_RUNTIME_API_EXTERN bool -// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, -// const char* encoding, -// const char* target, uint32_t n_graphs, -// const char** graph_paths); - WASM_RUNTIME_API_EXTERN bool wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, const char *encoding, const char *target, diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index 6c8f92975c..4dc6bd2537 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,11 +3300,10 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 /* Store graphs' path into ctx. Graphs will be loaded until user app calls * load_by_name */ - // Do not consider load() for now - struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; + WASINNArguments *nn_registry = &args->nn_registry; if (!wasm_runtime_init_wasi_nn_global_ctx( (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, nn_registry->target, nn_registry->n_graphs, diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index 83beba98f5..86afc42674 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -7,6 +7,3 @@ #define WASI_NN_NAME(name) wasi_ephemeral_nn_##name #include "wasi_nn.h" - -// #undef WASM_ENABLE_WASI_EPHEMERAL_NN -// #undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/test/build.sh b/core/iwasm/libraries/wasi-nn/test/build.sh index 79d65d730c..5c99706d04 100755 --- a/core/iwasm/libraries/wasi-nn/test/build.sh +++ b/core/iwasm/libraries/wasi-nn/test/build.sh @@ -22,6 +22,8 @@ CURR_PATH=$(cd $(dirname $0) && pwd -P) /opt/wasi-sdk/bin/clang \ --target=wasm32-wasi \ + -DWASM_ENABLE_WASI_NN=1 \ + -DWASM_ENABLE_WASI_EPHEMERAL_NN=1 \ -DNN_LOG_LEVEL=1 \ -Wl,--allow-undefined \ -I../include -I../src/utils \ diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index d5147b2fbd..d276dd0ac8 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -23,7 +23,11 @@ test_sum() run_inference(input.input_tensor, input.dim, &output_size, "sum", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 300.0) < EPSILON); } @@ -43,7 +47,11 @@ test_max() run_inference(input.input_tensor, input.dim, &output_size, "max", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 24.0) < EPSILON); NN_INFO_PRINTF("Result: max is %f", output[0]); } @@ -64,7 +72,11 @@ test_average() "average", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 1); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 1); +#endif assert(fabs(output[0] - 12.0) < EPSILON); NN_INFO_PRINTF("Result: average is %f", output[0]); } @@ -85,7 +97,11 @@ test_mult_dimensions() "mult_dim", 1); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 9); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 9); +#endif for (int i = 0; i < 9; i++) assert(fabs(output[i] - i) < EPSILON); } @@ -106,7 +122,11 @@ test_mult_outputs() "mult_out", 2); if (output) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 assert((output_size / sizeof(float)) == 8); +#elif WASM_ENABLE_WASI_NN != 0 + assert(output_size == 8); +#endif // first tensor check for (int i = 0; i < 4; i++) assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index 5a5c03c3d7..8d2683fff4 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -8,7 +8,11 @@ #include +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #include "wasi_ephemeral_nn.h" +#elif WASM_ENABLE_WASI_NN != 0 +#include "wasi_nn.h" +#endif #include "wasi_nn_types.h" #define MAX_MODEL_SIZE 85000000 diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 7226f5b507..7765bbef54 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -20,6 +20,8 @@ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #include "wasi_ephemeral_nn.h" +#elif WASM_ENABLE_WASI_NN != 0 +#include "wasi_nn.h" #endif #include "../common/wasm_proposal.c" @@ -120,7 +122,7 @@ print_help(void) #if WASM_ENABLE_STATIC_PGO != 0 printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 printf(" --wasi-nn-graph=encoding:target:::...:\n"); printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); @@ -645,8 +647,8 @@ main(int argc, char *argv[]) int timeout_ms = -1; #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - struct wasi_nn_graph_registry *nn_registry; +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + struct WASINNArguments *nn_registry; char *encoding, *target; uint32_t n_models = 0; char **model_paths; @@ -842,7 +844,7 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { char *token; char *saveptr = NULL; @@ -1022,7 +1024,7 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_create(&nn_registry); wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); @@ -1148,7 +1150,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: #endif -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_graph_registry_destroy(nn_registry); for (uint32_t i = 0; i < n_models; i++) if (model_paths[i]) From c73e4aa43259a8db0438852ce2756a58ad78a26c Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 19 Dec 2025 14:55:42 +0800 Subject: [PATCH 8/9] Revert tensorflow version --- core/iwasm/libraries/wasi-nn/test/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 2145e3736a..1643b91b00 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.0 +tensorflow==2.12.1 numpy==1.24.4 From 0234fe08b9842b5e27123e4ad75a65d8ad6a4c42 Mon Sep 17 00:00:00 2001 From: qinzh Date: Tue, 6 Jan 2026 10:02:52 +0800 Subject: [PATCH 9/9] CICD: retrigger checks