onnxruntime/nodejs/src/onnxruntime_cxx_api.h
Yulong Wang 5dfc91db51
Node.js binding for ONNX Runtime (#3613)
* initial commit for Node.js binding

* add c++ code

* add inference session impl

* e2e working

* add settings.json

* add test data

* adjust binding declaration

* refine tensor constructor declaration

* update tests

* enable onnx tests

* simply refine readme

* refine cpp impl

* refine tests

* formatting

* add linting

* move bin folder

* fix linux build

* manually update test filter list

* update C++ API headers: fix crash in release build

* make (manually) prebuild work

* add test into prepack script

* specify prebuild runtime type (N-API)

* build.ts: update rebuild and include regex

* fix lazy load on electron.js

* update dev version, git link and binary host

* support session options and run options

* bump dev version

* update README

* add 1 example

* move folder

* adjust path

* update document for examples

* rename example 01

* add example 02

* add session option: log severity level

* add example 04

* resolve comments

* fix typo

* remove double guard in header files

* add copyright banner

* move BUILD outside from README

* consume test filter list from onnxruntime
2020-05-05 11:45:12 -07:00

379 lines
14 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
//
// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
// and automatically releasing resources in the destructors.
//
// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
//
// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
// methods for this purpose.
#pragma once
#include <core/session/onnxruntime_c_api.h>
#include <array>
#include <cstddef>
#include <memory>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
namespace Ort {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string &&string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char *what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
// This need to be defined in a .cpp file
extern const OrtApi *g_api_;
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
inline const OrtApi &GetApi() { return *g_api_; }
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the
// various OrtRelease* functions for every Ort* type This can't be done in the C API since C doesn't have function
// overloading.
#define ORT_DEFINE_RELEASE(NAME) \
inline void OrtRelease(Ort##NAME *ptr) { GetApi().Release##NAME(ptr); }
ORT_DEFINE_RELEASE(MemoryInfo);
ORT_DEFINE_RELEASE(CustomOpDomain);
ORT_DEFINE_RELEASE(Env);
ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(Session);
ORT_DEFINE_RELEASE(SessionOptions);
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
ORT_DEFINE_RELEASE(TypeInfo);
ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ModelMetadata);
ORT_DEFINE_RELEASE(ThreadingOptions);
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
template <typename T> struct Base {
Base() = default;
Base(T *p) : p_{p} {
if (!p)
throw Ort::Exception("Allocation failure", ORT_FAIL);
}
~Base() { OrtRelease(p_); }
operator T *() { return p_; }
operator const T *() const { return p_; }
T *release() {
T *p = p_;
p_ = nullptr;
return p;
}
protected:
Base(const Base &) = delete;
Base &operator=(const Base &) = delete;
Base(Base &&v) noexcept : p_{v.p_} { v.p_ = nullptr; }
void operator=(Base &&v) noexcept {
if (p_) {
OrtRelease(p_);
}
p_ = v.p_;
v.p_ = nullptr;
}
T *p_{};
template <typename>
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
};
template <typename T> struct Unowned : T {
Unowned(decltype(T::p_) p) : T{p} {}
Unowned(Unowned &&v) : T{v.p_} {}
~Unowned() { this->p_ = nullptr; }
};
struct AllocatorWithDefaultOptions;
struct MemoryInfo;
struct Env;
struct TypeInfo;
struct Value;
struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char *logid = "");
Env(const OrtThreadingOptions *tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING,
_In_ const char *logid = "");
Env(OrtLoggingLevel default_logging_level, const char *logid, OrtLoggingFunction logging_function,
void *logger_param);
explicit Env(OrtEnv *p) : Base<OrtEnv>{p} {}
Env &EnableTelemetryEvents();
Env &DisableTelemetryEvents();
static const OrtApi *s_api;
};
struct CustomOpDomain : Base<OrtCustomOpDomain> {
explicit CustomOpDomain(std::nullptr_t) {}
explicit CustomOpDomain(const char *domain);
void Add(OrtCustomOp *op);
};
struct RunOptions : Base<OrtRunOptions> {
RunOptions(std::nullptr_t) {}
RunOptions();
RunOptions &SetRunLogVerbosityLevel(int);
int GetRunLogVerbosityLevel() const;
RunOptions &SetRunLogSeverityLevel(int);
int GetRunLogSeverityLevel() const;
RunOptions &SetRunTag(const char *run_tag);
const char *GetRunTag() const;
// terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
RunOptions &SetTerminate();
// unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
RunOptions &UnsetTerminate();
};
struct SessionOptions : Base<OrtSessionOptions> {
explicit SessionOptions(std::nullptr_t) {}
SessionOptions();
explicit SessionOptions(OrtSessionOptions *p) : Base<OrtSessionOptions>{p} {}
SessionOptions Clone() const;
SessionOptions &SetIntraOpNumThreads(int intra_op_num_threads);
SessionOptions &SetInterOpNumThreads(int inter_op_num_threads);
SessionOptions &SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
SessionOptions &EnableCpuMemArena();
SessionOptions &DisableCpuMemArena();
SessionOptions &SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file);
SessionOptions &EnableProfiling(const ORTCHAR_T *profile_file_prefix);
SessionOptions &DisableProfiling();
SessionOptions &EnableMemPattern();
SessionOptions &DisableMemPattern();
SessionOptions &SetExecutionMode(ExecutionMode execution_mode);
SessionOptions &SetLogId(const char *logid);
SessionOptions &SetLogSeverityLevel(int);
SessionOptions &Add(OrtCustomOpDomain *custom_op_domain);
SessionOptions &DisablePerSessionThreads();
};
struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {}
explicit ModelMetadata(OrtModelMetadata *p) : Base<OrtModelMetadata>{p} {}
char *GetProducerName(OrtAllocator *allocator) const;
char *GetGraphName(OrtAllocator *allocator) const;
char *GetDomain(OrtAllocator *allocator) const;
char *GetDescription(OrtAllocator *allocator) const;
char **GetCustomMetadataMapKeys(OrtAllocator *allocator, _Out_ int64_t &num_keys) const;
char *LookupCustomMetadataMap(const char *key, OrtAllocator *allocator) const;
int64_t GetVersion() const;
};
struct Session : Base<OrtSession> {
explicit Session(std::nullptr_t) {}
Session(Env &env, const ORTCHAR_T *model_path, const SessionOptions &options);
Session(Env &env, const void *model_data, size_t model_data_length, const SessionOptions &options);
// Run that will allocate the output values
std::vector<Value> Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values,
size_t input_count, const char *const *output_names, size_t output_count);
// Run for when there is a list of prealloated outputs
void Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count,
const char *const *output_names, Value *output_values, size_t output_count);
size_t GetInputCount() const;
size_t GetOutputCount() const;
size_t GetOverridableInitializerCount() const;
char *GetInputName(size_t index, OrtAllocator *allocator) const;
char *GetOutputName(size_t index, OrtAllocator *allocator) const;
char *GetOverridableInitializerName(size_t index, OrtAllocator *allocator) const;
char *EndProfiling(OrtAllocator *allocator) const;
ModelMetadata GetModelMetadata() const;
TypeInfo GetInputTypeInfo(size_t index) const;
TypeInfo GetOutputTypeInfo(size_t index) const;
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
};
struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
ONNXTensorElementDataType GetElementType() const;
size_t GetElementCount() const;
size_t GetDimensionsCount() const;
void GetDimensions(int64_t *values, size_t values_count) const;
void GetSymbolicDimensions(const char **values, size_t values_count) const;
std::vector<int64_t> GetShape() const;
};
struct TypeInfo : Base<OrtTypeInfo> {
explicit TypeInfo(std::nullptr_t) {}
explicit TypeInfo(OrtTypeInfo *p) : Base<OrtTypeInfo>{p} {}
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
ONNXType GetONNXType() const;
};
struct Value : Base<OrtValue> {
template <typename T>
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape,
size_t shape_len);
static Value CreateTensor(const OrtMemoryInfo *info, void *p_data, size_t p_data_byte_count, const int64_t *shape,
size_t shape_len, ONNXTensorElementDataType type);
template <typename T> static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len);
static Value CreateTensor(OrtAllocator *allocator, const int64_t *shape, size_t shape_len,
ONNXTensorElementDataType type);
static Value CreateMap(Value &keys, Value &values);
static Value CreateSequence(std::vector<Value> &values);
template <typename T> static Value CreateOpaque(const char *domain, const char *type_name, const T &);
template <typename T> void GetOpaqueData(const char *domain, const char *type_name, T &);
explicit Value(std::nullptr_t) {}
explicit Value(OrtValue *p) : Base<OrtValue>{p} {}
Value(Value &&) = default;
Value &operator=(Value &&) = default;
bool IsTensor() const;
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
Value GetValue(int index, OrtAllocator *allocator) const;
size_t GetStringTensorDataLength() const;
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const;
template <typename T> T *GetTensorMutableData();
TypeInfo GetTypeInfo() const;
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
};
struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();
operator OrtAllocator *() { return p_; }
operator const OrtAllocator *() const { return p_; }
void *Alloc(size_t size);
void Free(void *p);
const OrtMemoryInfo *GetInfo() const;
private:
OrtAllocator *p_{};
};
struct MemoryInfo : Base<OrtMemoryInfo> {
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
explicit MemoryInfo(std::nullptr_t) {}
MemoryInfo(const char *name, OrtAllocatorType type, int id, OrtMemType mem_type);
explicit MemoryInfo(OrtMemoryInfo *p) : Base<OrtMemoryInfo>{p} {}
};
//
// Custom OPs (only needed to implement custom OPs)
//
struct CustomOpApi {
CustomOpApi(const OrtApi &api) : api_(api) {}
template <typename T> // T is only implemented for float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo *info, _In_ const char *name);
OrtTensorTypeAndShapeInfo *GetTensorTypeAndShape(_In_ const OrtValue *value);
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo *info);
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo *info);
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo *info);
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo *info, _Out_ int64_t *dim_values, size_t dim_values_length);
void SetDimensions(OrtTensorTypeAndShapeInfo *info, _In_ const int64_t *dim_values, size_t dim_count);
template <typename T> T *GetTensorMutableData(_Inout_ OrtValue *value);
template <typename T> const T *GetTensorData(_Inout_ const OrtValue *value);
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo *info);
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input);
size_t KernelContext_GetInputCount(const OrtKernelContext *context);
const OrtValue *KernelContext_GetInput(const OrtKernelContext *context, _In_ size_t index);
size_t KernelContext_GetOutputCount(const OrtKernelContext *context);
OrtValue *KernelContext_GetOutput(OrtKernelContext *context, _In_ size_t index, _In_ const int64_t *dim_values,
size_t dim_count);
void ThrowOnError(OrtStatus *result);
private:
const OrtApi &api_;
};
template <typename TOp, typename TKernel> struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](OrtCustomOp *this_, const OrtApi *api, const OrtKernelInfo *info) {
return static_cast<TOp *>(this_)->CreateKernel(*api, info);
};
OrtCustomOp::GetName = [](OrtCustomOp *this_) { return static_cast<TOp *>(this_)->GetName(); };
OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp *this_) {
return static_cast<TOp *>(this_)->GetExecutionProviderType();
};
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp *this_) { return static_cast<TOp *>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](OrtCustomOp *this_, size_t index) {
return static_cast<TOp *>(this_)->GetInputType(index);
};
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp *this_) {
return static_cast<TOp *>(this_)->GetOutputTypeCount();
};
OrtCustomOp::GetOutputType = [](OrtCustomOp *this_, size_t index) {
return static_cast<TOp *>(this_)->GetOutputType(index);
};
OrtCustomOp::KernelCompute = [](void *op_kernel, OrtKernelContext *context) {
static_cast<TKernel *>(op_kernel)->Compute(context);
};
OrtCustomOp::KernelDestroy = [](void *op_kernel) { delete static_cast<TKernel *>(op_kernel); };
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char *GetExecutionProviderType() const { return nullptr; }
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"