diff --git a/.gitignore b/.gitignore index a41dfe0..ab5165b 100644 --- a/.gitignore +++ b/.gitignore @@ -24,5 +24,7 @@ xcschememanagement.plist Makefile *.ort onnxruntime*.tar.gz -ThirdParty/onnxruntime +ThirdParty/onnxruntime* /Installers/Windows/Output/ +Lib/ModelData/features_model.required_operators.config +Lib/ModelData/features_model.required_operators.with_runtime_opt.config \ No newline at end of file diff --git a/CMake/CPM.cmake b/CMake/CPM.cmake new file mode 100644 index 0000000..9598ee7 --- /dev/null +++ b/CMake/CPM.cmake @@ -0,0 +1,21 @@ +set(CPM_DOWNLOAD_VERSION 0.35.5) + +if(CPM_SOURCE_CACHE) + # Expand relative path. This is important if the provided path contains a tilde (~) + get_filename_component(CPM_SOURCE_CACHE ${CPM_SOURCE_CACHE} ABSOLUTE) + set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +elseif(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +else() + set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +endif() + +if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION})) + message(STATUS "Downloading CPM.cmake to ${CPM_DOWNLOAD_LOCATION}") + file(DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake + ${CPM_DOWNLOAD_LOCATION} + ) +endif() + +include(${CPM_DOWNLOAD_LOCATION}) diff --git a/CMake/FindJUCE.cmake b/CMake/FindJUCE.cmake new file mode 100644 index 0000000..8e3547d --- /dev/null +++ b/CMake/FindJUCE.cmake @@ -0,0 +1,6 @@ +include(CPM) + +CPMAddPackage( + NAME JUCE + GITHUB_REPOSITORY juce-framework/JUCE + GIT_TAG origin/master) \ No newline at end of file diff --git a/CMake/FindRTNeural.cmake b/CMake/FindRTNeural.cmake new file mode 100644 index 0000000..42c0038 --- /dev/null +++ b/CMake/FindRTNeural.cmake @@ -0,0 +1,6 @@ +include(CPM) + +CPMAddPackage( + NAME RTNeural + GITHUB_REPOSITORY jatinchowdhury18/RTNeural + GIT_TAG origin/main) \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index e132cbb..5805198 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.16) project(NeuralNotePlugin VERSION 1.1.0) # C++ settings -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED True) enable_language(CXX) #Compile commands, useful for some IDEs like VS-Code @@ -10,7 +11,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE) set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) #Minimum MacOS target, set globally -set(CMAKE_OSX_DEPLOYMENT_TARGET "10.11" CACHE STRING "Minimum OS X deployment version" FORCE) +set(CMAKE_OSX_DEPLOYMENT_TARGET "10.15" CACHE STRING "Minimum OS X deployment version" FORCE) option(UniversalBinary "Build universal binary for mac" OFF) option(RTNeural_Release "When CMAKE_BUILD_TYPE=Debug, overwrite it to Release for RTNeural only" OFF) @@ -24,10 +25,13 @@ endif () #static linking in Windows set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") -add_subdirectory(ThirdParty/JUCE) -add_subdirectory(ThirdParty/RTNeural) +include(CMake/CPM.cmake) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/CMake) +find_package(JUCE) +find_package(RTNeural) set(BaseTargetName NeuralNote) +set(CLITargetName NeuralNoteCLI) # Set COPY_PLUGIN_AFTER_BUILD to TRUE on Mac and to FALSE on Windows due to permission issues if (APPLE) @@ -115,14 +119,15 @@ target_compile_definitions(${BaseTargetName} JUCE_USE_CURL=0 JUCE_VST3_CAN_REPLACE_VST2=0) -add_library(onnxruntime STATIC IMPORTED) +add_library(onnxruntime SHARED IMPORTED) if (APPLE) - set_property(TARGET onnxruntime PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime/lib/libonnxruntime.a) + set_property(TARGET onnxruntime PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime-osx-arm64-1.22.0/lib/libonnxruntime.dylib) elseif (WIN32) set_property(TARGET onnxruntime APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) + #TODO: Fix Windows set_target_properties(onnxruntime PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX" IMPORTED_LOCATION_RELEASE "${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime/lib/onnxruntime.lib" @@ -136,10 +141,8 @@ elseif (UNIX) set_property(TARGET onnxruntime PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime/lib/libonnxruntime.a) endif () - target_include_directories("${BaseTargetName}" PRIVATE ${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime/include) - target_link_libraries(${BaseTargetName} PRIVATE juce::juce_audio_utils @@ -170,3 +173,37 @@ if (WIN32) target_include_directories(${BaseTargetName} PRIVATE ThirdParty/ASIO/common) endif () +juce_add_console_app(${CLITargetName} + PRODUCT_NAME ${CLITargetName}) + +juce_generate_juce_header(${CLITargetName}) + +target_sources(${CLITargetName} PRIVATE + NeuralNote/Source/MidiFile/MidiFileWriter.cpp + Lib/DSP/Resampler.cpp + Lib/MidiPostProcessing/NoteOptions.cpp + Lib/Model/BasicPitch.cpp + Lib/Model/Features.cpp + Lib/Model/Notes.cpp + NeuralNoteCLI/main.cpp) + +target_include_directories("${CLITargetName}" PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/ThirdParty/onnxruntime/include + NeuralNote/Source + NeuralNote/Source/MidiFile + Lib/Model + Lib/Utils + ) + +target_compile_definitions(${CLITargetName} PRIVATE + JUCE_WEB_BROWSER=0 + JUCE_USE_CURL=0) + +target_link_libraries(${CLITargetName} PRIVATE + juce_recommended_config_flags + juce_recommended_lto_flags + juce_recommended_warning_flags + juce::juce_audio_utils + juce::juce_dsp + BasicPitchCNN + onnxruntime) \ No newline at end of file diff --git a/Lib/Model/Features.cpp b/Lib/Model/Features.cpp index b3163da..3986932 100644 --- a/Lib/Model/Features.cpp +++ b/Lib/Model/Features.cpp @@ -10,8 +10,11 @@ Features::Features() { mMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + mSessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); + mSessionOptions.SetIntraOpNumThreads(0); mSessionOptions.SetInterOpNumThreads(1); - mSessionOptions.SetIntraOpNumThreads(1); + + mSessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); mSession = Ort::Session(mEnv, BinaryData::features_model_ort, BinaryData::features_model_ortSize, mSessionOptions); } diff --git a/Lib/ModelData/features_model.ort b/Lib/ModelData/features_model.ort new file mode 100644 index 0000000..3fcd20d Binary files /dev/null and b/Lib/ModelData/features_model.ort differ diff --git a/NeuralNoteCLI/main.cpp b/NeuralNoteCLI/main.cpp new file mode 100644 index 0000000..d0dd976 --- /dev/null +++ b/NeuralNoteCLI/main.cpp @@ -0,0 +1,140 @@ +#include +#include +#include + +#include "BasicPitch.h" +#include "MidiFileWriter.h" +#include "TimeQuantizeOptions.h" + +std::string convertMillis(long long milliseconds) { + auto duration = std::chrono::milliseconds(milliseconds); + auto hours = std::chrono::duration_cast(duration); + duration -= hours; + auto minutes = std::chrono::duration_cast(duration); + duration -= minutes; + auto seconds = std::chrono::duration_cast(duration); + + if (hours.count() > 0) + return std::to_string(hours.count()) + ":" + + std::to_string(minutes.count()) + ":" + + std::to_string(seconds.count()); + else + return std::to_string(minutes.count()) + ":" + + std::to_string(seconds.count()); +} + +struct App : juce::JUCEApplication +{ + const juce::String getApplicationName() override + { + return "NeuralNoteCLI"; + } + + const juce::String getApplicationVersion() override { return "0.1"; } + void initialise(const juce::String& inputArgs) override + { + juce::StringArray tokens; + tokens.addTokens (inputArgs, " ", "\""); + + if (tokens.size() != 2) + { + std::cerr << "Usage: " + << std::endl; + exit(1); + } + + // load audio passed as argument + std::string wav_file = tokens[0].toStdString(); + juce::File input { wav_file }; + + juce::AudioFormatManager formatManager; + formatManager.registerBasicFormats(); + + std::unique_ptr inputStream = input.createInputStream(); + + std::unique_ptr> buffer = std::make_unique>(); + + juce::AudioFormatReader* reader = formatManager.createReaderFor (input); + + if (reader != nullptr) + { + buffer->setSize(reader->numChannels, (int)reader->lengthInSamples); + reader->read(buffer.get(), + 0, + (int)reader->lengthInSamples, + 0, + true, + true); + } + + // strip extension to make prefix for output filenames + std::filesystem::path output_file_prefix = wav_file; + output_file_prefix.replace_extension(); + + // output dir passed as argument + std::string out_dir = tokens[1].toStdString(); + + // Check if the output directory exists, and create it if not + std::filesystem::path output_dir_path(out_dir); + if (!std::filesystem::exists(output_dir_path)) + { + std::cerr << "Directory does not exist: " << out_dir << ". Creating it." + << std::endl; + if (!std::filesystem::create_directories(output_dir_path)) + { + std::cerr << "Error: Unable to create directory: " << out_dir + << std::endl; + return; + } + } + else if (!std::filesystem::is_directory(output_dir_path)) + { + std::cerr << "Error: " << out_dir << " exists but is not a directory!" + << std::endl; + return; + } + + BasicPitch basicPitch; + + float inNoteSensitivity = 0.7; + float inPitchSensitivity = 0.5; + float inMinNoteDurationMs = 50; + + basicPitch.setParameters(inNoteSensitivity, inPitchSensitivity, inMinNoteDurationMs); + + auto start = std::chrono::high_resolution_clock::now(); + basicPitch.transcribeToMIDI(buffer->getWritePointer(0), + buffer->getNumSamples()); + auto stop = std::chrono::high_resolution_clock::now(); + auto processingTime = std::chrono::duration_cast(stop - start); + std::cout << "Audio converted to MIDI in " << convertMillis(processingTime.count()) << "s" << std::endl; + + auto noteEvents = basicPitch.getNoteEvents(); + + auto outFileName = output_file_prefix.string() + ".mid"; + + auto fullOutputPath = output_dir_path / outFileName; + + juce::File outFile { fullOutputPath.string() }; + + MidiFileWriter midiFileWriter; + TimeQuantizeOptions::TimeQuantizeInfo quantizeInfo; + double bpm = 120; + PitchBendModes pitchBendMode = PitchBendModes::NoPitchBend; + + if (!midiFileWriter.writeMidiFile(noteEvents, outFile, quantizeInfo, bpm, pitchBendMode)) { + std::cout << "Writing MIDI file failed" << std::endl; + } + + delete(reader); + + quit(); + } + + void shutdown() override + { + + } +}; + +START_JUCE_APPLICATION(App) \ No newline at end of file diff --git a/ThirdParty/JUCE b/ThirdParty/JUCE deleted file mode 160000 index 51a8a6d..0000000 --- a/ThirdParty/JUCE +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 51a8a6d7aeae7326956d747737ccf1575e61e209 diff --git a/ThirdParty/RTNeural b/ThirdParty/RTNeural deleted file mode 160000 index 2ca066e..0000000 --- a/ThirdParty/RTNeural +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2ca066e5e529ea3ed4120c368ee16526eeaf2cec diff --git a/ThirdParty/onnxruntime/include/coreml_provider_factory.h b/ThirdParty/onnxruntime/include/coreml_provider_factory.h new file mode 100644 index 0000000..f7857a3 --- /dev/null +++ b/ThirdParty/onnxruntime/include/coreml_provider_factory.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "onnxruntime_c_api.h" + +// COREMLFlags are bool options we want to set for CoreML EP +// This enum is defined as bit flags, and cannot have negative value +// To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below, +// uint32_t coreml_flags = 0; +// coreml_flags |= COREML_FLAG_USE_CPU_ONLY; +enum COREMLFlags { + COREML_FLAG_USE_NONE = 0x000, + + // Using CPU only in CoreML EP, this may decrease the perf but will provide + // reference output value without precision loss, which is useful for validation + COREML_FLAG_USE_CPU_ONLY = 0x001, + + // Enable CoreML EP on subgraph + COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002, + + // By default CoreML Execution provider will be enabled for all compatible Apple devices + // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine) + // Please note, enable this option does not guarantee the entire model to be executed using ANE only + COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004, + + // Keep COREML_FLAG_MAX at the end of the enum definition + // And assign the last COREMLFlag to it + COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE, +}; + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML, + _In_ OrtSessionOptions* options, uint32_t coreml_flags); + +#ifdef __cplusplus +} +#endif diff --git a/ThirdParty/onnxruntime/include/cpu_provider_factory.h b/ThirdParty/onnxruntime/include/cpu_provider_factory.h new file mode 100644 index 0000000..2926786 --- /dev/null +++ b/ThirdParty/onnxruntime/include/cpu_provider_factory.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \param use_arena zero: false. non-zero: true. + */ +ORT_EXPORT +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) +ORT_ALL_ARGS_NONNULL; + +#ifdef __cplusplus +} +#endif diff --git a/ThirdParty/onnxruntime/include/onnxruntime_c_api.h b/ThirdParty/onnxruntime/include/onnxruntime_c_api.h new file mode 100644 index 0000000..44875b0 --- /dev/null +++ b/ThirdParty/onnxruntime/include/onnxruntime_c_api.h @@ -0,0 +1,3987 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// See docs\c_cxx\README.md on generating the Doxygen documentation from this file + +/** \mainpage C & C++ APIs + * + *

C

+ * + * ::OrtApi - Click here to go to the structure with all C API functions. + * + *

C++

