mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Description This change enhances the Node.js binding with the following features: - support WebGPU EP - lazy initialization of `OrtEnv` - being able to initialize ORT with default log level setting from `ort.env.logLevel`. - session options: - `enableProfiling` and `profileFilePrefix`: support profiling. - `externalData`: explicit external data (optional in Node.js binding) - `optimizedModelFilePath`: allow dumping optimized model for diagnosis purpose - `preferredOutputLocation`: support IO binding. ====================================================== `Tensor.download()` is not implemented in this PR. Build pipeline update is not included in this PR.
378 lines
17 KiB
C++
378 lines
17 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "onnxruntime_cxx_api.h"
|
|
#include <napi.h>
|
|
|
|
#include <cmath>
|
|
#include <unordered_map>
|
|
#include <filesystem>
|
|
|
|
#include "common.h"
|
|
#include "session_options_helper.h"
|
|
#include "tensor_helper.h"
|
|
#ifdef USE_CUDA
|
|
#include "core/providers/cuda/cuda_provider_options.h"
|
|
#endif
|
|
#ifdef USE_DML
|
|
#include "core/providers/dml/dml_provider_factory.h"
|
|
#endif
|
|
#ifdef USE_WEBGPU
|
|
#include "core/providers/webgpu/webgpu_provider_factory.h"
|
|
#endif
|
|
#ifdef USE_TENSORRT
|
|
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
|
#endif
|
|
#ifdef USE_COREML
|
|
#include "core/providers/coreml/coreml_provider_factory.h"
|
|
#endif
|
|
|
|
const std::unordered_map<std::string, GraphOptimizationLevel> GRAPH_OPT_LEVEL_NAME_TO_ID_MAP = {
|
|
{"disabled", ORT_DISABLE_ALL},
|
|
{"basic", ORT_ENABLE_BASIC},
|
|
{"extended", ORT_ENABLE_EXTENDED},
|
|
{"all", ORT_ENABLE_ALL}};
|
|
|
|
const std::unordered_map<std::string, ExecutionMode> EXECUTION_MODE_NAME_TO_ID_MAP = {{"sequential", ORT_SEQUENTIAL},
|
|
{"parallel", ORT_PARALLEL}};
|
|
|
|
void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sessionOptions) {
|
|
for (uint32_t i = 0; i < epList.Length(); i++) {
|
|
Napi::Value epValue = epList[i];
|
|
std::string name;
|
|
int deviceId = 0;
|
|
#ifdef USE_COREML
|
|
int coreMlFlags = 0;
|
|
#endif
|
|
#ifdef USE_WEBGPU
|
|
std::unordered_map<std::string, std::string> webgpu_options;
|
|
#endif
|
|
if (epValue.IsString()) {
|
|
name = epValue.As<Napi::String>().Utf8Value();
|
|
} else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As<Napi::Object>().Has("name") ||
|
|
!epValue.As<Napi::Object>().Get("name").IsString()) {
|
|
ORT_NAPI_THROW_TYPEERROR(epList.Env(), "Invalid argument: sessionOptions.executionProviders[", i,
|
|
"] must be either a string or an object with property 'name'.");
|
|
} else {
|
|
auto obj = epValue.As<Napi::Object>();
|
|
name = obj.Get("name").As<Napi::String>().Utf8Value();
|
|
if (obj.Has("deviceId")) {
|
|
deviceId = obj.Get("deviceId").As<Napi::Number>();
|
|
}
|
|
#ifdef USE_COREML
|
|
if (obj.Has("coreMlFlags")) {
|
|
coreMlFlags = obj.Get("coreMlFlags").As<Napi::Number>();
|
|
}
|
|
#endif
|
|
#ifdef USE_WEBGPU
|
|
for (const auto& nameIter : obj.GetPropertyNames()) {
|
|
Napi::Value nameVar = nameIter.second;
|
|
std::string name = nameVar.As<Napi::String>().Utf8Value();
|
|
if (name != "name") {
|
|
Napi::Value valueVar = obj.Get(nameVar);
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'.");
|
|
std::string value = valueVar.As<Napi::String>().Utf8Value();
|
|
webgpu_options[name] = value;
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// CPU execution provider
|
|
if (name == "cpu") {
|
|
// TODO: handling CPU EP options
|
|
#ifdef USE_CUDA
|
|
} else if (name == "cuda") {
|
|
OrtCUDAProviderOptionsV2* options;
|
|
Ort::GetApi().CreateCUDAProviderOptions(&options);
|
|
options->device_id = deviceId;
|
|
sessionOptions.AppendExecutionProvider_CUDA_V2(*options);
|
|
Ort::GetApi().ReleaseCUDAProviderOptions(options);
|
|
#endif
|
|
#ifdef USE_TENSORRT
|
|
} else if (name == "tensorrt") {
|
|
OrtTensorRTProviderOptionsV2* options;
|
|
Ort::GetApi().CreateTensorRTProviderOptions(&options);
|
|
options->device_id = deviceId;
|
|
sessionOptions.AppendExecutionProvider_TensorRT_V2(*options);
|
|
Ort::GetApi().ReleaseTensorRTProviderOptions(options);
|
|
#endif
|
|
#ifdef USE_DML
|
|
} else if (name == "dml") {
|
|
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId));
|
|
#endif
|
|
#ifdef USE_WEBGPU
|
|
} else if (name == "webgpu") {
|
|
sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options);
|
|
#endif
|
|
#ifdef USE_COREML
|
|
} else if (name == "coreml") {
|
|
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags));
|
|
#endif
|
|
#ifdef USE_QNN
|
|
} else if (name == "qnn") {
|
|
std::unordered_map<std::string, std::string> qnn_options;
|
|
qnn_options["backend_path"] = "QnnHtp.dll";
|
|
qnn_options["enable_htp_fp16_precision"] = "1";
|
|
sessionOptions.AppendExecutionProvider("QNN", qnn_options);
|
|
#endif
|
|
} else {
|
|
ORT_NAPI_THROW_ERROR(epList.Env(), "Invalid argument: sessionOptions.executionProviders[", i,
|
|
"] is unsupported: '", name, "'.");
|
|
}
|
|
}
|
|
}
|
|
|
|
void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) {
|
|
for (const auto& kvp : obj) {
|
|
auto key = kvp.first.As<Napi::String>().Utf8Value();
|
|
Napi::Value value = kvp.second;
|
|
if (value.IsObject()) {
|
|
IterateExtraOptions(prefix + key + ".", value.As<Napi::Object>(), sessionOptions);
|
|
} else {
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(),
|
|
"Invalid argument: sessionOptions.extra value must be a string in Node.js binding.");
|
|
std::string entry = prefix + key;
|
|
auto val = value.As<Napi::String>().Utf8Value();
|
|
sessionOptions.AddConfigEntry(entry.c_str(), val.c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) {
|
|
// Execution provider
|
|
if (options.Has("executionProviders")) {
|
|
auto epsValue = options.Get("executionProviders");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!epsValue.IsArray(), options.Env(),
|
|
"Invalid argument: sessionOptions.executionProviders must be an array.");
|
|
ParseExecutionProviders(epsValue.As<Napi::Array>(), sessionOptions);
|
|
}
|
|
|
|
// Intra threads number
|
|
if (options.Has("intraOpNumThreads")) {
|
|
auto numValue = options.Get("intraOpNumThreads");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!numValue.IsNumber(), options.Env(),
|
|
"Invalid argument: sessionOptions.intraOpNumThreads must be a number.");
|
|
double num = numValue.As<Napi::Number>().DoubleValue();
|
|
ORT_NAPI_THROW_RANGEERROR_IF(std::floor(num) != num || num < 0 || num > 4294967295, options.Env(),
|
|
"'intraOpNumThreads' is invalid: ", num);
|
|
sessionOptions.SetIntraOpNumThreads(static_cast<int>(num));
|
|
}
|
|
|
|
// Inter threads number
|
|
if (options.Has("interOpNumThreads")) {
|
|
auto numValue = options.Get("interOpNumThreads");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!numValue.IsNumber(), options.Env(),
|
|
"Invalid argument: sessionOptions.interOpNumThreads must be a number.");
|
|
double num = numValue.As<Napi::Number>().DoubleValue();
|
|
ORT_NAPI_THROW_RANGEERROR_IF(std::floor(num) != num || num < 0 || num > 4294967295, options.Env(),
|
|
"'interOpNumThreads' is invalid: ", num);
|
|
sessionOptions.SetInterOpNumThreads(static_cast<int>(num));
|
|
}
|
|
|
|
// Optimization level
|
|
if (options.Has("graphOptimizationLevel")) {
|
|
auto optLevelValue = options.Get("graphOptimizationLevel");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!optLevelValue.IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.graphOptimizationLevel must be a string.");
|
|
auto optLevelString = optLevelValue.As<Napi::String>().Utf8Value();
|
|
auto v = GRAPH_OPT_LEVEL_NAME_TO_ID_MAP.find(optLevelString);
|
|
ORT_NAPI_THROW_TYPEERROR_IF(v == GRAPH_OPT_LEVEL_NAME_TO_ID_MAP.end(), options.Env(),
|
|
"'graphOptimizationLevel' is not supported: ", optLevelString);
|
|
sessionOptions.SetGraphOptimizationLevel(v->second);
|
|
}
|
|
|
|
// CPU memory arena
|
|
if (options.Has("enableCpuMemArena")) {
|
|
auto enableCpuMemArenaValue = options.Get("enableCpuMemArena");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!enableCpuMemArenaValue.IsBoolean(), options.Env(),
|
|
"Invalid argument: sessionOptions.enableCpuMemArena must be a boolean value.");
|
|
if (enableCpuMemArenaValue.As<Napi::Boolean>().Value()) {
|
|
sessionOptions.EnableCpuMemArena();
|
|
} else {
|
|
sessionOptions.DisableCpuMemArena();
|
|
}
|
|
}
|
|
|
|
// memory pattern
|
|
if (options.Has("enableMemPattern")) {
|
|
auto enableMemPatternValue = options.Get("enableMemPattern");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!enableMemPatternValue.IsBoolean(), options.Env(),
|
|
"Invalid argument: sessionOptions.enableMemPattern must be a boolean value.");
|
|
if (enableMemPatternValue.As<Napi::Boolean>().Value()) {
|
|
sessionOptions.EnableMemPattern();
|
|
} else {
|
|
sessionOptions.DisableMemPattern();
|
|
}
|
|
}
|
|
|
|
// optimizedModelFilePath
|
|
if (options.Has("optimizedModelFilePath")) {
|
|
auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.optimizedModelFilePath must be a string.");
|
|
#ifdef _WIN32
|
|
auto str = optimizedModelFilePathValue.As<Napi::String>().Utf16Value();
|
|
std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}};
|
|
#else
|
|
std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As<Napi::String>().Utf8Value()};
|
|
#endif
|
|
sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str());
|
|
}
|
|
|
|
// extra
|
|
if (options.Has("extra")) {
|
|
auto extraValue = options.Get("extra");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(),
|
|
"Invalid argument: sessionOptions.extra must be an object.");
|
|
IterateExtraOptions("", extraValue.As<Napi::Object>(), sessionOptions);
|
|
}
|
|
|
|
// execution mode
|
|
if (options.Has("executionMode")) {
|
|
auto executionModeValue = options.Get("executionMode");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!executionModeValue.IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.executionMode must be a string.");
|
|
auto executionModeString = executionModeValue.As<Napi::String>().Utf8Value();
|
|
auto v = EXECUTION_MODE_NAME_TO_ID_MAP.find(executionModeString);
|
|
ORT_NAPI_THROW_TYPEERROR_IF(v == EXECUTION_MODE_NAME_TO_ID_MAP.end(), options.Env(),
|
|
"'executionMode' is not supported: ", executionModeString);
|
|
sessionOptions.SetExecutionMode(v->second);
|
|
}
|
|
|
|
// log ID
|
|
if (options.Has("logId")) {
|
|
auto logIdValue = options.Get("logId");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!logIdValue.IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.logId must be a string.");
|
|
auto logIdString = logIdValue.As<Napi::String>().Utf8Value();
|
|
sessionOptions.SetLogId(logIdString.c_str());
|
|
}
|
|
|
|
// Log severity level
|
|
if (options.Has("logSeverityLevel")) {
|
|
auto logLevelValue = options.Get("logSeverityLevel");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!logLevelValue.IsNumber(), options.Env(),
|
|
"Invalid argument: sessionOptions.logSeverityLevel must be a number.");
|
|
double logLevelNumber = logLevelValue.As<Napi::Number>().DoubleValue();
|
|
ORT_NAPI_THROW_RANGEERROR_IF(
|
|
std::floor(logLevelNumber) != logLevelNumber || logLevelNumber < 0 || logLevelNumber > 4, options.Env(),
|
|
"Invalid argument: sessionOptions.logSeverityLevel must be one of the following: 0, 1, 2, 3, 4.");
|
|
|
|
sessionOptions.SetLogSeverityLevel(static_cast<int>(logLevelNumber));
|
|
}
|
|
|
|
// Profiling
|
|
if (options.Has("enableProfiling")) {
|
|
auto enableProfilingValue = options.Get("enableProfiling");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(),
|
|
"Invalid argument: sessionOptions.enableProfiling must be a boolean value.");
|
|
|
|
if (enableProfilingValue.As<Napi::Boolean>().Value()) {
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(),
|
|
"Invalid argument: sessionOptions.profileFilePrefix is required"
|
|
" when sessionOptions.enableProfiling is set to true.");
|
|
auto profileFilePrefixValue = options.Get("profileFilePrefix");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.profileFilePrefix must be a string."
|
|
" when sessionOptions.enableProfiling is set to true.");
|
|
#ifdef _WIN32
|
|
auto str = profileFilePrefixValue.As<Napi::String>().Utf16Value();
|
|
std::basic_string<ORTCHAR_T> profileFilePrefix = std::wstring{str.begin(), str.end()};
|
|
#else
|
|
std::basic_string<ORTCHAR_T> profileFilePrefix = profileFilePrefixValue.As<Napi::String>().Utf8Value();
|
|
#endif
|
|
sessionOptions.EnableProfiling(profileFilePrefix.c_str());
|
|
} else {
|
|
sessionOptions.DisableProfiling();
|
|
}
|
|
}
|
|
|
|
// external data
|
|
if (options.Has("externalData")) {
|
|
auto externalDataValue = options.Get("externalData");
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(),
|
|
"Invalid argument: sessionOptions.externalData must be an array.");
|
|
auto externalData = externalDataValue.As<Napi::Array>();
|
|
std::vector<std::basic_string<ORTCHAR_T>> paths;
|
|
std::vector<char*> buffs;
|
|
std::vector<size_t> sizes;
|
|
|
|
for (const auto& kvp : externalData) {
|
|
Napi::Value value = kvp.second;
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(),
|
|
"Invalid argument: sessionOptions.externalData value must be an object in Node.js binding.");
|
|
Napi::Object obj = value.As<Napi::Object>();
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(),
|
|
"Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding.");
|
|
#ifdef _WIN32
|
|
auto path = obj.Get("path").As<Napi::String>().Utf16Value();
|
|
paths.push_back(std::wstring{path.begin(), path.end()});
|
|
#else
|
|
auto path = obj.Get("path").As<Napi::String>().Utf8Value();
|
|
paths.push_back(path);
|
|
#endif
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") ||
|
|
!obj.Get("data").IsBuffer() ||
|
|
!(obj.Get("data").IsTypedArray() && obj.Get("data").As<Napi::TypedArray>().TypedArrayType() == napi_uint8_array),
|
|
options.Env(),
|
|
"Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding.");
|
|
|
|
auto data = obj.Get("data");
|
|
if (data.IsBuffer()) {
|
|
buffs.push_back(data.As<Napi::Buffer<char>>().Data());
|
|
sizes.push_back(data.As<Napi::Buffer<char>>().Length());
|
|
} else {
|
|
auto typedArray = data.As<Napi::TypedArray>();
|
|
buffs.push_back(reinterpret_cast<char*>(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset());
|
|
sizes.push_back(typedArray.ByteLength());
|
|
}
|
|
}
|
|
sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes);
|
|
}
|
|
}
|
|
|
|
void ParsePreferredOutputLocations(const Napi::Object options, const std::vector<std::string>& outputNames, std::vector<int>& preferredOutputLocations) {
|
|
if (options.Has("preferredOutputLocation")) {
|
|
auto polValue = options.Get("preferredOutputLocation");
|
|
if (polValue.IsNull() || polValue.IsUndefined()) {
|
|
return;
|
|
}
|
|
if (polValue.IsString()) {
|
|
DataLocation location = ParseDataLocation(polValue.As<Napi::String>().Utf8Value());
|
|
ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(),
|
|
"Invalid argument: preferredOutputLocation must be an array or a valid string.");
|
|
|
|
if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) {
|
|
preferredOutputLocations.resize(outputNames.size(), location);
|
|
}
|
|
} else if (polValue.IsObject()) {
|
|
preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU);
|
|
|
|
auto pol = polValue.As<Napi::Object>();
|
|
for (const auto& nameIter : pol.GetPropertyNames()) {
|
|
Napi::Value nameVar = nameIter.second;
|
|
std::string name = nameVar.As<Napi::String>().Utf8Value();
|
|
// find the name in outputNames
|
|
auto it = std::find(outputNames.begin(), outputNames.end(), name);
|
|
ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(),
|
|
"Invalid argument: \"", name, "\" is not a valid output name.");
|
|
|
|
Napi::Value value = pol.Get(nameVar);
|
|
DataLocation location = DATA_LOCATION_NONE;
|
|
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As<Napi::String>().Utf8Value())) == DATA_LOCATION_NONE,
|
|
options.Env(),
|
|
"Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string.");
|
|
|
|
size_t index = it - outputNames.begin();
|
|
preferredOutputLocations[index] = location;
|
|
}
|
|
|
|
if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) {
|
|
preferredOutputLocations.clear();
|
|
}
|
|
} else {
|
|
ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string.");
|
|
}
|
|
}
|
|
}
|