+ * + * ::Ort - Click here to go to the namespace holding all of the C++ wrapper classes + * + * It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to + * automate memory management through standard C++ RAII principles. + * + * \addtogroup Global + * ONNX Runtime C API + * @{ + */ + +#pragma once +#include +#include +#include + +/** \brief The API version defined in this header + * + * This value is used by some API functions to behave as this version of the header expects. + */ +#define ORT_API_VERSION 14 + +#ifdef __cplusplus +extern "C" { +#endif + +//! @} +// SAL2 Definitions +#ifndef _WIN32 +#define _In_ +#define _In_z_ +#define _In_opt_ +#define _In_opt_z_ +#define _Out_ +#define _Outptr_ +#define _Out_opt_ +#define _Inout_ +#define _Inout_opt_ +#define _Frees_ptr_opt_ +#define _Ret_maybenull_ +#define _Ret_notnull_ +#define _Check_return_ +#define _Outptr_result_maybenull_ +#define _In_reads_(X) +#define _Inout_updates_all_(X) +#define _Out_writes_bytes_all_(X) +#define _Out_writes_all_(X) +#define _Success_(X) +#define _Outptr_result_buffer_maybenull_(X) +#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) +#else +#include +#define ORT_ALL_ARGS_NONNULL +#endif + +#ifdef _WIN32 +// Define ORT_DLL_IMPORT if your program is dynamically linked to Ort. +// dllexport is not used, we use a .def file. +#ifdef ORT_DLL_IMPORT +#define ORT_EXPORT __declspec(dllimport) +#else +#define ORT_EXPORT +#endif +#define ORT_API_CALL _stdcall +#define ORT_MUST_USE_RESULT +#define ORTCHAR_T wchar_t +#else +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define ORT_EXPORT __attribute__((visibility("default"))) +#else +#define ORT_EXPORT +#endif +#define ORT_API_CALL +#define ORT_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define ORTCHAR_T char +#endif + +#ifndef ORT_TSTR +#ifdef _WIN32 +#define ORT_TSTR(X) L##X +#else +#define ORT_TSTR(X) X +#endif +#endif + +// Any pointer marked with _In_ or _Out_, cannot be NULL. + +// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation +// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that. +// for ReleaseXXX(...) functions, they can accept NULL pointer. + +#ifdef __cplusplus +// For any compiler with C++11 support, MSVC 2015 and greater, or Clang version supporting noexcept. +// Such complex condition is needed because compilers set __cplusplus value differently. +#ifndef __has_feature +#define __has_feature(x) 0 +#endif +#if ((__cplusplus >= 201103L) || (_MSC_VER >= 1900) || (defined(__has_feature) && __has_feature(cxx_noexcept))) +#define NO_EXCEPTION noexcept +#else +#define NO_EXCEPTION throw() +#endif +#else +#define NO_EXCEPTION +#endif + +// __VA_ARGS__ on Windows and Linux are different +#define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_API_STATUS(NAME, ...) \ + _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ + NO_EXCEPTION ORT_MUST_USE_RESULT + +// XXX: Unfortunately, SAL annotations are known to not work with function pointers +#define ORT_API2_STATUS(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT +#define ORT_API_STATUS_IMPL(NAME, ...) \ + _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) + +#ifdef __DOXYGEN__ +#undef ORT_API_STATUS +#define ORT_API_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) +#undef ORT_API2_STATUS +#define ORT_API2_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) +#undef ORT_CLASS_RELEASE +#define ORT_CLASS_RELEASE(X) void Release##X(Ort##X* input) +#undef NO_EXCEPTION +#define NO_EXCEPTION +#endif +/** \addtogroup Global + * ONNX Runtime C API + * @{ + */ + +/** Copied from TensorProto::DataType + * Currently, Ort doesn't support complex64, complex128 + */ +typedef enum ONNXTensorElementDataType { + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string + ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision +} ONNXTensorElementDataType; + +// Synced with onnx TypeProto oneof +typedef enum ONNXType { + ONNX_TYPE_UNKNOWN, + ONNX_TYPE_TENSOR, + ONNX_TYPE_SEQUENCE, + ONNX_TYPE_MAP, + ONNX_TYPE_OPAQUE, + ONNX_TYPE_SPARSETENSOR, + ONNX_TYPE_OPTIONAL +} ONNXType; + +// These types are synced with internal +// SparseFormatFlags +typedef enum OrtSparseFormat { + ORT_SPARSE_UNDEFINED = 0, + ORT_SPARSE_COO = 0x1, + ORT_SPARSE_CSRC = 0x2, + ORT_SPARSE_BLOCK_SPARSE = 0x4 +} OrtSparseFormat; + +// Enum allows to query sparse tensor indices +enum OrtSparseIndicesFormat { + ORT_SPARSE_COO_INDICES, + ORT_SPARSE_CSR_INNER_INDICES, + ORT_SPARSE_CSR_OUTER_INDICES, + ORT_SPARSE_BLOCK_SPARSE_INDICES +}; + +/** \brief Logging severity levels + * + * In typical API usage, specifying a logging severity level specifies the minimum severity of log messages to show. + */ +typedef enum OrtLoggingLevel { + ORT_LOGGING_LEVEL_VERBOSE, ///< Verbose informational messages (least severe). + ORT_LOGGING_LEVEL_INFO, ///< Informational messages. + ORT_LOGGING_LEVEL_WARNING, ///< Warning messages. + ORT_LOGGING_LEVEL_ERROR, ///< Error messages. + ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe). +} OrtLoggingLevel; + +typedef enum OrtErrorCode { + ORT_OK, + ORT_FAIL, + ORT_INVALID_ARGUMENT, + ORT_NO_SUCHFILE, + ORT_NO_MODEL, + ORT_ENGINE_ERROR, + ORT_RUNTIME_EXCEPTION, + ORT_INVALID_PROTOBUF, + ORT_MODEL_LOADED, + ORT_NOT_IMPLEMENTED, + ORT_INVALID_GRAPH, + ORT_EP_FAIL, +} OrtErrorCode; + +typedef enum OrtOpAttrType { + ORT_OP_ATTR_UNDEFINED = 0, + ORT_OP_ATTR_INT, + ORT_OP_ATTR_INTS, + ORT_OP_ATTR_FLOAT, + ORT_OP_ATTR_FLOATS, + ORT_OP_ATTR_STRING, + ORT_OP_ATTR_STRINGS, +} OrtOpAttrType; + +//! @} +#define ORT_RUNTIME_CLASS(X) \ + struct Ort##X; \ + typedef struct Ort##X Ort##X; + +/** \addtogroup Global + * ONNX Runtime C API + * @{ + */ +// The actual types defined have an Ort prefix +ORT_RUNTIME_CLASS(Env); +ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success +ORT_RUNTIME_CLASS(MemoryInfo); +ORT_RUNTIME_CLASS(IoBinding); +ORT_RUNTIME_CLASS(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) +ORT_RUNTIME_CLASS(Value); +ORT_RUNTIME_CLASS(RunOptions); +ORT_RUNTIME_CLASS(TypeInfo); +ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); +ORT_RUNTIME_CLASS(MapTypeInfo); +ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(ModelMetadata); +ORT_RUNTIME_CLASS(ThreadPoolParams); +ORT_RUNTIME_CLASS(ThreadingOptions); +ORT_RUNTIME_CLASS(ArenaCfg); +ORT_RUNTIME_CLASS(PrepackedWeightsContainer); +ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2); +ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); +ORT_RUNTIME_CLASS(CANNProviderOptions); +ORT_RUNTIME_CLASS(Op); +ORT_RUNTIME_CLASS(OpAttr); + +#ifdef _WIN32 +typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; +#else +typedef OrtStatus* OrtStatusPtr; +#endif + +/** \brief Memory allocation interface + * + * Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. + * + * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. + */ +typedef struct OrtAllocator { + uint32_t version; ///< Must be initialized to ORT_API_VERSION + void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes + void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc + const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator +} OrtAllocator; + +typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + +/** \brief Graph optimization level + * + * Refer to https://www.onnxruntime.ai/docs/resources/graph-optimizations.html + * for an in-depth understanding of Graph Optimizations + */ +typedef enum GraphOptimizationLevel { + ORT_DISABLE_ALL = 0, + ORT_ENABLE_BASIC = 1, + ORT_ENABLE_EXTENDED = 2, + ORT_ENABLE_ALL = 99 +} GraphOptimizationLevel; + +typedef enum ExecutionMode { + ORT_SEQUENTIAL = 0, + ORT_PARALLEL = 1, +} ExecutionMode; + +/** \brief Language projection identifiers + * /see OrtApi::SetLanguageProjection + */ +typedef enum OrtLanguageProjection { + ORT_PROJECTION_C = 0, + ORT_PROJECTION_CPLUSPLUS = 1, + ORT_PROJECTION_CSHARP = 2, + ORT_PROJECTION_PYTHON = 3, + ORT_PROJECTION_JAVA = 4, + ORT_PROJECTION_WINML = 5, + ORT_PROJECTION_NODEJS = 6, +} OrtLanguageProjection; + +struct OrtKernelInfo; +typedef struct OrtKernelInfo OrtKernelInfo; +struct OrtKernelContext; +typedef struct OrtKernelContext OrtKernelContext; +struct OrtCustomOp; +typedef struct OrtCustomOp OrtCustomOp; + +typedef enum OrtAllocatorType { + OrtInvalidAllocator = -1, + OrtDeviceAllocator = 0, + OrtArenaAllocator = 1 +} OrtAllocatorType; + +/** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. + */ +// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc +typedef enum OrtMemType { + OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider + OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeDefault = 0, ///< The default allocator for execution provider +} OrtMemType; + +/** \brief This mimics OrtDevice type constants so they can be returned in the API + */ +typedef enum OrtMemoryInfoDeviceType { + OrtMemoryInfoDeviceType_CPU = 0, + OrtMemoryInfoDeviceType_GPU = 1, + OrtMemoryInfoDeviceType_FPGA = 2 +} OrtMemoryInfoDeviceType; + +/** \brief Algorithm to use for cuDNN Convolution Op + */ +typedef enum OrtCudnnConvAlgoSearch { + OrtCudnnConvAlgoSearchExhaustive, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx + OrtCudnnConvAlgoSearchHeuristic, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 + OrtCudnnConvAlgoSearchDefault, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM +} OrtCudnnConvAlgoSearch; + +/** \brief CUDA Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA + */ +typedef struct OrtCUDAProviderOptions { +#ifdef __cplusplus + OrtCUDAProviderOptions() + : device_id{}, + cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + tunable_op_enabled{false} {} +#endif + + /** \brief CUDA device Id + * Defaults to 0. + */ + int device_id; + + /** \brief CUDA Convolution algorithm search configuration. + * See enum OrtCudnnConvAlgoSearch for more details. + * Defaults to OrtCudnnConvAlgoSearchExhaustive. + */ + OrtCudnnConvAlgoSearch cudnn_conv_algo_search; + + /** \brief CUDA memory limit (To use all possible memory pass in maximum size_t) + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + size_t gpu_mem_limit; + + /** \brief Strategy used to grow the memory arena + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + int arena_extend_strategy; + + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ + int do_copy_in_default_stream; + + /** \brief Flag indicating if there is a user provided compute stream + * Defaults to 0. + */ + int has_user_compute_stream; + + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ + void* user_compute_stream; + + /** \brief CUDA memory arena configuration parameters + */ + OrtArenaCfg* default_memory_arena_cfg; + + /** \brief Enable TunableOp. + * Set it to 1 to enable TunableOp. Otherwise, it is disabled by default. + * This option can be superseded by environment variable ORT_CUDA_TUNABLE_OP_ENABLED. + */ + int tunable_op_enabled; + +} OrtCUDAProviderOptions; + +/** \brief ROCM Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM + */ +typedef struct OrtROCMProviderOptions { +#ifdef __cplusplus + OrtROCMProviderOptions() + : device_id{}, + miopen_conv_exhaustive_search{0}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + tunable_op_enabled{false} {} +#endif + + /** \brief ROCM device Id + * Defaults to 0. + */ + int device_id; + + /** \brief ROCM MIOpen Convolution algorithm exaustive search option. + * Defaults to 0 (false). + */ + int miopen_conv_exhaustive_search; + + /** \brief ROCM memory limit (To use all possible memory pass in maximum size_t) + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + size_t gpu_mem_limit; + + /** \brief Strategy used to grow the memory arena + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + int arena_extend_strategy; + + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ + int do_copy_in_default_stream; + + /** \brief Flag indicating if there is a user provided compute stream + * Defaults to 0. + */ + int has_user_compute_stream; + + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ + void* user_compute_stream; + + /** \brief ROCM memory arena configuration parameters + */ + OrtArenaCfg* default_memory_arena_cfg; + + /** \brief Enable TunableOp. + * Set it to 1 to enable TunableOp. Otherwise, it is disabled by default. + * This option can be superseded by environment variable ORT_ROCM_TUNABLE_OP_ENABLED. + */ + int tunable_op_enabled; + +} OrtROCMProviderOptions; + +/** \brief TensorRT Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + */ +typedef struct OrtTensorRTProviderOptions { + int device_id; ///< CUDA device id (0 = default device) + int has_user_compute_stream; // indicator of user specified CUDA compute stream. + void* user_compute_stream; // user specified CUDA compute stream. + int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability + int trt_min_subgraph_size; // minimum size of TensorRT subgraphs + size_t trt_max_workspace_size; // maximum workspace size for TensorRT. + int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true + int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true + const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. + int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true + int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true + int trt_dla_core; // DLA core number. Default 0 + int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true + int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true + const char* trt_engine_cache_path; // specify engine cache path + int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true + const char* trt_engine_decryption_lib_path; // specify engine decryption library path + int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true + // This is the legacy struct and don't add new fields here. + // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h + // For non-string field, need to create a new separate api to handle it. +} OrtTensorRTProviderOptions; + +/** \brief MIGraphX Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + */ +typedef struct OrtMIGraphXProviderOptions { + int device_id; // hip device id. + int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true +} OrtMIGraphXProviderOptions; + +/** \brief OpenVINO Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + */ +typedef struct OrtOpenVINOProviderOptions { +#ifdef __cplusplus + OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, + num_of_threads{}, cache_dir{}, + context{}, enable_opencl_throttling{}, enable_dynamic_shapes{} {} +#endif + /** \brief Device type string + * + * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "MYRIAD_FP16", "VAD-M_FP16" or "VAD-F_FP32" + */ + const char* device_type; + unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + const char* device_id; + size_t num_of_threads; ///< 0 = Use default number of threads + const char* cache_dir; // path is set to empty by default + void* context; + unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled + unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled +} OrtOpenVINOProviderOptions; + +struct OrtApi; +typedef struct OrtApi OrtApi; + +struct OrtTrainingApi; +typedef struct OrtTrainingApi OrtTrainingApi; + +/** \brief The helper interface to get the right version of OrtApi + * + * Get a pointer to this structure through ::OrtGetApiBase + */ +struct OrtApiBase { + /** \brief Get a pointer to the requested version of the ::OrtApi + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime + * older than the version created with this header file. + */ + const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; + const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") +}; +typedef struct OrtApiBase OrtApiBase; + +/** \brief The Onnxruntime library's entry point to access the C API + * + * Call this to get the a pointer to an ::OrtApiBase + */ +ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; + +/** \brief Thread work loop function + * + * Onnxruntime will provide the working loop on custom thread creation + * Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn + */ +typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); + +typedef const struct OrtCustomHandleType { + char __place_holder; +}* OrtCustomThreadHandle; + +/** \brief Ort custom thread creation function + * + * The function should return a thread handle to be used in onnxruntime thread pools + * Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread + */ +typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); + +/** \brief Custom thread join function + * + * Onnxruntime thread pool destructor will call the function to join a custom thread. + * Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn + */ +typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); + +typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); + +/** \brief The C API + * + * All C API functions are defined inside this structure as pointers to functions. + * Call OrtApiBase::GetApi to get a pointer to it + * + * \nosubgrouping + */ +struct OrtApi { + /// \name OrtStatus + /// @{ + + /** + * \brief Create an OrtStatus from a null terminated string + * + * \param[in] code + * \param[in] msg A null-terminated string. Its contents will be copied. + * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus + */ + OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Get OrtErrorCode from OrtStatus + * + * \param[in] status + * \return OrtErrorCode that \p status was created with + */ + OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Get error string from OrtStatus + * + * \param[in] status + * \return The error message inside the `status`. Do not free the returned value. + */ + const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an OrtEnv + * + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnv, OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + /** \brief Create an OrtEnv + * + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + /** \brief Enable Telemetry + * + * \note Telemetry events are on by default since they are lightweight + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); + /** \brief Disable Telemetry + * + * \see OrtApi::EnableTelemetryEvents + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Create an OrtSession from a model file + * + * \param[in] env + * \param[in] model_path + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + // TODO: document the path separator convention? '/' vs '\' + // TODO: should specify the access characteristics of model_path. Is this read only during the + // execution of CreateSession, or does the OrtSession retain a handle to the file/directory + // and continue to access throughout the OrtSession lifetime? + // What sort of access is needed to model_path : read or read/write? + ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession from memory + * + * \param[in] env + * \param[in] model_data + * \param[in] model_data_length + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Run the model in an ::OrtSession + * + * Will not return until the model run has completed. Multiple threads might be used to run the model based on + * the options in the ::OrtSession and settings used when creating the ::OrtEnv + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] inputs Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be + * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers + * to them will be set into the `outputs` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(Run, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* inputs, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** outputs); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Create an ::OrtSessionOptions object + * + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session:
+ * OrtSessionOptionsAppendExecutionProvider_CPU
+ * OrtSessionOptionsAppendExecutionProvider_CUDA
+ * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
+ * The order they are called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + * + * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); + + /** \brief Set filepath to save optimized model after graph level transformations + * + * \param[in] options + * \param[in] optimized_model_filepath + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, + _In_ const ORTCHAR_T* optimized_model_filepath); + + /** \brief Create a copy of an existing ::OrtSessionOptions + * + * \param[in] in_options OrtSessionOptions to copy + * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, + _Outptr_ OrtSessionOptions** out_options); + + /** \brief Set execution mode + * + * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model + * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. + * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. + * + * \param[in] options + * \param[in] execution_mode + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); + + /** \brief Enable profiling for a session + * + * \param[in] options + * \param[in] profile_file_prefix + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); + + /** \brief Disable profiling for a session + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); + + /** \brief Enable the memory pattern optimization + * + * The idea is if the input shapes are the same, we could trace the internal memory allocation + * and generate a memory pattern for future request. So next time we could just do one allocation + * with a big chunk for all the internal memory allocation. + * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) + * + * \see OrtApi::DisableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); + + /** \brief Disable the memory pattern optimization + * + * \see OrtApi::EnableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); + + /** \brief Enable the memory arena on CPU + * + * Arena may pre-allocate memory for future usage. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); + + /** \brief Disable the memory arena on CPU + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); + + /** \brief Set session log id + * + * \param[in] options + * \param[in] logid The log identifier. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); + + /** \brief Set session log verbosity level + * + * Applies to session load, initialization, etc + * + * \param[in] options + * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); + + /** \brief Set session log severity level + * + * \param[in] options + * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); + + /** \brief Set the optimization level to apply when loading a graph + * + * Please see https://www.onnxruntime.ai/docs/resources/graph-optimizations.html for an in-depth explanation + * \param[in,out] options The session options object + * \param[in] graph_optimization_level The optimization level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, + GraphOptimizationLevel graph_optimization_level); + + /** \brief Sets the number of threads used to parallelize the execution within nodes + * + * When running a single node operation, ex. add, this sets the maximum number of threads to use. + * + * \note If built with OpenMP, this has no effect on the number of threads used. In this case + * use the OpenMP env variables to configure the number of intra op num threads. + * + * \param[in] options + * \param[in] intra_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); + + /** \brief Sets the number of threads used to parallelize the execution of the graph + * + * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. + * + * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. + * + * \param[in] options + * \param[in] inter_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); + + /// @} + /// \name OrtCustomOpDomain + /// @{ + + /** \brief Create a custom op domain + * + * \param[in] domain + * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); + + /** \brief Add a custom op to a custom op domain + * + * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released + * + * \param[in] custom_op_domain + * \param[in] op + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Add custom op domain to a session options + * + * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released + * + * \param[in] options + * \param[in] custom_op_domain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); + + /** \deprecated Use OrtApi::RegisterCustomOpsLibrary_V2. + * + * Registers custom ops from a shared library. + * + * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in + * session options are destroyed, or if an error occurs and it is non null. + * + * \param[in] options + * \param[in] library_path + * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Get input count for a session + * + * This number must also match the number of inputs passed to OrtApi::Run + * + * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of inputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get output count for a session + * + * This number must also match the number of outputs returned by OrtApi::Run + * + * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of outputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get overridable initializer count + * + * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName + * + * \param[in] session + * \param[in] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get input type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get output type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get overridable initializer type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get input name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get output name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get overridable initializer name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* session, size_t index, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /// @} + /// \name OrtRunOptions + /// @{ + + /** \brief Create an OrtRunOptions + * + * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); + + /** \brief Set per-run log verbosity level + * + * \see OrtApi::RunOptionsGetRunLogVerbosityLevel + * + * \param[in] options + * \param[in] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int log_verbosity_level); + + /** \brief Set per-run log severity level + * + * \see OrtApi::RunOptionsGetRunLogSeverityLevel + * + * \param[in] options + * \param[in] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + */ + ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int log_severity_level); + + /** \brief Set per-run tag + * + * This is used in a per-run log identifier. + * + * \see OrtApi::RunOptionsGetRunTag + * + * \param[in] options + * \param[in] run_tag The run tag. + */ + ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions* options, _In_ const char* run_tag); + + /** \brief Get per-run log verbosity level + * + * \see OrtApi::RunOptionsSetRunLogVerbosityLevel + * + * \param[in] options + * \param[out] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, + _Out_ int* log_verbosity_level); + + /** \brief Get per-run log severity level + * + * \see OrtApi::RunOptionsSetRunLogSeverityLevel + * + * \param[in] options + * \param[out] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + */ + ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* log_severity_level); + + /** \brief Get per-run tag + * + * This is used in a per-run log identifier. + * + * \see OrtApi::RunOptionsSetRunTag + * + * \param[in] options + * \param[out] run_tag The run tag. + * Do not free this value, it is owned by `options`. It will be invalidated if the run tag + * changes (i.e., with OrtApi::RunOptionsSetRunTag) or `options` is freed. + */ + ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** run_tag); + + /** \brief Set terminate flag + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); + + /** \brief Clears the terminate flag + * + * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Create a tensor + * + * Create a tensor using a supplied ::OrtAllocator + * + * \param[in] allocator + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** \brief Create a tensor backed by a user supplied buffer + * + * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. + * p_data is owned by caller. ReleaseValue won't release p_data. + * + * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param[in] p_data Pointer to the data buffer. + * \param[in] p_data_len The number of bytes in the data buffer. + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type The data type. + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** \brief Return if an ::OrtValue is a tensor type + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); + + /** \brief Get a pointer to the raw data inside a tensor + * + * Used to read/write/modify the internal tensor data directly. + * \note The returned pointer is valid until the \p value is destroyed. + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Filled in with a pointer to the internal storage + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorMutableData, _In_ OrtValue* value, _Outptr_ void** out); + + /** \brief Set all strings at once in a string tensor + * + * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s An array of strings. Each string in this array must be null terminated. + * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); + + /** \brief Get total byte length for all strings in a string tensor + * + * Typically used with OrtApi::GetStringTensorContent + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[out] len Total byte length of all strings (does not include trailing nulls) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); + + /** \brief Get all strings from a string tensor + * + * An example of the results:
+ * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
+ * \p s must have a size of 11 bytes
+ * \p offsets must have 4 elements
+ * After the call, these values will be filled in:
+ * \p s will contain "Thisisatest"
+ * \p offsets will contain { 0, 4, 6, 7 }
+ * The length of the last string is just s_len - offsets[last] + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. + * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) + * \param[out] offsets Array of start offsets into the strings written to \p s + * \param[in] offsets_len Number of elements in offsets + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); + + /// @} + /// \name OrtTypeInfo + /// @{ + + /** \brief Get ::OrtTensorTypeAndShapeInfo from an ::OrtTypeInfo + * + * \param[in] type_info + * \param[out] out Do not free this value, it will be valid until type_info is freed. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); + + /** \brief Get ::ONNXType from ::OrtTypeInfo + * + * \param[in] type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ enum ONNXType* out); + + /// @} + /// \name OrtTensorTypeAndShapeInfo + /// @{ + + /** \brief Create an ::OrtTensorTypeAndShapeInfo object + * + * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Set element type in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] type + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* info, enum ONNXTensorElementDataType type); + + /** \brief Set shape information in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. + * \param[in] dim_count Number of elements in `dim_values` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + /** \brief Get element type in ::OrtTensorTypeAndShapeInfo + * + * \see OrtApi::SetTensorElementType + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_ enum ONNXTensorElementDataType* out); + + /** \brief Get dimension count in ::OrtTensorTypeAndShapeInfo + * + * \see OrtApi::GetDimensions + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + + /** \brief Get dimensions in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo + * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, + size_t dim_values_length); + + /** \brief Get symbolic dimension names in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names + * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); + + /** \brief Get total number of elements in a tensor shape from an ::OrtTensorTypeAndShapeInfo + * + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ * + * \param[in] info + * \param[out] out Number of elements + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Get type and shape information from a tensor ::OrtValue + * + * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure + * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Get type information of an OrtValue + * + * \param[in] value + * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); + + /** \brief Get ONNXType of an ::OrtValue + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); + + /// @} + /// \name OrtMemoryInfo + /// @{ + + /** \brief Create an ::OrtMemoryInfo + * + * \param[in] name + * \param[in] type + * \param[in] id + * \param[in] mem_type + * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id, + enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); + + /** \brief Create an ::OrtMemoryInfo for CPU memory + * + * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. + * + * \param[in] type + * \param[in] mem_type + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, + _Outptr_ OrtMemoryInfo** out); + + /** \brief Compare ::OrtMemoryInfo objects for equality + * + * Compares all settings of each ::OrtMemoryInfo for equality + * + * \param[in] info1 + * \param[in] info2 + * \param[out] out Set to 0 if equal, -1 if not equal + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); + + /** \brief Get name from ::OrtMemoryInfo + * + * \param[in] ptr + * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); + + /** \brief Get the id from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); + + /** \brief Get the ::OrtMemType from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); + + /** \brief Get the ::OrtAllocatorType from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); + + /// @} + /// \name OrtAllocator + /// @{ + + /// \brief Calls OrtAllocator::Alloc function + ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ort_allocator, size_t size, _Outptr_ void** out); + /// \brief Calls OrtAllocator::Free function + ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ort_allocator, void* p); + /// \brief Calls OrtAllocator::Info function + ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ort_allocator, _Outptr_ const struct OrtMemoryInfo** out); + + /** \brief Get the default allocator + * + * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. + * + * \param[out] out Returned value should NOT be freed + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Override session symbolic dimensions + * + * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable + * optimizations that can take advantage of fixed values (such as memory planning, etc) + * + * \param[in] options + * \param[in] dim_denotation + * \param[in] dim_value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, + _In_ int64_t dim_value); + + /// @} + /// \name OrtValue + /// @{ + + /* Internal information (not seen in Doxygen) + * + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ + + /** \brief Get non tensor data from an ::OrtValue + * + * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element + * of the sequence. + * + * \param[in] value + * \param[in] index See above for usage based on `value` type + * \param[in] allocator Allocator used to allocate ::OrtValue + * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out); + + /** \brief Get non tensor value count from an ::OrtValue + * + * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE + * the number of elements in the sequence will be returned + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); + + /** \brief Create a map or sequence ::OrtValue + * + * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s + * representing keys and values.
+ * + * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an array of N ::OrtValue%s. + * + * \param[in] in See above for details + * \param[in] num_values + * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out); + + /** \brief Create an opaque (custom user defined type) ::OrtValue + * + * Constructs an ::OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. ::OrtValue in this case would contain + * an internal representation of the Opaque type. Opaque types are distinguished from + * each other by two strings 1) domain and 2) type name. The combination of the two + * must be unique, so the type representation is properly identified internally. The combination + * must be properly registered from within ORT at both compile/run time or by another API. + * + * To construct the ::OrtValue pass domain and type names, also a pointer to a data container + * the type of which must be known to both ORT and the client program. That data container may or may + * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for + * verification purposes. + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] data_container User pointer Data to populate ::OrtValue + * \param[in] data_container_size Size in bytes of what `data_container` points to + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, + _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); + + /** \brief Get internal data from an opaque (custom user defined type) ::OrtValue + * + * Copies internal data from an opaque value into a user provided buffer + * + * \see OrtApi::CreateOpaqueValue + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] in The opaque ::OrtValue + * \param[out] data_container Buffer to copy data into + * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, + _Out_ void* data_container, size_t data_container_size); + + /// @} + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get a float stored as an attribute in the graph node + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out); + + /** \brief Fetch a 64-bit int stored as an attribute in the graph node + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out); + + /** \brief Fetch a string stored as an attribute in the graph node + * + * If `out` is nullptr, the value of `size` is set to the true size of the string + * attribute, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string attribute's size, + * the value of `size` is set to the true size of the string attribute, the provided memory + * is filled with the attribute's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string attribute's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string attribute + * and a failure status is returned.) + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * \param[in,out] size See above comments for details + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, + _Inout_ size_t* size); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Used for custom operators, get the input count of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + + /** \brief Used for custom operators, get the output count of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + + /** \brief Used for custom operators, get an input of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, + _Out_ const OrtValue** out); + + /** \brief Used for custom operators, get an output of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); + + /// @} + /// \name OrtEnv + /// @{ + ORT_CLASS_RELEASE(Env); + /// @} + /// \name OrtStatus + /// @{ + ORT_CLASS_RELEASE(Status); + /// @} + /// \name OrtMemoryInfo + /// @{ + ORT_CLASS_RELEASE(MemoryInfo); + /// @} + /// \name OrtSession + /// @{ + ORT_CLASS_RELEASE(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) + /// @} + /// \name OrtValue + /// @{ + ORT_CLASS_RELEASE(Value); + /// @} + /// \name OrtRunOptions + /// @{ + ORT_CLASS_RELEASE(RunOptions); + /// @} + /// \name OrtTypeInfo + /// @{ + ORT_CLASS_RELEASE(TypeInfo); + /// @} + /// \name OrtTensorTypeAndShapeInfo + /// @{ + ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); + /// @} + /// \name OrtSessionOptions + /// @{ + ORT_CLASS_RELEASE(SessionOptions); + /// @} + /// \name OrtCustomOpDomain + /// @{ + ORT_CLASS_RELEASE(CustomOpDomain); + + /// @} + /// \name OrtTypeInfo + /// @{ + + /** \brief Get denotation from type information + * + * Augments ::OrtTypeInfo to return denotations on the type. + * + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + * + * \param[in] type_info + * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. + * \param[out] len Length in bytes of the string returned in `denotation` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const denotation, + _Out_ size_t* len); + + /** \brief Get detailed map information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[out] type_info + * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out); + + /** \brief Cast ::OrtTypeInfo to an ::OrtSequenceTypeInfo + * + * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. + * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] type_info + * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); + + /// @} + /// \name OrtMapTypeInfo + /// @{ + + /** \brief Get key type from an ::OrtMapTypeInfo + * + * Key types are restricted to being scalar types. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] map_type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); + + /** \brief Get the value type from an ::OrtMapTypeInfo + * + * \param[in] map_type_info + * \param[out] type_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); + + /// @} + /// \name OrtSequenceTypeInfo + /// @{ + + /** \brief Get element type from an ::OrtSequenceTypeInfo + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] sequence_type_info + * \param[out] type_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, + _Outptr_ OrtTypeInfo** type_info); + + /// @} + /// \name OrtMapTypeInfo + /// @{ + ORT_CLASS_RELEASE(MapTypeInfo); + /// @} + /// \name OrtSequenceTypeInfo + /// @{ + ORT_CLASS_RELEASE(SequenceTypeInfo); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief End profiling and return filename of the profile data + * + * Profiling is turned on through OrtApi::EnableProfiling + * + * \param[in] session + * \param[in] allocator + * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* session, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); + + /** \brief Get ::OrtModelMetadata from an ::OrtSession + * + * \param[in] session + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* session, _Outptr_ OrtModelMetadata** out); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** \brief Get `producer name` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get `graph name` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get `domain` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + + /** \brief Get `description` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Return data for a key in the custom metadata map in an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[in] key Null terminated string + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * `value` will be set to nullptr if the given key is not found in the custom metadata map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); + + /** \brief Get version number from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[out] value Set to the version number + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); + + ORT_CLASS_RELEASE(ModelMetadata); + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an OrtEnv + * + * Create an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel log_severity_level, _In_ const char* logid, + _In_ const OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Use global thread pool on a session + * + * Disable using per session thread pool and use the shared global threadpool. + * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Create an ::OrtThreadingOptions + * + * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); + + ORT_CLASS_RELEASE(ThreadingOptions); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. + * The strings and the pointer array must be freed using `allocator` + * `keys` will be set to nullptr if the custom metadata map is empty. + * \param[out] num_keys Set to the number of elements in the `keys` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** + * + * Override symbolic dimensions (by specific name strings) with actual values + * if known at session initialization time to enable optimizations that can + * take advantage of fixed values (such as memory planning, etc) + * + */ + ORT_API2_STATUS(AddFreeDimensionOverrideByName, + _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, + _In_ int64_t dim_value); + + /// @} + /// \name Misc + /// @{ + + /** \brief Get the names of all available providers + * + * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. + * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. + * + * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the + * array itself must be freed using OrtApi::ReleaseAvailableProviders + * \param[out] provider_length Set to the number of entries in the `out_ptr` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length); + + /** \brief Release data from OrtApi::GetAvailableProviders + * + * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. + * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, + _In_ int providers_length); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Get the length of a single string in a string tensor + * + * \param[in] value A string tensor + * \param[in] index Index of the string in the tensor + * \param[out] out Set to number of bytes of the string element + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); + + /** \brief Get a single string from a string tensor + * + * \param[in] value A string tensor + * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. + * \param[in] index Index of the string in the tensor + * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); + + /** \brief Set a single string in a string tensor + * + * \param[in] value A string tensor + * \param[in] s A null terminated UTF-8 encoded string + * \param[in] index Index of the string in the tensor to set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Set a session configuration entry as a pair of strings + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value. + * + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, + _In_z_ const char* config_key, _In_z_ const char* config_value); + + /// @} + /// \name OrtAllocator + /// @{ + + /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo + * + * \param[in] session + * \param[in] mem_info valid ::OrtMemoryInfo instance + * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* session, _In_ const OrtMemoryInfo* mem_info, + _Outptr_ OrtAllocator** out); + + /** \brief Release an ::OrtAllocator obtained from OrtApi::CreateAllocator + */ + ORT_CLASS_RELEASE(Allocator); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Run a model using Io Bindings for the inputs & outputs + * + * \see OrtApi::Run + * + * \param[in] session + * \param[in] run_options + * \param[in] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* session, _In_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); + + /** \brief Create an ::OrtIoBinding instance + * + * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. + * Thus if you want to use a raw on device buffer as input or output you can avoid + * extra copy during runtime. + * + * \param[in] session + * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* session, _Outptr_ OrtIoBinding** out); + + /// @} + /// \name OrtIoBinding + /// @{ + + /** \brief Release an ::OrtIoBinding obtained from OrtApi::CreateIoBinding + */ + ORT_CLASS_RELEASE(IoBinding); + + /** \brief Bind an ::OrtValue to an ::OrtIoBinding input + * + * When using OrtApi::RunWithBinding this value is used for the named input + * + * \param[in] binding_ptr + * \param[in] name Name for the model input + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** \brief Bind an ::OrtValue to an ::OrtIoBinding output + * + * When using OrtApi::RunWithBinding this value is used for the named output + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the model output name + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** \brief Bind an ::OrtIoBinding output to a device + * + * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. + * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using + * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of + * memory within ::OrtValue ahead of time. + * + * \see OrtApi::RunWithBinding + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the device name + * \param[in] mem_info_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr); + + /** \brief Get the names of an ::OrtIoBinding's outputs + * + * Returns the names of the outputs in the order they were bound. This is useful after running the model + * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if + * the order of outputs and their names is not known. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. + * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, + * no memory allocation is performed and buffer and lengths are set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); + + /** \brief Get the output ::OrtValue objects from an ::OrtIoBinding + * + * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding + * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput + * or OrtApi::BindOutputToDevice. + * + * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. + * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after + * all the ::OrtValue%s contained therein are individually released. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate output array + * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. + * \param[out] output_count Set to number of ::OrtValue%s returned + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); + + /** \brief Clears any previously set Inputs for an ::OrtIoBinding + */ + void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Clears any previously set Outputs for an ::OrtIoBinding + */ + void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Direct memory access to a specified tensor element + * + * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved + * + * This function only works for numeric type tensors (No strings, etc). + * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. + * + * \param[in] value + * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape + * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. + * \param[out] out Set to a pointer to the element specified + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an allocator and register it with the ::OrtEnv + * + * Enables sharing the allocator between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * See https://onnxruntime.ai/docs/reference/api/c-api.html for details. + * + * \param[in] env ::OrtEnv instance + * \param[in] mem_info + * \param[in] arena_cfg Pass nullptr for defaults + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, + _In_ const OrtArenaCfg* arena_cfg); + + /** \brief Set language projection + * + * Set the language projection for collecting telemetry data when Env is created. + * + * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. + * + * \param[in] ort_env + * \param[in] projection + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Return the time that profiling was started + * + * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns + * + * \param[in] session + * \param[out] out nanoseconds of profiling's start time + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* session, _Outptr_ uint64_t* out); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Set global intra-op thread count + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] intra_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); + + /** \brief Set global inter-op thread count + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] inter_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); + + /** \brief Set global spin control options + * + * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Allow spinning of thread pools when their queues are empty. This will set the value for both + * inter_op and intra_op threadpools. + * + * \param[in] tp_options + * \param[in] allow_spinning Valid values are 0 or 1.
+ * 0 = It won't spin (recommended if CPU usage is high)
+ * 1 = Threadpool will spin to wait for queue to become non-empty + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Add a pre-allocated initializer to a session + * + * If a model contains an initializer with a name that is same as the name passed to this call, + * ORT will use this initializer instance instead of deserializing one from the model file. This + * is useful when you want to share the same initializer across sessions. + * + * \param[in] options + * \param[in] name Null terminated string of the initializer name + * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be + * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object + * to which it is added. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, + _In_ const OrtValue* val); + + /// @} + /// \name OrtEnv + /// @{ + + /** + * Create a custom environment with global threadpools and logger that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel log_severity_level, + _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append CUDA provider to session options + * + * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, + _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); + + /** \brief Append ROCM execution provider to the session options + * + * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] rocm_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); + + /** \brief Append OpenVINO execution provider to the session options + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO, + _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Set threading flush-to-zero and denormal-as-zero + * + * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. + * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. + * + * \param[in] tp_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options); + + /// @} + /// \name OrtArenaCfg + /// @{ + + /** \deprecated Use OrtApi::CreateArenaCfgV2 + * + * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior + * + * \param[in] max_mem Use 0 to allow ORT to choose the default + * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default + * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default + * \param[in] out A pointer to an OrtArenaCfg instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, + int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); + + ORT_CLASS_RELEASE(ArenaCfg); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** + * Use this to obtain the description of the graph present in the model + * (doc_string field of the GraphProto message within the ModelProto message). + * If it doesn't exist, an empty string will be returned. + * + * \param[in] model_metadata An instance of ::OrtModelMetadata + * \param[in] allocator Allocator used to allocate the string that will be returned back + * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append TensorRT provider to session options + * + * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, + _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); + + /// @} + /// \name Misc + /// @{ + + /** \brief Set current GPU device ID + * + * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less + * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is + * required to restrict execution to a single GPU. + * + * \param[in] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetCurrentGpuDeviceId, _In_ int device_id); + + /** \brief Get current GPU device ID + * + * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). + * + * \see OrtApi::SetCurrentGpuDeviceId + * + * \param[out] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); + + /// @} + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Fetch an array of int64_t values stored as an attribute in the graph node + * + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out, _Inout_ size_t* size); + + /** \brief Fetch an array of int64_t values stored as an attribute in the graph node + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size); + + /// @} + /// \name OrtArenaCfg + /// @{ + + /** \brief Create an ::OrtArenaCfg + * + * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. + * + * Supported keys are (See https://onnxruntime.ai/docs/reference/api/c-api.html for details on what the + * following parameters mean and how to choose these values.): + * "max_mem": Maximum memory that can be allocated by the arena based allocator. + * Use 0 for ORT to pick the best value. Default is 0. + * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. + * Use -1 to allow ORT to choose the default. + * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * Ultimately, the first allocation size is determined by the allocation memory request. + * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after + * crossing which the current chunk is chunked into 2. + * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * Ultimately, the allocation size is determined by the allocation memory request. + * Further allocation sizes are governed by the arena extend strategy. + * + * \param[in] arena_config_keys Keys to configure the arena + * \param[in] arena_config_values Values to configure the arena + * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` + * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, + _In_reads_(num_keys) const size_t* arena_config_values, _In_ size_t num_keys, + _Outptr_ OrtArenaCfg** out); + + /// @} + /// \name OrtRunOptions + /// @{ + + /** \brief Set a single run configuration entry as a pair of strings + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value + * + * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddRunConfigEntry, _Inout_ OrtRunOptions* options, + _In_z_ const char* config_key, _In_z_ const char* config_value); + + /// @} + /// \name OrtPrepackedWeightsContainer + /// @{ + + /** \brief Create an ::OrtPrepackedWeightsContainer + * + * This container will hold pre-packed buffers of shared initializers for sharing between sessions + * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers + * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container + * to sessions that you would like to share pre-packed buffers of shared initializers at session + * creation time. + * + * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out); + + /** \brief Release OrtPrepackedWeightsContainer instance + * + * \note instance must not be released until the sessions using it are released + */ + ORT_CLASS_RELEASE(PrepackedWeightsContainer); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Create session with prepacked weights container + * + * Same functionality offered by OrtApi::CreateSession except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env OrtEnv instance instance + * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _Outptr_ OrtSession** out); + + /** \brief Create session from memory with prepacked weights container + * + * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env + * \param[in] model_data Array of bytes holding the model + * \param[in] model_data_length Number of bytes in `model_data_model` + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _Outptr_ OrtSession** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append TensorRT execution provider to the session options + * + * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an + * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 + * which must be created with OrtApi::CreateTensorRTProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions + * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. + * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT_V2, + _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options); + + /// @} + /// \name OrtTensorRTProviderOptionsV2 + /// @{ + + /** \brief Create an OrtTensorRTProviderOptionsV2 + * + * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out); + + /** \brief Set options in a TensorRT Execution Provider. + * + * Please refer to https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html#c-api-example + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 + * and value should be its related range. + * + * For example, key="trt_max_workspace_size" and value="2147483648" + * + * \param[in] tensorrt_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UpdateTensorRTProviderOptions, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get serialized TensorRT provider options string. + * + * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" + * + * \param tensorrt_options - OrTensorRTProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtTensorRTProviderOptionsV2 + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + */ + void(ORT_API_CALL* ReleaseTensorRTProviderOptions)(_Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* input); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Enable custom operators + * + * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); + + /// @} + /// \name OrtAllocator + /// @{ + + /** \brief Register a custom allocator + * + * Enables sharing between multiple sessions that use the same env instance. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except + * instead of ORT creating an allocator based on provided info, in this case + * ORT uses the user-provided custom allocator. + * See https://onnxruntime.ai/docs/reference/api/c-api.html for details. + * + * \param[in] env + * \param[in] allocator User provided allocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator); + + /** \brief Unregister a custom allocator + * + * It is an error if you provide an ::OrtMemoryInfo not corresponding to any + * registered allocators for sharing. + * + * \param[in] env + * \param[in] mem_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env, + _In_ const OrtMemoryInfo* mem_info); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Sets *out to 1 iff an ::OrtValue is a SparseTensor, and 0 otherwise + * + * \param[in] value existing ::OrtValue + * \param[out] out unless an error occurs, contains 1 iff the value contains an instance + * of sparse tensor or 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out); + + /** \brief Create an ::OrtValue with a sparse tensor that is empty. + * + * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and + * format specific indices data. + * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value + * if any was allocated. + * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed + * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan + * this sparse tensor instance as the same allocator will be used to free memory. + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, + size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values pointer to an array of values. For strings, pass const char**. + * \param[in] indices_data pointer to a location of COO indices + * \param[in] indices_num number of COO indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* indices_data, size_t indices_num); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values - pointer to an array of values. For strings, pass const char**. + * \param[in] inner_indices_data pointer to a location of CSR inner indices + * \param[in] inner_indices_num number of CSR inner indices + * \param[in] outer_indices_data pointer to a location of CSR outer indices + * \param[in] outer_indices_num number of CSR outer indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* inner_indices_data, size_t inner_indices_num, + _In_ const int64_t* outer_indices_data, size_t outer_indices_num); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape + * \param[in] values_shape_len + * \param[in] values structure with values information + * \param[in] indices_shape_data pointer to a location of indices shape + * \param[in] indices_shape_len length of the block sparse indices shape + * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* indices_shape_data, size_t indices_shape_len, + _In_ const int32_t* indices_data); + + /** + * Create an ::OrtValue with a sparse tensor. This is the first step. + * Next, use UseIndices() functions to supply sparse tensor with + * format specific indices data and set its sparse format to a specific enum value. + * This will not perform memory allocations. It will + * use supplied user buffer which should outlive the created sparse tensor. + * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. + * This function can not be used to map strings from the user allocated memory. Strings must always be copied + * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data + * using appropriate Make*() function. + * + * \param[in] info memory info where sparse values reside. + * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero + * values, pass nullptr + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, + * pass {0} shape. + * \param[in] values_shape_len number of values shape dimensions + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + _In_ const int64_t* dense_shape, size_t dense_shape_len, + _In_ const int64_t* values_shape, size_t values_shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** + * This assigns Coo format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of + * indices_data buffer should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or + * be twice as number of nnz values for a 2-D indices {nnz, 2} + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num); + + /** + * The assigns CSR format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of + * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. + * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or + * equal to rows + 1 of the dense shape. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num, + _Inout_ int64_t* outer_data, size_t outer_num); + + /** + * The assigns BlockSparse format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of + * indices_data buffer must eclipse the lifespan of this ::OrtValue. + * + * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors + * \param[in] indices_shape_len length of the indices shape + * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data); + + /** \brief Returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor. + * + * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor + * \param[out] out pointer to out parameter + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out); + + /** \brief Returns data type and shape of sparse tensor values (nnz) iff ::OrtValue contains a SparseTensor. + * + * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor + * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*(). + * + * \param[in] ort_value an instance of ::OrtValue containing sparse tensor + * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out); + + /** \brief Returns data type, shape for the type of indices specified by indices_format. + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse + * tensor does not contain. + * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Returns indices data for the type of the indices specified by indices_format + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. + * \param[out] num_indices Pointer to where the number of indices entries is returned + * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); + /// @} + /// \name OrtSessionOptions + /// @{ + + /** + * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) + * Use this API to find if the optional type OrtValue is None or not. + * If the optional type OrtValue is not None, use the OrtValue just like any other OrtValue. + * For example, if you get an OrtValue that corresponds to Optional(tensor) and + * if HasValue() returns true, use it as tensor and so on. + + * \param[in] value Input OrtValue. + * \param[out] out indicating if the input OrtValue contains data (1) or if it is a None (0) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel + * \see ::OrtCustomOp + * \param[in] context OrtKernelContext instance + * \param[out] out Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. + * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to + * some other EP), then a nullptr is returned as the output param. + * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. + * Only use it for custom kernel launching. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); + + /// @} + /// \name GetTensorMemoryInfo + /// @{ + /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor + * \param[in] value ::OrtValue containing tensor. + * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info); + + /// @} + /// \name GetExecutionProviderApi + /// @{ + /** \brief Get a pointer to the requested version of the Execution Provider specific + * API extensions to the OrtApi + * \param[in] provider_name The name of the execution provider name. Currently only the following + * values are supported: "DML". + * \param[in] version Must be ::ORT_API_VERSION. + * \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure. + * For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML". + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); + + /// @} + + /// \name SessionOptions + /// @{ + /** \brief Set custom thread creation function + * + * \param[in] options Session options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set creation options for custom thread + * + * \param[in] options Session options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function + * + * \param[in] options Session options + * \param[in] ort_custom_join_thread_fn Custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} + + /// \name OrtThreadingOptions + /// @{ + /** \brief Set custom thread creation function for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set custom thread creation options for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_join_thread_fn Custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} + + /** \brief Synchronize bound inputs. The call may be necessary for some providers, such as cuda, + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr); + + /** \brief Synchronize bound outputs. The call may be necessary for some providers, such as cuda, + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr); + + /// \name OrtSessionOptions + /// @{ + + /** \brief Append CUDA execution provider to the session options + * + * If CUDA is not available (due to a non CUDA enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_CUDA, it takes an + * ::OrtCUDAProviderOptions which is publicly defined. This takes an opaque ::OrtCUDAProviderOptionsV2 + * which must be created with OrtApi::CreateCUDAProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_CUDA, the user needs to instantiate ::OrtCUDAProviderOptions + * as well as allocate/release buffers for some members of ::OrtCUDAProviderOptions. + * Here, OrtApi::CreateCUDAProviderOptions and Ortapi::ReleaseCUDAProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA_V2, + _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptionsV2* cuda_options); + + /// @} + /// \name OrtCUDAProviderOptionsV2 + /// @{ + + /** \brief Create an OrtCUDAProviderOptionsV2 + * + * \param[out] out Newly created ::OrtCUDAProviderOptionsV2. Must be released with OrtApi::ReleaseCudaProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out); + + /** \brief Set options in a CUDA Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 + * and value should be its related range. + * + * For example, key="device_id" and value="0" + * + * \param[in] cuda_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptions, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized CUDA provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtCUDAProviderOptionsV2 + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.11. + */ + void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); + + /// @} + + /** \brief Append MIGraphX provider to session options + * + * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] migraphx_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); + + /** \brief Replace initialized Tensors with external data with the data provided in initializers. + * + * The function will find the initialized TensorProtos with external data in the graph with the provided names and + * replace them with the provided tensors. The API verifies that the TensorProto being replaced + * has an external data reference and has the same name, dimensions and data type as its replacement. The replacement + * will occur before any of the optimizations take place. The data will be copied into the graph + * since TensorProto can't refer to the user provided buffers. + * + * Once the model has been loaded, the OrtValue(s) added to SessionOptions instance will be removed + * from the internal SessionOptions copy to save memory, the user provided buffers can then be deallocated + * and the SessionOptions instance that refers to them can be destroyed. + * + * \param[in] options + * \param[in] initializer_names Array of null terminated UTF-8 encoded strings of the initializers names. + * \param[in] initializers Array of ::OrtValue type + * \param[in] initializers_num Number of elements in the initializer_names and initializers + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.12. + */ + ORT_API2_STATUS(AddExternalInitializers, _In_ OrtSessionOptions* options, + _In_reads_(input_len) const char* const* initializer_names, + _In_reads_(input_len) const OrtValue* const* initializers, size_t initializers_num); + + /** \brief: Create attribute of onnxruntime operator + * + * \param[in] name Name of the attribute + * \param[in] data Data content of the attribute + * \param[in] len Number of bytes stored in data + * \param[in] type Data type + * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOpAttr, + _In_ const char* name, + _In_ const void* data, + _In_ int len, + _In_ OrtOpAttrType type, + _Outptr_ OrtOpAttr** op_attr); + + /* \brief: Release op attribute + * + * \param[in] opAttr Attribute created by OrtApi::CreateOpAttr + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(OpAttr); + + /** \brief: Create onnxruntime native operator + * + * \param[in] info Kernel info + * \param[in] op_name Operator name + * \param[in] domain Operator domain + * \param[in] version Operator opset version + * \param[in] type_constraint_names Name of the type contraints, such as "T" or "T1" + * \param[in] type_constraint_values Type of each contraints + * \param[in] type_constraint_count Number of contraints + * \param[in] attr_values Attributes used to initialize the operator + * \param[in] attr_count Number of the attributes + * \param[in] input_count Number of inputs + * \param[in] output_count Number of outputs + * \param[out] ort_op Operator that has been created + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOp, + _In_ const OrtKernelInfo* info, + _In_ const char* op_name, + _In_ const char* domain, + _In_ int version, + _In_opt_ const char** type_constraint_names, + _In_opt_ const ONNXTensorElementDataType* type_constraint_values, + _In_opt_ int type_constraint_count, + _In_opt_ const OrtOpAttr* const* attr_values, + _In_opt_ int attr_count, + _In_ int input_count, + _In_ int output_count, + _Outptr_ OrtOp** ort_op); + + /** \brief: Invoke the operator created by OrtApi::CreateOp + * The inputs must follow the order as specified in onnx specification + * + * \param[in] context Kernel context + * \param[in] ort_op Operator that has been created + * \param[in] input_values Array of inputs + * \param[in] input_count Number of inputs + * \param[in] output_values Array of outputs + * \param[in] output_count Number of outputs + * + * \since Version 1.12. + */ + ORT_API2_STATUS(InvokeOp, + _In_ const OrtKernelContext* context, + _In_ const OrtOp* ort_op, + _In_ const OrtValue* const* input_values, + _In_ int input_count, + _Inout_ OrtValue* const* output_values, + _In_ int output_count); + + /* \brief: Release an onnxruntime operator + * + * \param[in] Op Operator created by OrtApi::CreateOp + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(Op); + + /** \brief: Append execution provider to the session options. + * \param[in] options + * \param[in] provider_name - provider to add. + * \param[in] provider_options_keys - keys to configure the provider options + * \param[in] provider_options_values - values to configure the provider options + * \param[in] num_keys - number of keys passed in + * + * Currently supported providers: + * SNPE + * XNNPACK + * + * Note: If an execution provider has a dedicated SessionOptionsAppendExecutionProvider_ function + * that should be used to add it. + * + * SNPE supported keys: + * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", + * "DSP", "DSP_FIXED8_TF", "AIP_FIXED_TF", "AIP_FIXED8_TF". + * Mapping to SNPE Runtime_t definition: CPU, CPU_FLOAT32 => zdl::DlSystem::Runtime_t::CPU; + * GPU, GPU_FLOAT32_16_HYBRID => zdl::DlSystem::Runtime_t::GPU; + * GPU_FLOAT16 => zdl::DlSystem::Runtime_t::GPU_FLOAT16; + * DSP, DSP_FIXED8_TF => zdl::DlSystem::Runtime_t::DSP. + * AIP_FIXED_TF, AIP_FIXED8_TF => zdl::DlSystem::Runtime_t::AIP_FIXED_TF. + * "priority": execution priority, options: "low", "normal". + * "buffer_type": ITensor or user buffers, options: "ITENSOR", user buffer with different types - "TF8", "TF16", "UINT8", "FLOAT". + * "ITENSOR" -- default, ITensor which is float only. + * "TF8" -- quantized model required, "FLOAT" -- for both quantized or non-quantized model + * If SNPE is not available (due to a non Snpe enabled build or its dependencies not being installed), this function will fail. + * + * XNNPACK supported keys: + * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. + * default value is 0, which means to use the session thread-pool size. + * + * \since Version 1.12. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, + _In_ const char* provider_name, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /* \brief: Get a copy of kernel info + * + * \param[in] info Kernel info + * \param[out] info_copy Copy of kernel info + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CopyKernelInfo, + _In_ const OrtKernelInfo* info, + _Outptr_ OrtKernelInfo** info_copy); + + /* \brief: Release kernel info + * + * \param[in] KernelInfo A copy of kernel info returned by CopyKernelInfo + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(KernelInfo); + + /* \brief: Get the training C Api + * + * \since Version 1.13 + */ + const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; + + /** \brief Append CANN provider to session options + * + * If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cann_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN, + _In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options); + + /** \brief Create an OrtCANNProviderOptions + * + * \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out); + + /** \brief Set options in a CANN Execution Provider. + * + * \param[in] cann_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get serialized CANN provider options string. + * + * \param[in] cann_options OrtCANNProviderOptions instance + * \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator() + * or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate + * continuous buffers for output strings and lengths. + * \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'. + * The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options, + _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an OrtCANNProviderOptions + * + * \param[in] the pointer of OrtCANNProviderOptions which will been deleted + * + * \since Version 1.13. + */ + void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input); + + /* \brief Get OrtDevice type from MemoryInfo + * + * \since Version 1.14 + */ + void(ORT_API_CALL* MemoryInfoGetDeviceType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemoryInfoDeviceType* out); + + /* \brief Update the OrtEnv instance with custom log severity level + * + * \param[in] ort_env The OrtEnv instance being used + * \param[in] log_severity_level The log severity level. + * + * \since Version 1.14. + */ + ORT_API2_STATUS(UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, OrtLoggingLevel log_severity_level); + + /* \brief Set affinities for intra op threads + * + * Affinity string follows format: + * logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id + * Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. + * e.g. 1,2,3;4,5 + * specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. + * To ease the configuration, an "interval" is also allowed: + * e.g. 1-8;8-16;17-24 + * orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. + * Note: + * 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, + * ort does not set affinity on the main thread which is started and managed by the calling app; + * 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, + * an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. + * Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. + * + * \since Version 1.14 + */ + ORT_API2_STATUS(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* tp_options, const char* affinity_string); + + /** \brief Register custom ops from a shared library. + * + * Loads a shared library (.dll on windows, .so on linux, etc) named 'library_name' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * + * The handle to the loaded library is automatically released by ORT when the last OrtSession that references the + * library handle is released. If no OrtSession is created, then the library handle is released when the provided + * OrtSessionOptions is released. + * + * \param[in] options The session options. + * \param[in] library_name The name of the shared library to load and register. Refer to OS-specific dynamic library + * loading utilities (e.g., LoadLibraryEx on Windows or dlopen on Linux/MacOS) for information + * on the format of library names and search paths. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); + + /** \brief Register custom ops by calling a RegisterCustomOpsFn function. + * + * Searches for registration_func_name and if found calls it. + * + * The library containing the function must either be linked against or previously loaded by the executable. + * + * If you want ONNX Runtime to load the library and manage its lifetime, use RegisterCustomOpsLibrary_V2. + * + * RegisterCustomOpsUsingFunction can be used in scenarios where it may not be possible for ONNX Runtime to load + * the library from a path. e.g. mobile platforms where the library must be linked into the app. + * + * The registration function must have the signature of RegisterCustomOpsFn: + * OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); + * + * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for details on how the registration + * function should be implemented. + * + * \param[in] options OrtSessionOptions that is passed through as the first argument in the call to the + * registration function. + * \param[in] registration_func_name Name of registration function to use. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, + _In_ const char* registration_func_name); + + /// @} + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get the number of inputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of inputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the number of outputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of outputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the name of a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an input's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the input name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the input's name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the name of a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an output's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the output name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the output's + * name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the type information for a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an input during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get the type information for a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an output during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get a ::OrtValue tensor stored as an attribute in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a tensor attribute. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] name UTF-8 null-terminated string representing the attribute's name. + * \param[in] allocator Allocator used to allocate the internal tensor state. + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue, + * which will also free internal tensor state allocated with the provided allocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); + + /// @} + /// \name OrtSessionOptions + /// Custom operator APIs + /// @{ + + /** \brief Checks if the given session configuration entry exists. + * + * The config_key formats are defined in onnxruntime_session_options_config_keys.h + * + * Can be used in a custom operator library to check for session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The ::OrtSessionOptions instance. + * \param[in] config_key A null-terminated UTF-8 string representation of the configuration key. + * \param[out] out Pointer set to 1 if the entry exists and 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(HasSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ int* out); + + /** \brief Get a session configuration value. + * + * Returns a failure status if the configuration key does not exist. + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * If `config_value` is nullptr, the value of `size` is set to the true size of the string + * value (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string value's size, + * the value of `size` is set to the true size of the string value, the provided memory + * is filled with the value's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string value's size and `config_value` + * is not nullptr, the value of `size` is set to the true size of the string value + * and a failure status is returned. + * + * Can be used in a custom operator library to get session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The session options. + * \param[in] config_key A null-terminated UTF-8 string representation of the config key. + * \param[in] config_value Pointer to memory where the null-terminated UTF-8 string value will be stored. + * \param[in,out] size Pointer to the size of the `config_value` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(GetSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ char* config_value, _Inout_ size_t* size); + + /// @} + +#ifdef __cplusplus + OrtApi(const OrtApi&) = delete; // Prevent users from accidentally copying the API structure, it should always be passed as a pointer +#endif +}; + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options + */ + +// Specifies some characteristics of inputs/outputs of custom ops: +// Specify if the inputs/outputs are one of: +// 1) Non-optional (input/output must be present in the node) +// 2) Optional (input/output may be absent in the node) +// 3) Variadic: A variadic input or output specifies N (i.e., the minimum arity) or more operands. +// Only the last input or output of a custom op may be marked as variadic. +// The homogeneity of the variadic input or output determines whether all operands must be of the same +// tensor element type. +typedef enum OrtCustomOpInputOutputCharacteristic { + INPUT_OUTPUT_REQUIRED = 0, + INPUT_OUTPUT_OPTIONAL, + INPUT_OUTPUT_VARIADIC, +} OrtCustomOpInputOutputCharacteristic; + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. + */ +struct OrtCustomOp { + uint32_t version; // Must be initialized to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. + void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); + + // Returns the type of the execution provider, return nullptr to use CPU execution provider + const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); + + // Op kernel callbacks + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); + + // Returns the characteristics of the input & output tensors + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the memory type of the input tensors. This API allows the custom op + // to place the inputs on specific devices. By default, it returns + // OrtMemTypeDefault, which means the input is placed on the default device for + // the execution provider. If the inputs need to be with different memory tyeps, + // this function can be overridden to return the specific memory types. + OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the minimum number of input arguments expected for the variadic input. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all arguments of a variadic input have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Returns the minimum number of output values expected for the variadic output. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all outputs values of a variadic output have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); +}; + +/* + * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists + * + * \param device_id CUDA device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the MIGraphX provider to the session, please use + * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the MIGraphX provider shared library exists + * + * \param device_id HIP device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); + +#ifdef __cplusplus +} +#endif + +//! @} diff --git a/ThirdParty/onnxruntime/include/onnxruntime_cxx_api.h b/ThirdParty/onnxruntime/include/onnxruntime_cxx_api.h new file mode 100644 index 0000000..97b2aa4 --- /dev/null +++ b/ThirdParty/onnxruntime/include/onnxruntime_cxx_api.h @@ -0,0 +1,1876 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary: The Ort C++ API is a header only wrapper around the Ort C API. +// +// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors +// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so +// all the resources follow RAII and do not leak memory. +// +// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them +// until you assign an instance that actually holds an underlying object. +// +// For Ort objects only move assignment between objects is allowed, there are no copy constructors. +// Some objects have explicit 'Clone' methods for this purpose. +// +// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments +// by value or by reference. ConstXXXX types are restricted to const only interfaces. +// +// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. +// +// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not +// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. + +#pragma once +#include "onnxruntime_c_api.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef ORT_NO_EXCEPTIONS +#include +#endif + +/** \brief All C++ Onnxruntime APIs are defined inside this namespace + * + */ +namespace Ort { + +/** \brief All C++ methods that can fail will throw an exception of this type + * + * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() + */ +struct Exception : std::exception { + Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} + + OrtErrorCode GetOrtErrorCode() const { return code_; } + const char* what() const noexcept override { return message_.c_str(); } + + private: + std::string message_; + OrtErrorCode code_; +}; + +#ifdef ORT_NO_EXCEPTIONS +// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. +// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW +#ifndef ORT_CXX_API_THROW +#define ORT_CXX_API_THROW(string, code) \ + do { \ + std::cerr << Ort::Exception(string, code) \ + .what() \ + << std::endl; \ + abort(); \ + } while (false) +#endif +#else +#define ORT_CXX_API_THROW(string, code) \ + throw Ort::Exception(string, code) +#endif + +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, +// it's in a template so that we can define a global variable in a header and make +// it transparent to the users of the API. +template +struct Global { + static const OrtApi* api_; +}; + +// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. +template +#ifdef ORT_API_MANUAL_INIT +const OrtApi* Global::api_{}; +inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } + +// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is +// required by C++ APIs. +// +// Example mycustomop.cc: +// +// #define ORT_API_MANUAL_INIT +// #include +// #undef ORT_API_MANUAL_INIT +// +// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { +// Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); +// // ... +// } +// +inline void InitApi(const OrtApi* api) { Global::api_ = api; } +#else +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. +// Please define ORT_API_MANUAL_INIT if it conerns you. +#pragma warning(disable : 26426) +#endif +const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif +#endif + +/// This returns a reference to the OrtApi interface in use +inline const OrtApi& GetApi() { return *Global::api_; } + +/// +/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and +/// returns a vector of strings representing the available execution providers. +/// +/// vector of strings +std::vector GetAvailableProviders(); + +/** \brief IEEE 754 half-precision floating point data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint16_t. + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. + * + * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor + * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. + * And you can also feed a array of uint16_t elements directly. For example, + * + * \code{.unparsed} + * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; + * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); + * std::vector dims = {values_length}; // one dimensional example + * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) + * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), + * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); + * \endcode + * + * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use + * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra + * template specialization. + * + * \code{.unparsed} + * namespace yours { struct half {}; } // assume this is your type, define this: + * namespace Ort { + * template<> + * struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; + * } //namespace Ort + * + * std::vector values; + * std::vector dims = {values.size()}; // one dimensional example + * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + * // Here we are passing element count -> values.size() + * auto float16_tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size()); + * + * \endcode + */ +struct Float16_t { + uint16_t value; + constexpr Float16_t() noexcept : value(0) {} + constexpr Float16_t(uint16_t v) noexcept : value(v) {} + constexpr operator uint16_t() const noexcept { return value; } + constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief bfloat16 (Brain Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint16_t. + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. + * + * See also code examples for Float16_t above. + */ +struct BFloat16_t { + uint16_t value; + constexpr BFloat16_t() noexcept : value(0) {} + constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} + constexpr operator uint16_t() const noexcept { return value; } + constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); + +namespace detail { +// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type +// This can't be done in the C API since C doesn't have function overloading. +#define ORT_DEFINE_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } + +ORT_DEFINE_RELEASE(Allocator); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(CustomOpDomain); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(Env); +ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(Session); +ORT_DEFINE_RELEASE(SessionOptions); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(SequenceTypeInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(ArenaCfg); +ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(KernelInfo); + +#undef ORT_DEFINE_RELEASE + +/** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object + * has no ownership of the underlying C object. + */ +template +struct Unowned { + using Type = T; +}; + +/** \brief Used internally by the C++ API. C++ wrapper types inherit from this. + * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. + * + * All of the C++ classes + * a) serve as containers for pointers to objects that are created by the underlying C API. + * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. + * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. + * they would release objects owned automatically when going out of scope, they are move-only. + * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. + * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else + * such as Onnxruntime or instances of XXXX classes. + * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used + * in C++ code. + * + */ + +/// +/// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. +/// +template +struct Base { + using contained_type = T; + + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} + ~Base() { OrtRelease(p_); } + + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + Base& operator=(Base&& v) noexcept { + OrtRelease(p_); + p_ = v.release(); + return *this; + } + + constexpr operator contained_type*() const noexcept { return p_; } + + /// \brief Relinquishes ownership of the contained C object pointer + /// The underlying object is not destroyed + contained_type* release() { + T* p = p_; + p_ = nullptr; + return p; + } + + protected: + contained_type* p_{}; +}; + +// Undefined. For const types use Base> +template +struct Base; + +/// +/// Covers unowned pointers owned by either the ORT +/// or some other instance of CPP wrappers. +/// Used for ConstXXX and UnownedXXXX types that are copyable. +/// Also convenient to wrap raw OrtXX pointers . +/// +/// +template +struct Base> { + using contained_type = typename Unowned::Type; + + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} + + ~Base() = default; + + Base(const Base&) = default; + Base& operator=(const Base&) = default; + + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + Base& operator=(Base&& v) noexcept { + p_ = nullptr; + std::swap(p_, v.p_); + return *this; + } + + constexpr operator contained_type*() const noexcept { return p_; } + + protected: + contained_type* p_{}; +}; + +// Light functor to release memory with OrtAllocator +struct AllocatedFree { + OrtAllocator* allocator_; + explicit AllocatedFree(OrtAllocator* allocator) + : allocator_(allocator) {} + void operator()(void* ptr) const { + if (ptr) allocator_->Free(allocator_, ptr); + } +}; + +} // namespace detail + +struct AllocatorWithDefaultOptions; +struct Env; +struct TypeInfo; +struct Value; +struct ModelMetadata; + +/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators + * and release them at the end of the scope. The lifespan of the given allocator + * must eclipse the lifespan of AllocatedStringPtr instance + */ +using AllocatedStringPtr = std::unique_ptr; + +/** \brief The Status that holds ownership of OrtStatus received from C API + * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate + * constructors to construct an instance of a Status object from exceptions. + */ +struct Status : detail::Base { + explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null + explicit Status(const Exception&); ///< Creates status instance out of exception + explicit Status(const std::exception&); ///< Creates status instance out of exception + std::string GetErrorMessage() const; + OrtErrorCode GetErrorCode() const; +}; + +/** \brief The ThreadingOptions + * + * The ThreadingOptions used for set global threadpools' options of The Env. + */ +struct ThreadingOptions : detail::Base { + /// \brief Wraps OrtApi::CreateThreadingOptions + ThreadingOptions(); + + /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads + ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads + ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalSpinControl + ThreadingOptions& SetGlobalSpinControl(int allow_spinning); + + /// \brief Wraps OrtApi::SetGlobalDenormalAsZero + ThreadingOptions& SetGlobalDenormalAsZero(); + + /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn + ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions + ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); + + /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn + ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); +}; + +/** \brief The Env (Environment) + * + * The Env holds the logging state used by all other objects. + * Note: One Env must be created before using any other Onnxruntime functionality + */ +struct Env : detail::Base { + explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used + + /// \brief Wraps OrtApi::CreateEnv + Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief Wraps OrtApi::CreateEnvWithCustomLogger + Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); + + /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools + Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools + Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, + OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief C Interop Helper + explicit Env(OrtEnv* p) : Base{p} {} + + Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents + Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents + + Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel + + Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator +}; + +/** \brief Custom Op Domain + * + */ +struct CustomOpDomain : detail::Base { + explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used + + /// \brief Wraps OrtApi::CreateCustomOpDomain + explicit CustomOpDomain(const char* domain); + + // This does not take ownership of the op, simply registers it. + void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add +}; + +/** \brief RunOptions + * + */ +struct RunOptions : detail::Base { + explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used + RunOptions(); ///< Wraps OrtApi::CreateRunOptions + + RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel + int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel + + RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel + int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel + + RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag + const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag + + RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry + + /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error + * Wraps OrtApi::RunOptionsSetTerminate + */ + RunOptions& SetTerminate(); + + /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating + * + * Wraps OrtApi::RunOptionsUnsetTerminate + */ + RunOptions& UnsetTerminate(); +}; + + +namespace detail { +// Utility function that returns a SessionOption config entry key for a specific custom operator. +// Ex: custom_op.[custom_op_name].[config] +std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); +} // namespace detail + +/// +/// Class that represents session configuration entries for one or more custom operators. +/// +/// Example: +/// Ort::CustomOpConfigs op_configs; +/// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); +/// +/// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. +/// +struct CustomOpConfigs { + CustomOpConfigs() = default; + ~CustomOpConfigs() = default; + CustomOpConfigs(const CustomOpConfigs&) = default; + CustomOpConfigs& operator=(const CustomOpConfigs&) = default; + CustomOpConfigs(CustomOpConfigs&& o) = default; + CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; + + /** \brief Adds a session configuration entry/value for a specific custom operator. + * + * \param custom_op_name The name of the custom operator for which to add a configuration entry. + * Must match the name returned by the CustomOp's GetName() method. + * \param config_key The name of the configuration entry. + * \param config_value The value of the configuration entry. + * \return A reference to this object to enable call chaining. + */ + CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); + + /** \brief Returns a flattened map of custom operator configuration entries and their values. + * + * The keys has been flattened to include both the custom operator name and the configuration entry key name. + * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair + * {"my_op.key", "value"}. + * + * \return An unordered map of flattened configurations. + */ + const std::unordered_map& GetFlattenedConfigs() const; + + private: + std::unordered_map flat_configs_; +}; + +/** \brief Options object used when creating a new Session object + * + * Wraps ::OrtSessionOptions object and methods + */ + +struct SessionOptions; + +namespace detail { +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionOptionsImpl : Base { + using B = Base; + using B::B; + + SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions + + std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry + bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry + std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); +}; + +template +struct SessionOptionsImpl : ConstSessionOptionsImpl { + using B = ConstSessionOptionsImpl; + using B::B; + + SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads + SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads + SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel + + SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena + SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena + + SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath + + SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling + SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling + + SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps + + SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern + SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern + + SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId + SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel + + SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain + + SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads + + SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry + + SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer + SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers + + SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA + SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 + SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM + SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK. + SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, + const std::unordered_map& provider_options = {}); + + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn + SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions + SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn + + ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. + ///< The custom operator configurations are optional. If provided, custom operator configs are set via + ///< OrtApi::AddSessionConfigEntry. + SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); + + SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction +}; +} // namespace detail + +using UnownedSessionOptions = detail::SessionOptionsImpl>; +using ConstSessionOptions = detail::ConstSessionOptionsImpl>; + +/** \brief Wrapper around ::OrtSessionOptions + * + */ +struct SessionOptions : detail::SessionOptionsImpl { + explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used + SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions + explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl{p} {} ///< Used for interop with the C API + UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } + ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } +}; + +/** \brief Wrapper around ::OrtModelMetadata + * + */ +struct ModelMetadata : detail::Base { + explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used + explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API + + /** \brief Returns a copy of the producer name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName + + /** \brief Returns a copy of the graph name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName + + /** \brief Returns a copy of the domain name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + + /** \brief Returns a copy of the description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription + + /** \brief Returns a copy of the graph description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription + + /** \brief Returns a vector of copies of the custom metadata keys. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. + * The OrtAllocator instance must be valid at the point of memory release. + */ + std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys + + /** \brief Looks up a value by a key in the Custom Metadata map + * + * \param key zero terminated string key to lookup + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * maybe nullptr if key is not found. + * + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap + + int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion +}; + +struct IoBinding; + +namespace detail { + +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionImpl : Base { + using B = Base; + using B::B; + + size_t GetInputCount() const; ///< Returns the number of model inputs + size_t GetOutputCount() const; ///< Returns the number of model outputs + size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden + + /** \brief Returns a copy of input name at the specified index. + * + * \param index must less than the value returned by GetInputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of output name at then specified index. + * + * \param index must less than the value returned by GetOutputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of the overridable initializer name at then specified index. + * + * \param index must less than the value returned by GetOverridableInitializerCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName + + uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata + + TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo + TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo + TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo +}; + +template +struct SessionImpl : ConstSessionImpl { + using B = ConstSessionImpl; + using B::B; + + /** \brief Run the model returning results in an Ort allocated vector. + * + * Wraps OrtApi::Run + * + * The caller provides a list of inputs and a list of the desired outputs to return. + * + * See the output logs for more information on warnings/errors that occur while processing the model. + * Common errors are.. (TODO) + * + * \param[in] run_options + * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names + * \param[in] input_values Array of Value objects of length input_count that is the list of input values + * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) + * \param[in] output_names Array of C style strings of length output_count that is the list of output names + * \param[in] output_count Number of outputs (the size of the output_names array) + * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) + */ + std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count); + + /** \brief Run the model returning results in user provided outputs + * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) + */ + void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count); + + void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding + + /** \brief End profiling and return a copy of the profiling file name. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling +}; + +} // namespace detail + +using ConstSession = detail::ConstSessionImpl>; +using UnownedSession = detail::SessionImpl>; + +/** \brief Wrapper around ::OrtSession + * + */ +struct Session : detail::SessionImpl { + explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + + ConstSession GetConst() const { return ConstSession{this->p_}; } + UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } +}; + +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; + + std::string GetAllocatorName() const; + OrtAllocatorType GetAllocatorType() const; + int GetDeviceId() const; + OrtMemoryInfoDeviceType GetDeviceType() const; + OrtMemType GetMemoryType() const; + + template + bool operator==(const MemoryInfoImpl& o) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; + +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; + +namespace detail { +template +struct TensorTypeAndShapeInfoImpl : Base { + using B = Base; + using B::B; + + ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType + size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount + + size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount + + /** \deprecated use GetShape() returning std::vector + * [[deprecated]] + * This interface is unsafe to use + */ + [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions + + void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + + std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape +}; + +} // namespace detail + +using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl>; + +/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo + * + */ +struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl { + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } +}; + +namespace detail { +template +struct SequenceTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType +}; + +} // namespace detail + +using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl>; + +/** \brief Wrapper around ::OrtSequenceTypeInfo + * + */ +struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { + explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used + explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct MapTypeInfoImpl : detail::Base { + using B = Base; + using B::B; + ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType + TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType +}; + +} // namespace detail + +using ConstMapTypeInfo = detail::MapTypeInfoImpl>; + +/** \brief Wrapper around ::OrtMapTypeInfo + * + */ +struct MapTypeInfo : detail::MapTypeInfoImpl { + explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used + explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct TypeInfoImpl : detail::Base { + using B = Base; + using B::B; + + ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo + ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo + ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + + ONNXType GetONNXType() const; +}; +} // namespace detail + +/// +/// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. +/// Provides access to const OrtTypeInfo APIs. +/// +using ConstTypeInfo = detail::TypeInfoImpl>; + +/// +/// Type information that may contain either TensorTypeAndShapeInfo or +/// the information about contained sequence or map depending on the ONNXType. +/// +struct TypeInfo : detail::TypeInfoImpl { + explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop + + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } +}; + +namespace detail { +// This structure is used to feed sparse tensor values +// information for use with FillSparseTensor() API +// if the data type for the sparse tensor values is numeric +// use data.p_data, otherwise, use data.str pointer to feed +// values. data.str is an array of const char* that are zero terminated. +// number of strings in the array must match shape size. +// For fully sparse tensors use shape {0} and set p_data/str +// to nullptr. +struct OrtSparseValuesParam { + const int64_t* values_shape; + size_t values_shape_len; + union { + const void* p_data; + const char** str; + } data; +}; + +// Provides a way to pass shape in a single +// argument +struct Shape { + const int64_t* shape; + size_t shape_len; +}; + +template +struct ConstValueImpl : Base { + using B = Base; + using B::B; + + /// + /// Obtains a pointer to a user defined data for experimental purposes + /// + template + void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue + + bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc + bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None + + size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements + Value GetValue(int index, OrtAllocator* allocator) const; + + /// + /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. + /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful + /// for allocating necessary memory and calling GetStringTensorContent(). + /// + /// total length of UTF-8 encoded bytes contained. No zero terminators counted. + size_t GetStringTensorDataLength() const; + + /// + /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor + /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. + /// The user must also allocate offsets buffer with the number of entries equal to that of the contained + /// strings. + /// + /// Strings are always assumed to be on CPU, no X-device copy. + /// + /// user allocated buffer + /// length in bytes of the allocated buffer + /// a pointer to the offsets user allocated buffer + /// count of offsets, must be equal to the number of strings contained. + /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() + /// for sparse tensors + void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + + /// + /// Returns a const typed pointer to the tensor contained data. + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// + /// const pointer to data, no copies made + template + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + + /// + /// Returns a non-typed pointer to a tensor contained data. + /// + /// const pointer to data, no copies made + const void* GetTensorRawData() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TypeInfo + TypeInfo GetTypeInfo() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TensorTypeAndShapeInfo + TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + + /// + /// This API returns information about the memory allocation used to hold data. + /// + /// Non owning instance of MemoryInfo + ConstMemoryInfo GetTensorMemoryInfo() const; + + /// + /// The API copies UTF-8 encoded bytes for the requested string element + /// contained within a tensor or a sparse tensor into a provided buffer. + /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. + /// + /// + /// + /// + void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + + /// + /// The API returns a byte length of UTF-8 encoded string element + /// contained in either a tensor or a spare tensor values. + /// + /// + /// byte length for the specified string element + size_t GetStringTensorElementLength(size_t element_index) const; + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// The API returns the sparse data format this OrtValue holds in a sparse tensor. + /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used + /// the value returned is ORT_SPARSE_UNDEFINED. + /// + /// Format enum + OrtSparseFormat GetSparseFormat() const; + + /// + /// The API returns type and shape information for stored non-zero values of the + /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. + /// + /// TensorTypeAndShapeInfo values information + TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; + + /// + /// The API returns type and shape information for the specified indices. Each supported + /// indices have their own enum values even if a give format has more than one kind of indices. + /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. + /// + /// enum requested + /// type and shape information + TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; + + /// + /// The API retrieves a pointer to the internal indices buffer. The API merely performs + /// a convenience data type casting on the return type pointer. Make sure you are requesting + /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); + /// + /// type to cast to + /// requested indices kind + /// number of indices entries + /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. + template + const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; + + /// + /// Returns true if the OrtValue contains a sparse tensor + /// + /// + bool IsSparseTensor() const; + + /// + /// The API returns a pointer to an internal buffer of the sparse tensor + /// containing non-zero values. The API merely does casting. Make sure you + /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() + /// first. + /// + /// numeric data types only. Use GetStringTensor*() to retrieve strings. + /// a pointer to the internal values buffer. Do not free this pointer. + template + const R* GetSparseTensorValues() const; + +#endif +}; + +template +struct ValueImpl : ConstValueImpl { + using B = ConstValueImpl; + using B::B; + + /// + /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// non-const pointer to data, no copies made + template + R* GetTensorMutableData(); + + /// + /// Returns a non-typed non-const pointer to a tensor contained data. + /// + /// pointer to data, no copies made + void* GetTensorMutableRawData(); + + /// + // Obtain a reference to an element of data at the location specified + /// by the vector of dims. + /// + /// + /// [in] expressed by a vecotr of dimensions offsets + /// + template + R& At(const std::vector& location); + + /// + /// Set all strings at once in a string tensor + /// + /// [in] An array of strings. Each string in this array must be null terminated. + /// [in] Count of strings in s (Must match the size of \p value's tensor shape) + void FillStringTensor(const char* const* s, size_t s_len); + + /// + /// Set a single string in a string tensor + /// + /// [in] A null terminated UTF-8 encoded string + /// [in] Index of the string in the tensor to set + void FillStringTensorElement(const char* s, size_t index); + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors. + /// number of indices entries. Use 0 for fully sparse tensors + void UseCooIndices(int64_t* indices_data, size_t indices_num); + + /// + /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors + /// number of csr inner indices or 0 for fully sparse tensors + /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors + /// number of csr outer indices or 0 for fully sparse tensors + void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); + + /// + /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// indices shape or a {0} for fully sparse + /// user allocated buffer with indices or nullptr for fully spare tensors + void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information. + /// coo indices buffer or nullptr for fully sparse data + /// number of COO indices or 0 for fully sparse data + void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, + const int64_t* indices_data, size_t indices_num); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information + /// csr inner indices pointer or nullptr for fully sparse tensors + /// number of csr inner indices or 0 for fully sparse tensors + /// pointer to csr indices data or nullptr for fully sparse tensors + /// number of csr outer indices or 0 + void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information + /// indices shape. use {0} for fully sparse tensors + /// pointer to indices data or nullptr for fully sparse tensors + void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data); + +#endif +}; + +} // namespace detail + +using ConstValue = detail::ConstValueImpl>; +using UnownedValue = detail::ValueImpl>; + +/** \brief Wrapper around ::OrtValue + * + */ +struct Value : detail::ValueImpl { + using Base = detail::ValueImpl; + using OrtSparseValuesParam = detail::OrtSparseValuesParam; + using Shape = detail::Shape; + + explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API + Value(Value&&) = default; + Value& operator=(Value&&) = default; + + ConstValue GetConst() const { return ConstValue{this->p_}; } + UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_element_count The number of elements in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ + template + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ + template + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); + + /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + + static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue + static Value CreateSequence(std::vector& values); ///< Wraps OrtApi::CreateValue + + template + static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// This is a simple forwarding method to the other overload that helps deducing + /// data type enum value from the type of the buffer. + /// + /// numeric datatype. This API is not suitable for strings. + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// + template + static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape); + + /// + /// Creates an OrtValue instance containing SparseTensor. This constructs + /// a sparse tensor that makes use of user allocated buffers. It does not make copies + /// of the user provided data and does not modify it. The lifespan of user provided buffers should + /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain + /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below + /// to supply a sparse format specific indices. + /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings + /// can be properly copied into the allocated buffer. + /// + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// data type + /// Ort::Value instance containing SparseTensor + static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type); + + /// + /// This is a simple forwarding method to the below CreateSparseTensor. + /// This helps to specify data type enum in terms of C++ data type. + /// Use CreateSparseTensor + /// + /// numeric data type only. String data enum must be specified explicitly. + /// allocator to use + /// a would be dense shape of the tensor + /// Ort::Value + template + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); + + /// + /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. + /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values + /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. + /// Use this API to create OrtValues that contain sparse tensors with all supported data types including + /// strings. + /// + /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue + /// a would be dense shape of the tensor + /// data type + /// an instance of Ort::Value + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); + +#endif // !defined(DISABLE_SPARSE_TENSORS) +}; + +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&) noexcept; + MemoryAllocation& operator=(MemoryAllocation&&) noexcept; + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; + + void* Alloc(size_t size); + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + ConstMemoryInfo GetInfo() const; +}; + +} // namespace detail + +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); +}; + +/** \brief Wrapper around ::OrtAllocator + * + */ +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; + +using UnownedAllocator = detail::AllocatorImpl>; + +namespace detail { +namespace binding_utils { +// Bring these out of template +std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); +std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); +} // namespace binding_utils + +template +struct ConstIoBindingImpl : Base { + using B = Base; + using B::B; + + std::vector GetOutputNames() const; + std::vector GetOutputNames(OrtAllocator*) const; + std::vector GetOutputValues() const; + std::vector GetOutputValues(OrtAllocator*) const; +}; + +template +struct IoBindingImpl : ConstIoBindingImpl { + using B = ConstIoBindingImpl; + using B::B; + + void BindInput(const char* name, const Value&); + void BindOutput(const char* name, const Value&); + void BindOutput(const char* name, const OrtMemoryInfo*); + void ClearBoundInputs(); + void ClearBoundOutputs(); + void SynchronizeInputs(); + void SynchronizeOutputs(); +}; + +} // namespace detail + +using ConstIoBinding = detail::ConstIoBindingImpl>; +using UnownedIoBinding = detail::IoBindingImpl>; + +/** \brief Wrapper around ::OrtIoBinding + * + */ +struct IoBinding : detail::IoBindingImpl { + explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. + explicit IoBinding(Session& session); + ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } + UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } +}; + +/*! \struct Ort::ArenaCfg + * \brief it is a structure that represents the configuration of an arena based allocator + * \details Please see docs/C_API.md for details + */ +struct ArenaCfg : detail::Base { + explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used + /** + * Wraps OrtApi::CreateArenaCfg + * \param max_mem - use 0 to allow ORT to choose the default + * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default + * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default + * See docs/C_API.md for details on what the following parameters mean and how to choose these values + */ + ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); +}; + +// +// Custom OPs (only needed to implement custom OPs) +// + +/// +/// This struct provides life time management for custom op attribute +/// +struct OpAttr : detail::Base { + OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); +}; + +/// +/// This class wraps a raw pointer OrtKernelContext* that is being passed +/// to the custom kernel Compute() method. Use it to safely access context +/// attributes, input and output parameters with exception safety guarantees. +/// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +/// +struct KernelContext { + explicit KernelContext(OrtKernelContext* context); + size_t GetInputCount() const; + size_t GetOutputCount() const; + ConstValue GetInput(size_t index) const; + UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; + UnownedValue GetOutput(size_t index, const std::vector& dims) const; + void* GetGPUComputeStream() const; + + private: + OrtKernelContext* ctx_; +}; + +struct KernelInfo; + +namespace detail { +namespace attr_utils { +void GetAttr(const OrtKernelInfo* p, const char* name, float&); +void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); +void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +} // namespace attr_utils + +template +struct KernelInfoImpl : Base { + using B = Base; + using B::B; + + KernelInfo Copy() const; + + template // R is only implemented for float, int64_t, and string + R GetAttribute(const char* name) const { + R val; + attr_utils::GetAttr(this->p_, name, val); + return val; + } + + template // R is only implemented for std::vector, std::vector + std::vector GetAttributes(const char* name) const { + std::vector result; + attr_utils::GetAttrs(this->p_, name, result); + return result; + } + + Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; + + size_t GetInputCount() const; + size_t GetOutputCount() const; + + std::string GetInputName(size_t index) const; + std::string GetOutputName(size_t index) const; + + TypeInfo GetInputTypeInfo(size_t index) const; + TypeInfo GetOutputTypeInfo(size_t index) const; +}; + +} // namespace detail + +using ConstKernelInfo = detail::KernelInfoImpl>; + +/// +/// This struct owns the OrtKernInfo* pointer when a copy is made. +/// For convenient wrapping of OrtKernelInfo* passed to kernel constructor +/// and query attributes, warp the pointer with Ort::Unowned instance +/// so it does not destroy the pointer the kernel does not own. +/// +struct KernelInfo : detail::KernelInfoImpl { + explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later + explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance + ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } +}; + +/// +/// Create and own custom defined operation. +/// +struct Op : detail::Base { + explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used + + explicit Op(OrtOp*); ///< Take ownership of the OrtOp + + static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, + int version, const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, + size_t attr_count, + size_t input_count, size_t output_count); + + void Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count); + + // For easier refactoring + void Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count); +}; + +/// +/// This entire structure is deprecated, but we not marking +/// it as a whole yet since we want to preserve for the next release. +/// +struct CustomOpApi { + CustomOpApi(const OrtApi& api) : api_(api) {} + + /** \deprecated use Ort::Value::GetTensorTypeAndShape() + * [[deprecated]] + * This interface produces a pointer that must be released. Not exception safe. + */ + [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); + + /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); + + /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); + + /** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); + + /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); + + /** \deprecated + * [[deprecated]] + * This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue. + */ + [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + /** \deprecated use Ort::Value::GetTensorMutableData() + * [[deprecated]] + * This interface is redundant. + */ + template + [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value); + + /** \deprecated use Ort::Value::GetTensorData() + * [[deprecated]] + * This interface is redundant. + */ + template + [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value); + + /** \deprecated use Ort::Value::GetTensorMemoryInfo() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value); + + /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); + + /** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership. + * [[deprecated]] + * This interface is not exception safe. + */ + [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); + + /** \deprecated use Ort::KernelContext::GetInputCount + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context); + + /** \deprecated use Ort::KernelContext::GetInput + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); + + /** \deprecated use Ort::KernelContext::GetOutputCount + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context); + + /** \deprecated use Ort::KernelContext::GetOutput + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); + + /** \deprecated use Ort::KernelContext::GetGPUComputeStream + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context); + + /** \deprecated use Ort::ThrowOnError() + * [[deprecated]] + * This interface is redundant. + */ + [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result); + + /** \deprecated use Ort::OpAttr + * [[deprecated]] + * This interface is not exception safe. + */ + [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name, + _In_ const void* data, + _In_ int len, + _In_ OrtOpAttrType type); + + /** \deprecated use Ort::OpAttr + * [[deprecated]] + * This interface is not exception safe. + */ + [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr); + + /** \deprecated use Ort::Op + * [[deprecated]] + * This interface is not exception safe. + */ + [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info, + _In_ const char* op_name, + _In_ const char* domain, + _In_ int version, + _In_opt_ const char** type_constraint_names, + _In_opt_ const ONNXTensorElementDataType* type_constraint_values, + _In_opt_ int type_constraint_count, + _In_opt_ const OrtOpAttr* const* attr_values, + _In_opt_ int attr_count, + _In_ int input_count, + _In_ int output_count); + + /** \deprecated use Ort::Op::Invoke + * [[deprecated]] + * This interface is redundant + */ + [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context, + _In_ const OrtOp* ort_op, + _In_ const OrtValue* const* input_values, + _In_ int input_count, + _Inout_ OrtValue* const* output_values, + _In_ int output_count); + + /** \deprecated use Ort::Op for automatic lifespan management. + * [[deprecated]] + * This interface is not exception safe. + */ + [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op); + + /** \deprecated use Ort::KernelInfo for automatic lifespan management or for + * querying attributes + * [[deprecated]] + * This interface is redundant + */ + template // T is only implemented for std::vector, std::vector, float, int64_t, and string + [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); + + /** \deprecated use Ort::KernelInfo::Copy + * querying attributes + * [[deprecated]] + * This interface is not exception safe + */ + [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info); + + /** \deprecated use Ort::KernelInfo for lifespan management + * querying attributes + * [[deprecated]] + * This interface is not exception safe + */ + [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy); + + private: + const OrtApi& api_; +}; + +template +struct CustomOpBase : OrtCustomOp { + CustomOpBase() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; + OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; + OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26409) +#endif + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicInputMinArity(); }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; + } + + // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider + const char* GetExecutionProviderType() const { return nullptr; } + + // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below + // (inputs and outputs are required by default) + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault + OrtMemType GetInputMemoryType(size_t /*index*/) const { + return OrtMemTypeDefault; + } + + // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input + // should expect at least 1 argument. + int GetVariadicInputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments + // to a variadic input should be of the same type. + bool GetVariadicInputHomogeneity() const { + return true; + } + + // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output + // should produce at least 1 output value. + int GetVariadicOutputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values + // produced by a variadic output should be of the same type. + bool GetVariadicOutputHomogeneity() const { + return true; + } + + // Declare list of session config entries used by this Custom Op. + // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). + // This default implementation returns an empty vector of config entries. + std::vector GetSessionConfigKeys() const { + return std::vector{}; + } + + protected: + // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. + void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; +}; + +} // namespace Ort + +#include "onnxruntime_cxx_inline.h" diff --git a/ThirdParty/onnxruntime/include/onnxruntime_cxx_inline.h b/ThirdParty/onnxruntime/include/onnxruntime_cxx_inline.h new file mode 100644 index 0000000..6d391ad --- /dev/null +++ b/ThirdParty/onnxruntime/include/onnxruntime_cxx_inline.h @@ -0,0 +1,1874 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead. +// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead. +// +// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter +// the main C++ file with implementation details. + +namespace Ort { + +namespace detail { +inline void ThrowStatus(const Status& st) { + std::string error_message = st.GetErrorMessage(); + OrtErrorCode error_code = st.GetErrorCode(); + ORT_CXX_API_THROW(std::move(error_message), error_code); +} +} // namespace detail + +inline void ThrowOnError(OrtStatus* ort_status) { + if (ort_status) { + Ort::Status st(ort_status); + detail::ThrowStatus(st); + } +} + +inline void ThrowOnError(const Status& st) { + if (st) { + detail::ThrowStatus(st); + } +} + +inline Status::Status(OrtStatus* status) : Base{status} { +} + +inline Status::Status(const std::exception& e) { + p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); +} + +inline Status::Status(const Exception& e) { + p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); +} + +inline std::string Status::GetErrorMessage() const { + std::string message(GetApi().GetErrorMessage(p_)); + return message; +} + +inline OrtErrorCode Status::GetErrorCode() const { + return GetApi().GetErrorCode(p_); +} + +// This template converts a C++ type into it's ONNXTensorElementDataType +template +struct TypeToTensorType; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; +}; + +inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) + : allocator_(allocator), p_(p), size_(size) { +} + +inline MemoryAllocation::~MemoryAllocation() { + if (p_ != nullptr) { + // We do not throw out of destructor + auto ret = GetApi().AllocatorFree(allocator_, p_); + static_cast(ret); + } +} + +inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) { + *this = std::move(o); +} + +inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept { + OrtAllocator* alloc = nullptr; + void* p = nullptr; + size_t sz = 0; + + // Swap out this + std::swap(alloc, allocator_); + std::swap(p, p_); + std::swap(sz, size_); + + // Swap with incoming + std::swap(allocator_, o.allocator_); + std::swap(p_, o.p_); + std::swap(size_, o.size_); + + // Destroy this instance if needed + MemoryAllocation this_alloc(alloc, p, sz); + return *this; +} + +namespace detail { + +template +inline void* AllocatorImpl::Alloc(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); + return out; +} + +template +inline MemoryAllocation AllocatorImpl::GetAllocation(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); + MemoryAllocation result(this->p_, out, size); + return result; +} + +template +inline void AllocatorImpl::Free(void* p) { + ThrowOnError(GetApi().AllocatorFree(this->p_, p)); +} + +template +inline ConstMemoryInfo AllocatorImpl::GetInfo() const { + const OrtMemoryInfo* out; + ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out)); + return ConstMemoryInfo{out}; +} + +} // namespace detail + +inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { + ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_)); +} + +inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_)); +} + +namespace detail { + +template +inline std::string MemoryInfoImpl::GetAllocatorName() const { + const char* name = nullptr; + ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name)); + return std::string(name); +} + +template +inline OrtAllocatorType MemoryInfoImpl::GetAllocatorType() const { + OrtAllocatorType type; + ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type)); + return type; +} + +template +inline int MemoryInfoImpl::GetDeviceId() const { + int id = 0; + ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id)); + return id; +} + +template +inline OrtMemoryInfoDeviceType MemoryInfoImpl::GetDeviceType() const { + OrtMemoryInfoDeviceType type; + GetApi().MemoryInfoGetDeviceType(this->p_, &type); + return type; +} + +template +inline OrtMemType MemoryInfoImpl::GetMemoryType() const { + OrtMemType type; + ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type)); + return type; +} + +template +template +inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { + int comp_result = 0; + ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result)); + return comp_result == 0; +} + +} // namespace detail + +inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { + OrtMemoryInfo* p; + ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); + return MemoryInfo(p); +} + +inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); +} + +namespace detail { +template +inline std::vector ConstIoBindingImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputNamesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputNames(OrtAllocator* allocator) const { + return binding_utils::GetOutputNamesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputValues() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputValuesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { + return binding_utils::GetOutputValuesHelper(this->p_, allocator); +} + +template +inline void IoBindingImpl::BindInput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindInput(this->p_, name, value)); +} + +template +inline void IoBindingImpl::BindOutput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindOutput(this->p_, name, value)); +} + +template +inline void IoBindingImpl::BindOutput(const char* name, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info)); +} + +template +inline void IoBindingImpl::ClearBoundInputs() { + GetApi().ClearBoundInputs(this->p_); +} + +template +inline void IoBindingImpl::ClearBoundOutputs() { + GetApi().ClearBoundOutputs(this->p_); +} + +template +inline void IoBindingImpl::SynchronizeInputs() { + ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_)); +} + +template +inline void IoBindingImpl::SynchronizeOutputs() { + ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_)); +} + +namespace binding_utils { +inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { + std::vector result; + auto free_fn = detail::AllocatedFree(allocator); + using Ptr = std::unique_ptr; + + char* buffer = nullptr; + size_t* lengths = nullptr; + size_t count = 0; + ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count)); + + if (count == 0) { + return result; + } + + Ptr buffer_g(buffer, free_fn); + Ptr lengths_g(lengths, free_fn); + + result.reserve(count); + for (size_t i = 0; i < count; ++i) { + auto sz = *lengths; + result.emplace_back(buffer, sz); + buffer += sz; + ++lengths; + } + return result; +} + +inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { + std::vector result; + size_t owned = 0; + size_t output_count = 0; + // Lambda to release the buffer when no longer needed and + // make sure that we destroy all instances on exception + auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { + if (buffer) { + while (owned < output_count) { + auto* p = buffer + owned++; + GetApi().ReleaseValue(*p); + } + allocator->Free(allocator, buffer); + } + }; + using Ptr = std::unique_ptr; + + OrtValue** output_buffer = nullptr; + ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); + if (output_count == 0) { + return result; + } + + Ptr buffer_g(output_buffer, free_fn); + + result.reserve(output_count); + for (size_t i = 0; i < output_count; ++i) { + result.emplace_back(output_buffer[i]); + ++owned; + } + return result; +} + +} // namespace binding_utils +} // namespace detail + +inline IoBinding::IoBinding(Session& session) { + ThrowOnError(GetApi().CreateIoBinding(session, &this->p_)); +} + +inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { + ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); +} + +inline ThreadingOptions::ThreadingOptions() { + ThrowOnError(GetApi().CreateThreadingOptions(&p_)); +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) { + ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() { + ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); + return *this; +} + +inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { + ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, + OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env& Env::EnableTelemetryEvents() { + ThrowOnError(GetApi().EnableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::DisableTelemetryEvents() { + ThrowOnError(GetApi().DisableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) { + ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level)); + return *this; +} + +inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { + ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); + return *this; +} + +inline CustomOpDomain::CustomOpDomain(const char* domain) { + ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); +} + +inline void CustomOpDomain::Add(const OrtCustomOp* op) { + ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); +} + +inline RunOptions::RunOptions() { + ThrowOnError(GetApi().CreateRunOptions(&p_)); +} + +inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level)); + return *this; +} + +inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level)); + return *this; +} + +inline int RunOptions::GetRunLogVerbosityLevel() const { + int out; + ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out)); + return out; +} + +inline int RunOptions::GetRunLogSeverityLevel() const { + int out; + ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { + ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag)); + return *this; +} + +inline const char* RunOptions::GetRunTag() const { + const char* out; + ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value)); + return *this; +} + +inline RunOptions& RunOptions::SetTerminate() { + ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); + return *this; +} + +inline RunOptions& RunOptions::UnsetTerminate() { + ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_)); + return *this; +} + +namespace detail { + +template +inline Ort::SessionOptions ConstSessionOptionsImpl::Clone() const { + OrtSessionOptions* out; + ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out)); + return SessionOptions{out}; +} + +template +inline std::string ConstSessionOptionsImpl::GetConfigEntry(const char* config_key) const { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) const { + int out = 0; + Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out)); + return static_cast(out); +} + +template +inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { + if (!this->HasConfigEntry(config_key)) { + return def; + } + + return this->GetConfigEntry(config_key); +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { + ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { + ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableProfiling() { + ThrowOnError(GetApi().DisableProfiling(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableOrtCustomOps() { + ThrowOnError(GetApi().EnableOrtCustomOps(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableMemPattern() { + ThrowOnError(GetApi().EnableMemPattern(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableMemPattern() { + ThrowOnError(GetApi().DisableMemPattern(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableCpuMemArena() { + ThrowOnError(GetApi().EnableCpuMemArena(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableCpuMemArena() { + ThrowOnError(GetApi().DisableCpuMemArena(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionMode execution_mode) { + ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { + ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogSeverityLevel(int level) { + ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::Add(OrtCustomOpDomain* custom_op_domain) { + ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddInitializer(const char* name, const OrtValue* ort_val) { + ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisablePerSessionThreads() { + ThrowOnError(GetApi().DisablePerSessionThreads(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializers(const std::vector& names, + const std::vector& ort_values) { + const size_t inputs_num = names.size(); + if (inputs_num != ort_values.size()) { + ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT); + } + std::vector names_ptr; + std::vector ort_values_ptrs; + names_ptr.reserve(inputs_num); + ort_values_ptrs.reserve(inputs_num); + for (size_t i = 0; i < inputs_num; ++i) { + names_ptr.push_back(names[i].c_str()); + ort_values_ptrs.push_back(ort_values[i]); + } + ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( + const std::string& provider_name, + const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(), + keys.data(), values.data(), num_entries)); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, + const CustomOpConfigs& custom_op_configs) { + // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by + // the custom op library. + for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) { + AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str()); + } + + ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunction(const char* registration_function_name) { + ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name)); + return *this; +} + +/// Session +template +inline size_t ConstSessionImpl::GetInputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out)); + return out; +} + +template +inline size_t ConstSessionImpl::GetOutputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out)); + return out; +} + +template +inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out)); + return out; +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { + uint64_t out; + ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out)); + return out; +} + +template +inline ModelMetadata ConstSessionImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count) { + std::vector output_values; + output_values.reserve(output_count); + for (size_t i = 0; i < output_count; i++) + output_values.emplace_back(nullptr); + Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count); + return output_values; +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& io_binding) { + ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); +} + +template +inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { + char* out = nullptr; + ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +} // namespace detail + +inline SessionOptions::SessionOptions() { + ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); +} + +/// CustomOpConfigs +inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) { + std::string config_key = "custom_op."; + + config_key += custom_op_name; + config_key += "."; + config_key += config; + + return config_key; +} + +inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) { + const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key); + flat_configs_[full_flat_key] = config_value; + return *this; +} + +inline const std::unordered_map& CustomOpConfigs::GetFlattenedConfigs() const { + return flat_configs_; +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, + prepacked_weights_container, &this->p_)); +} + +inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { + auto deletor = detail::AllocatedFree(allocator); + std::vector result; + + char** out = nullptr; + int64_t num_keys = 0; + ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + if (num_keys <= 0) { + return result; + } + + // array of pointers will be freed + std::unique_ptr array_guard(out, deletor); + // reserve may throw + auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; + std::unique_ptr strings_guard(out, strings_deletor); + result.reserve(static_cast(num_keys)); + strings_guard.release(); + for (int64_t i = 0; i < num_keys; ++i) { + result.push_back(AllocatedStringPtr(out[i], deletor)); + } + + return result; +} + +inline int64_t ModelMetadata::GetVersion() const { + int64_t out; + ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); + return out; +} + +namespace detail { + +template +inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl::GetElementType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetTensorElementType(this->p_, &out)); + return out; +} + +template +inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { + size_t out; + ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); + return static_cast(out); +} + +template +inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { + size_t out; + ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out)); + return out; +} + +template +inline void TensorTypeAndShapeInfoImpl::GetDimensions(int64_t* values, size_t values_count) const { + ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count)); +} + +template +inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** values, size_t values_count) const { + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); +} + +template +inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { + std::vector out(GetDimensionsCount(), 0); + ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); + return out; +} + +} // namespace detail + +namespace detail { +template +inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { + const OrtTensorTypeAndShapeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out)); + return ConstTensorTypeAndShapeInfo{out}; +} + +template +inline ConstSequenceTypeInfo TypeInfoImpl::GetSequenceTypeInfo() const { + const OrtSequenceTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out)); + return ConstSequenceTypeInfo{out}; +} + +template +inline ConstMapTypeInfo TypeInfoImpl::GetMapTypeInfo() const { + const OrtMapTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out)); + return ConstMapTypeInfo{out}; +} + +template +inline ONNXType TypeInfoImpl::GetONNXType() const { + ONNXType out; + ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out)); + return out; +} + +} // namespace detail + +namespace detail { +template +inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output)); + return TypeInfo{output}; +} + +} // namespace detail + +namespace detail { +template +inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetMapKeyType(this->p_, &out)); + return out; +} + +template +inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); + return TypeInfo{output}; +} +} // namespace detail + +namespace detail { + +template +template +inline void ConstValueImpl::GetOpaqueData(const char* domain, const char* type_name, R& out) const { + ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R))); +} + +template +inline bool ConstValueImpl::IsTensor() const { + int out; + ThrowOnError(GetApi().IsTensor(this->p_, &out)); + return out != 0; +} + +template +inline bool ConstValueImpl::HasValue() const { + int out; + ThrowOnError(GetApi().HasValue(this->p_, &out)); + return out != 0; +} + +template +inline size_t ConstValueImpl::GetCount() const { + size_t out; + ThrowOnError(GetApi().GetValueCount(this->p_, &out)); + return out; +} + +template +inline Value ConstValueImpl::GetValue(int index, OrtAllocator* allocator) const { + OrtValue* out; + ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out)); + return Value{out}; +} + +template +inline size_t ConstValueImpl::GetStringTensorDataLength() const { + size_t out; + ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out)); + return out; +} + +template +inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_index) const { + size_t out; + ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out)); + return out; +} + +template +template +inline const R* ConstValueImpl::GetTensorData() const { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + return out; +} + +template +inline const void* ConstValueImpl::GetTensorRawData() const { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + return out; +} + +template +inline TypeInfo ConstValueImpl::GetTypeInfo() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetTypeInfo(this->p_, &output)); + return TypeInfo{output}; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetTensorTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +inline ConstMemoryInfo ConstValueImpl::GetTensorMemoryInfo() const { + const OrtMemoryInfo* mem_info; + ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info)); + return ConstMemoryInfo(mem_info); +} + +template +inline void ConstValueImpl::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { + ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer)); +} + +template +inline void ConstValueImpl::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { + ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count)); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +template +inline OrtSparseFormat ConstValueImpl::GetSparseFormat() const { + OrtSparseFormat format; + ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format)); + return format; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorValuesTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +template +inline const R* ConstValueImpl::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { + const void* out; + ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out)); + return reinterpret_cast(out); +} + +template +inline bool ConstValueImpl::IsSparseTensor() const { + int out; + ThrowOnError(GetApi().IsSparseTensor(this->p_, &out)); + return out != 0; +} + +template +template +inline const R* ConstValueImpl::GetSparseTensorValues() const { + const void* out; + ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out)); + return reinterpret_cast(out); +} + +#endif + +template +void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { + ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); +} + +template +void ValueImpl::FillStringTensorElement(const char* s, size_t index) { + ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index)); +} + +template +void* ValueImpl::GetTensorMutableRawData() { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out)); + return out; +} + +template +template +R* ValueImpl::GetTensorMutableData() { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out)); + return out; +} + +template +template +R& ValueImpl::At(const std::vector& location) { + static_assert(!std::is_same::value, "this api does not support std::string"); + R* out; + ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out)); + return *out; +} + +#if !defined(DISABLE_SPARSE_TENSORS) +template +void ValueImpl::UseCooIndices(int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num)); +} + +template +void ValueImpl::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { + ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num)); +} + +template +void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { + ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data)); +} + +template +void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, + const int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, + values_param.values_shape_len, values_param.data.p_data, + indices_data, indices_num)); +} + +template +void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num) { + ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + inner_indices_data, inner_indices_num, + outer_indices_data, outer_indices_num)); +} + +template +void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data) { + ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + indices_shape.shape, indices_shape.shape_len, + indices_data)); +} + +#endif // !defined(DISABLE_SPARSE_TENSORS) + +} // namespace detail + +template +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { + return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); + return Value{out}; +} + +template +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { + return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); + return Value{out}; +} + +#if !defined(DISABLE_SPARSE_TENSORS) + +template +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape) { + return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); +} + +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, + values_shape.shape, values_shape.shape_len, type, &out)); + return Value{out}; +} + +template +inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) { + return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type); +} + +inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out)); + return Value{out}; +} +#endif // !defined(DISABLE_SPARSE_TENSORS) + +inline Value Value::CreateMap(Value& keys, Value& values) { + OrtValue* out; + OrtValue* inputs[2] = {keys, values}; + ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); + return Value{out}; +} + +inline Value Value::CreateSequence(std::vector& values) { + OrtValue* out; + std::vector values_ort{values.data(), values.data() + values.size()}; + ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); + return Value{out}; +} + +template +inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) { + OrtValue* out; + ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out)); + return Value{out}; +} + +// +// Custom OP Inlines +// +inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) { +} + +inline size_t KernelContext::GetInputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out)); + return out; +} + +inline size_t KernelContext::GetOutputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out)); + return out; +} + +inline ConstValue KernelContext::GetInput(size_t index) const { + const OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out)); + return ConstValue{out}; +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out)); + return UnownedValue(out); +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector& dims) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out)); + return UnownedValue(out); +} + +inline void* KernelContext::GetGPUComputeStream() const { + void* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out)); + return out; +} + +inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { + Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); +} + +namespace detail { +template +inline KernelInfo KernelInfoImpl::Copy() const { + OrtKernelInfo* info_copy = nullptr; + Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); + return KernelInfo{info_copy}; +} + +template +inline size_t KernelInfoImpl::GetInputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out)); + return out; +} + +template +inline size_t KernelInfoImpl::GetOutputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out)); + return out; +} + +template +inline std::string KernelInfoImpl::GetInputName(size_t index) const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline std::string KernelInfoImpl::GetOutputName(size_t index) const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline TypeInfo KernelInfoImpl::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo KernelInfoImpl::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline Value KernelInfoImpl::GetTensorAttribute(const char* name, OrtAllocator* allocator) const { + OrtValue* out = nullptr; + ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out)); + return Value{out}; +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the string attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + out.swap(result); +} + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size)); + + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size)); + out.swap(result); +} + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size)); + + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); + out.swap(result); +} +} // namespace detail + +inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} + +inline Op::Op(OrtOp* p) : Base(p) {} + +inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, + const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, size_t attr_count, + size_t input_count, size_t output_count) { + static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*), + "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely"); + auto attr_input_values = reinterpret_cast(attr_values); + OrtOp* op; + Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, + static_cast(type_constraint_count), + attr_input_values, + static_cast(attr_count), + static_cast(input_count), + static_cast(output_count), &op)); + return Op{op}; +} + +inline void Op::Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), + "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), + ort_output_values, static_cast(output_count))); +} + +inline void Op::Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count) { + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast(input_count), + output_values, static_cast(output_count))); +} + +inline void CustomOpApi::ThrowOnError(OrtStatus* status) { + Ort::ThrowOnError(status); +} + +template <> +inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + float out; + Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out)); + return out; +} + +template <> +inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + int64_t out; + Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out)); + return out; +} + +template <> +inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::string out; + + // Feed nullptr for the data buffer to query the true size of the string attribute + OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size); + + if (status == nullptr) { + out.resize(size); + Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + } else { + Ort::ThrowOnError(status); + } + return out; +} + +template <> +inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::vector out; + + // Feed nullptr for the data buffer to query the true size of the attribute + OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size); + + if (status == nullptr) { + out.resize(size); + Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size)); + } else { + Ort::ThrowOnError(status); + } + return out; +} + +template <> +inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::vector out; + + // Feed nullptr for the data buffer to query the true size of the attribute + OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size); + + if (status == nullptr) { + out.resize(size); + Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size)); + } else { + Ort::ThrowOnError(status); + } + return out; +} +inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) { + OrtTensorTypeAndShapeInfo* out; + Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out)); + return out; +} + +inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + size_t out; + Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out)); + return out; +} + +inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) { + ONNXTensorElementDataType out; + Ort::ThrowOnError(api_.GetTensorElementType(info, &out)); + return out; +} + +inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + size_t out; + Ort::ThrowOnError(api_.GetDimensionsCount(info, &out)); + return out; +} + +inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { + Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length)); +} + +inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { + Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count)); +} + +template +inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) { + T* data; + Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast(&data))); + return data; +} + +inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) { + const OrtMemoryInfo* mem_info; + Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info)); + return mem_info; +} + +template +inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) { + T* data = nullptr; + Ort::ThrowOnError(api_.GetTensorMutableData(const_cast(value), reinterpret_cast(&data))); + return data; +} + +inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { + size_t out; + Ort::ThrowOnError(api_.GetDimensionsCount(info, &out)); + std::vector output(out); + Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out)); + return output; +} + +inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { + api_.ReleaseTensorTypeAndShapeInfo(input); +} + +inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) { + size_t out; + Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out)); + return out; +} + +inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) { + const OrtValue* out; + Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out)); + return out; +} + +inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) { + size_t out; + Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out)); + return out; +} + +inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count) { + OrtValue* out; + Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out)); + return out; +} + +inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) { + void* out; + Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out)); + return out; +} + +inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name, + _In_ const void* data, + _In_ int len, + _In_ OrtOpAttrType type) { + OrtOpAttr* op_attr{}; + Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr)); + return op_attr; +} + +inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) { + api_.ReleaseOpAttr(op_attr); +} + +inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info, + _In_ const char* op_name, + _In_ const char* domain, + _In_ int version, + _In_opt_ const char** type_constraint_names, + _In_opt_ const ONNXTensorElementDataType* type_constraint_values, + _In_opt_ int type_constraint_count, + _In_opt_ const OrtOpAttr* const* attr_values, + _In_opt_ int attr_count, + _In_ int input_count, + _In_ int output_count) { + OrtOp* ort_op{}; + Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, + type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op)); + return ort_op; +} + +inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context, + _In_ const OrtOp* ort_op, + _In_ const OrtValue* const* input_values, + _In_ int input_count, + _Inout_ OrtValue* const* output_values, + _In_ int output_count) { + Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count)); +} + +inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) { + api_.ReleaseOp(ort_op); +} + +inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) { + OrtKernelInfo* info_copy{}; + Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy)); + return info_copy; +} + +inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) { + api_.ReleaseKernelInfo(info_copy); +} + +inline std::vector GetAvailableProviders() { + int len; + char** providers; + ThrowOnError(GetApi().GetAvailableProviders(&providers, &len)); + std::vector available_providers(providers, providers + len); + ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len)); + return available_providers; +} + +SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); + +template +void CustomOpBase::GetSessionConfigs(std::unordered_map& out, + ConstSessionOptions options) const { + const TOp* derived = static_cast(this); + std::vector keys = derived->GetSessionConfigKeys(); + + out.reserve(keys.size()); + + std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), ""); + const size_t prefix_size = config_entry_key.length(); + + for (const auto& key : keys) { + config_entry_key.resize(prefix_size); + config_entry_key.append(key); + out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), ""); + } +} + +} // namespace Ort diff --git a/ThirdParty/onnxruntime/include/onnxruntime_run_options_config_keys.h b/ThirdParty/onnxruntime/include/onnxruntime_run_options_config_keys.h new file mode 100644 index 0000000..1f5fcd5 --- /dev/null +++ b/ThirdParty/onnxruntime/include/onnxruntime_run_options_config_keys.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* + * This file defines RunOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a RunOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 128 + * + * The string format of a RunOptions Config Value is defined individually for each Config. + * The maximum length of the Config Value is 1024 + */ + +// Key for enabling shrinkages of user listed device memory arenas. +// Expects a list of semi-colon separated key value pairs separated by colon in the following format: +// "device_0:device_id_0;device_1:device_id_1" +// No white-spaces allowed in the provided list string. +// Currently, the only supported devices are : "cpu", "gpu" (case sensitive). +// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled. +// Example usage: "cpu:0;gpu:0" (or) "gpu:0" +// By default, the value for this key is empty (i.e.) no memory arenas are shrunk +static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; + +// Set to '1' to not synchronize execution providers with CPU at the end of session run. +// Per default it will be set to '0' +// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. +static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; diff --git a/ThirdParty/onnxruntime/include/onnxruntime_session_options_config_keys.h b/ThirdParty/onnxruntime/include/onnxruntime_session_options_config_keys.h new file mode 100644 index 0000000..92482d7 --- /dev/null +++ b/ThirdParty/onnxruntime/include/onnxruntime_session_options_config_keys.h @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* + * This file defines SessionOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a SessionOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 128 + * + * The string format of a SessionOptions Config Value is defined individually for each Config. + * The maximum length of the Config Value is 1024 + */ + +// Key for disable PrePacking, +// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value) +static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking"; + +// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session +// will be used. Use this to override the usage of env allocators on a per session level. +static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators"; + +// Set to 'ORT' (case sensitive) to load an ORT format model. +// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT +static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format"; + +// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set. +// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'. +static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format"; + +// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0". +// When multiple sessions are created, a main thread doesn't override changes from succeeding session options, +// but threads in session thread pools follow option changes. +// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and +// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool. +// Note that an alternative way not using this option at runtime is to train and export a model without denormals +// and that's recommended because turning this option on may hurt model accuracy. +static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero"; + +// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not. +// "0": enable. ORT does fusion logic for QDQ format. +// "1": disable. ORT doesn't do fusion logic for QDQ format. +// Its default value is "0" +static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; + +// It controls whether to enable Double QDQ remover and Identical Children Consolidation +// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// Its default value is "0" +static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; + +// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been +// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the +// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to +// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on +// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. +// As such, it's best to test to determine if enabling this works well for your scenario. +// The default value is "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; + +// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". +// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. +static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; + +#ifdef ENABLE_TRAINING +// Specifies a list of op types for memory footprint reduction. +// The value should be a ","-delimited list of pair of +// . +// For example, "Gelu+Cast+:1:0,Dropout+:1:1". +// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. +// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. +// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" +// the memory. +static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer"; + +// Specifies the level for detecting subgraphs for memory footprint reduction. +// The value should be an integer. The default value is 0. +static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level"; +#endif + +// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". +// Using device allocators means the memory allocation is made using malloc/new. +static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; + +// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking +// "0": thread will block if found no job to run +// "1": default, thread will spin a number of times before blocking +static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; +static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; + +// Key for using model bytes directly for ORT format +// If a session is created using an input byte array contains the ORT format model data, +// By default we will copy the model bytes at the time of session creation to ensure the model bytes +// buffer is valid. +// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller +// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. +static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; + +/// +/// Key for using the ORT format model flatbuffer bytes directly for initializers. +/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. +/// Requires `session.use_ort_model_bytes_directly` to be true. +/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire +/// duration of the InferenceSession. +/// +static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = + "session.use_ort_model_bytes_for_initializers"; + +// This should only be specified when exporting an ORT format model for use on a different platform. +// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; + +// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8. +// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if +// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512 +// platforms. +static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision"; + +// Specifies how minimal build graph optimizations are handled in a full build. +// These optimizations are at the extended level or higher. +// Possible values and their effects are: +// "save": Save runtime optimizations when saving an ORT format model. +// "apply": Only apply optimizations available in a minimal build. +// ""/: Apply optimizations available in a full build. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = + "optimization.minimal_build_optimizations"; + +// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in +// order for them to take effect. + +// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be +// run by the NNAPI EP. +// The value should be a ","-delimited list of op types. For example, "Add,Sub". +// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op +// exclusion, set the value to "". +static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; + +// Enabling dynamic block-sizing for multithreading. +// With a positive value, thread pool will split a task of N iterations to blocks of size starting from: +// N / (num_of_threads * dynamic_block_base) +// As execution progresses, the size will decrease according to the diminishing residual of N, +// meaning the task will be distributed in smaller granularity for better parallelism. +// For some models, it helps to reduce the variance of E2E inference latency and boost performance. +// The feature will not function by default, specify any positive integer, e.g. "4", to enable it. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base"; + +// This option allows to decrease CPU usage between infrequent +// requests and forces any TP threads spinning stop immediately when the last of +// concurrent Run() call returns. +// Spinning is restarted on the next Run() call. +// Applies only to internal thread-pools +static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop"; + +// "1": all inconsistencies encountered during shape and type inference +// will result in failures. +// "0": in some cases warnings will be logged but processing will continue. The default. +// May be useful to expose bugs in models. +static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference"; + +// The file saves configuration for partitioning node among logic streams +static const char* const kNodePartitionConfigFile = "session.node_partition_config_file"; + +// This Option allows setting affinities for intra op threads. +// Affinity string follows format: +// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id +// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. +// e.g.1,2,3;4,5 +// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. +// To ease the configuration, an "interval" is also allowed: +// e.g. 1-8;8-16;17-24 +// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. +// Note: +// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which +// is started and managed by the calling app; +// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, +// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. +// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. +static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities"; diff --git a/ThirdParty/onnxruntime/include/provider_options.h b/ThirdParty/onnxruntime/include/provider_options.h new file mode 100644 index 0000000..aab13e8 --- /dev/null +++ b/ThirdParty/onnxruntime/include/provider_options.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { + +// data types for execution provider options + +using ProviderOptions = std::unordered_map; +using ProviderOptionsVector = std::vector; +using ProviderOptionsMap = std::unordered_map; + +} // namespace onnxruntime diff --git a/ThirdParty/onnxruntime/lib/libonnxruntime.a b/ThirdParty/onnxruntime/lib/libonnxruntime.a new file mode 100644 index 0000000..cd05adf Binary files /dev/null and b/ThirdParty/onnxruntime/lib/libonnxruntime.a differ