support WebGPU EP in Node.js binding (#22660)

### Description

This change enhances the Node.js binding with the following features:
- support WebGPU EP
- lazy initialization of `OrtEnv`
- being able to initialize ORT with default log level setting from
`ort.env.logLevel`.
- session options:
  - `enableProfiling` and `profileFilePrefix`: support profiling.
  - `externalData`: explicit external data (optional in Node.js binding)
- `optimizedModelFilePath`: allow dumping optimized model for diagnosis
purpose
  - `preferredOutputLocation`: support IO binding.

======================================================
`Tensor.download()` is not implemented in this PR.
Build pipeline update is not included in this PR.
This commit is contained in:
Yulong Wang 2024-11-04 13:09:07 -08:00 committed by GitHub
parent 6c21ab7337
commit bd5dbf86fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 480 additions and 64 deletions

View file

@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11)
project (onnxruntime-node)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
add_compile_definitions(NAPI_VERSION=${napi_build_version})
add_compile_definitions(ORT_API_MANUAL_INIT)
@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api)
# optional providers
option(USE_DML "Build with DirectML support" OFF)
option(USE_WEBGPU "Build with WebGPU support" OFF)
option(USE_CUDA "Build with CUDA support" OFF)
option(USE_TENSORRT "Build with TensorRT support" OFF)
option(USE_COREML "Build with CoreML support" OFF)
@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF)
if(USE_DML)
add_compile_definitions(USE_DML=1)
endif()
if(USE_WEBGPU)
add_compile_definitions(USE_WEBGPU=1)
endif()
if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)
endif()

View file

@ -3,12 +3,14 @@
import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common';
import { Binding, binding } from './binding';
import { Binding, binding, initOrt } from './binding';
class OnnxruntimeSessionHandler implements InferenceSessionHandler {
#inferenceSession: Binding.InferenceSession;
constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) {
initOrt();
this.#inferenceSession = new binding.InferenceSession();
if (typeof pathOrBuffer === 'string') {
this.#inferenceSession.loadModel(pathOrBuffer, options);
@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
readonly outputNames: string[];
startProfiling(): void {
// TODO: implement profiling
// startProfiling is a no-op.
//
// if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded.
}
endProfiling(): void {
// TODO: implement profiling
this.#inferenceSession.endProfiling();
}
async run(

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession, OnnxValue } from 'onnxruntime-common';
import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common';
type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = {
@ -28,6 +28,8 @@ export declare namespace Binding {
run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;
endProfiling(): void;
dispose(): void;
}
@ -48,4 +50,35 @@ export const binding =
// eslint-disable-next-line @typescript-eslint/naming-convention
InferenceSession: Binding.InferenceSessionConstructor;
listSupportedBackends: () => Binding.SupportedBackend[];
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void;
};
let ortInitialized = false;
export const initOrt = (): void => {
if (!ortInitialized) {
ortInitialized = true;
let logLevel = 2;
if (env.logLevel) {
switch (env.logLevel) {
case 'verbose':
logLevel = 0;
break;
case 'info':
logLevel = 1;
break;
case 'warning':
logLevel = 2;
break;
case 'error':
logLevel = 3;
break;
case 'fatal':
logLevel = 4;
break;
default:
throw new Error(`Unsupported log level: ${env.logLevel}`);
}
}
binding.initOrtOnce(logLevel, Tensor);
}
};

View file

@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator'];
const REBUILD = !!buildArgs.rebuild;
// --use_dml
const USE_DML = !!buildArgs.use_dml;
// --use_webgpu
const USE_WEBGPU = !!buildArgs.use_webgpu;
// --use_cuda
const USE_CUDA = !!buildArgs.use_cuda;
// --use_tensorrt
@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') {
if (USE_DML) {
args.push('--CDUSE_DML=ON');
}
if (USE_WEBGPU) {
args.push('--CDUSE_WEBGPU=ON');
}
if (USE_CUDA) {
args.push('--CDUSE_CUDA=ON');
}

View file

@ -11,7 +11,12 @@
#include "tensor_helper.h"
#include <string>
Napi::FunctionReference InferenceSessionWrap::constructor;
Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor;
Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor;
Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
return InferenceSessionWrap::ortTensorConstructor;
}
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
#if defined(USE_DML) && defined(_WIN32)
@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
Ort::Global<void>::api_ == nullptr, env,
"Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
"ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library).");
auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"};
env.SetInstanceData(ortEnv);
// initialize binding
Napi::HandleScope scope(env);
Napi::Function func = DefineClass(
env, "InferenceSession",
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel),
InstanceMethod("run", &InferenceSessionWrap::Run),
InstanceMethod("dispose", &InferenceSessionWrap::Dispose),
InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling),
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
wrappedSessionConstructor = Napi::Persistent(func);
wrappedSessionConstructor.SuppressDestruct();
exports.Set("InferenceSession", func);
Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends);
exports.Set("listSupportedBackends", listSupportedBackends);
Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce);
exports.Set("initOrtOnce", initOrtOnce);
return exports;
}
Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
int log_level = info[0].As<Napi::Number>().Int32Value();
Ort::Env* ortEnv = env.GetInstanceData<Ort::Env>();
if (ortEnv == nullptr) {
ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"};
env.SetInstanceData(ortEnv);
}
Napi::Function tensorConstructor = info[1].As<Napi::Function>();
ortTensorConstructor = Napi::Persistent(tensorConstructor);
ortTensorConstructor.SuppressDestruct();
return env.Undefined();
}
InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info)
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {}
@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
}
// cache preferred output locations
ParsePreferredOutputLocations(info[argsLength - 1].As<Napi::Object>(), outputNames_, preferredOutputLocations_);
if (preferredOutputLocations_.size() > 0) {
ioBinding_ = std::make_unique<Ort::IoBinding>(*session_);
}
} catch (Napi::Error const& e) {
throw e;
} catch (std::exception const& e) {
@ -167,7 +201,8 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
std::vector<bool> reuseOutput;
size_t inputIndex = 0;
size_t outputIndex = 0;
OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release();
Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault};
try {
for (auto& name : inputNames_) {
@ -175,7 +210,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
inputIndex++;
inputNames_cstr.push_back(name.c_str());
auto value = feed.Get(name);
inputValues.push_back(NapiValueToOrtValue(env, value, memory_info));
inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
}
}
for (auto& name : outputNames_) {
@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
outputNames_cstr.push_back(name.c_str());
auto value = fetch.Get(name);
reuseOutput.push_back(!value.IsNull());
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info));
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
}
}
@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
runOptions = Ort::RunOptions{};
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
}
if (preferredOutputLocations_.size() == 0) {
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
Napi::Object result = Napi::Object::New(env);
Napi::Object result = Napi::Object::New(env);
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i])));
}
return scope.Escape(result);
} else {
// IO binding
ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env,
"Preferred output locations must have the same size as output names.");
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i]));
for (size_t i = 0; i < inputIndex; i++) {
ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]);
}
for (size_t i = 0; i < outputIndex; i++) {
// TODO: support preallocated output tensor (outputValues[i])
if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) {
ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo);
} else {
ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo);
}
}
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_);
auto outputs = ioBinding_->GetOutputValues();
ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch.");
Napi::Object result = Napi::Object::New(env);
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i])));
}
return scope.Escape(result);
}
return scope.Escape(result);
} catch (Napi::Error const& e) {
throw e;
} catch (std::exception const& e) {
@ -218,6 +281,8 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
this->ioBinding_.reset(nullptr);
this->defaultRunOptions_.reset(nullptr);
this->session_.reset(nullptr);
@ -225,6 +290,20 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
return env.Undefined();
}
Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
Napi::EscapableHandleScope scope(env);
Ort::AllocatorWithDefaultOptions allocator;
auto filename = session_->EndProfilingAllocated(allocator);
Napi::String filenameValue = Napi::String::From(env, filename.get());
return scope.Escape(filenameValue);
}
Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::EscapableHandleScope scope(env);
@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo
#ifdef USE_DML
result.Set(result.Length(), createObject("dml", true));
#endif
#ifdef USE_WEBGPU
result.Set(result.Length(), createObject("webgpu", true));
#endif
#ifdef USE_CUDA
result.Set(result.Length(), createObject("cuda", false));
#endif

View file

@ -12,9 +12,22 @@
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports);
static Napi::FunctionReference& GetTensorConstructor();
InferenceSessionWrap(const Napi::CallbackInfo& info);
private:
/**
* [sync] initialize ONNX Runtime once.
*
* This function must be called before any other functions.
*
* @param arg0 a number specifying the log level.
*
* @returns undefined
*/
static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info);
/**
* [sync] list supported backend list
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
*/
Napi::Value Dispose(const Napi::CallbackInfo& info);
/**
* [sync] end the profiling.
* @param nothing
* @returns nothing
* @throw nothing
*/
Napi::Value EndProfiling(const Napi::CallbackInfo& info);
// private members
// persistent constructor
static Napi::FunctionReference constructor;
static Napi::FunctionReference wrappedSessionConstructor;
static Napi::FunctionReference ortTensorConstructor;
// session objects
bool initialized_;
@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
std::vector<std::string> outputNames_;
std::vector<ONNXType> outputTypes_;
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;
// preferred output locations
std::vector<int> preferredOutputLocations_;
std::unique_ptr<Ort::IoBinding> ioBinding_;
};

View file

@ -6,15 +6,20 @@
#include <cmath>
#include <unordered_map>
#include <filesystem>
#include "common.h"
#include "session_options_helper.h"
#include "tensor_helper.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
#endif
#ifdef USE_WEBGPU
#include "core/providers/webgpu/webgpu_provider_factory.h"
#endif
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#endif
@ -36,7 +41,12 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
Napi::Value epValue = epList[i];
std::string name;
int deviceId = 0;
#ifdef USE_COREML
int coreMlFlags = 0;
#endif
#ifdef USE_WEBGPU
std::unordered_map<std::string, std::string> webgpu_options;
#endif
if (epValue.IsString()) {
name = epValue.As<Napi::String>().Utf8Value();
} else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As<Napi::Object>().Has("name") ||
@ -49,9 +59,23 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
if (obj.Has("deviceId")) {
deviceId = obj.Get("deviceId").As<Napi::Number>();
}
#ifdef USE_COREML
if (obj.Has("coreMlFlags")) {
coreMlFlags = obj.Get("coreMlFlags").As<Napi::Number>();
}
#endif
#ifdef USE_WEBGPU
for (const auto& nameIter : obj.GetPropertyNames()) {
Napi::Value nameVar = nameIter.second;
std::string name = nameVar.As<Napi::String>().Utf8Value();
if (name != "name") {
Napi::Value valueVar = obj.Get(nameVar);
ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'.");
std::string value = valueVar.As<Napi::String>().Utf8Value();
webgpu_options[name] = value;
}
}
#endif
}
// CPU execution provider
@ -77,6 +101,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
} else if (name == "dml") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId));
#endif
#ifdef USE_WEBGPU
} else if (name == "webgpu") {
sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options);
#endif
#ifdef USE_COREML
} else if (name == "coreml") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags));
@ -95,6 +123,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
}
}
void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) {
for (const auto& kvp : obj) {
auto key = kvp.first.As<Napi::String>().Utf8Value();
Napi::Value value = kvp.second;
if (value.IsObject()) {
IterateExtraOptions(prefix + key + ".", value.As<Napi::Object>(), sessionOptions);
} else {
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(),
"Invalid argument: sessionOptions.extra value must be a string in Node.js binding.");
std::string entry = prefix + key;
auto val = value.As<Napi::String>().Utf8Value();
sessionOptions.AddConfigEntry(entry.c_str(), val.c_str());
}
}
}
void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) {
// Execution provider
if (options.Has("executionProviders")) {
@ -162,6 +206,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio
}
}
// optimizedModelFilePath
if (options.Has("optimizedModelFilePath")) {
auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath");
ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(),
"Invalid argument: sessionOptions.optimizedModelFilePath must be a string.");
#ifdef _WIN32
auto str = optimizedModelFilePathValue.As<Napi::String>().Utf16Value();
std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}};
#else
std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As<Napi::String>().Utf8Value()};
#endif
sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str());
}
// extra
if (options.Has("extra")) {
auto extraValue = options.Get("extra");
ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(),
"Invalid argument: sessionOptions.extra must be an object.");
IterateExtraOptions("", extraValue.As<Napi::Object>(), sessionOptions);
}
// execution mode
if (options.Has("executionMode")) {
auto executionModeValue = options.Get("executionMode");
@ -195,4 +261,118 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio
sessionOptions.SetLogSeverityLevel(static_cast<int>(logLevelNumber));
}
// Profiling
if (options.Has("enableProfiling")) {
auto enableProfilingValue = options.Get("enableProfiling");
ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(),
"Invalid argument: sessionOptions.enableProfiling must be a boolean value.");
if (enableProfilingValue.As<Napi::Boolean>().Value()) {
ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(),
"Invalid argument: sessionOptions.profileFilePrefix is required"
" when sessionOptions.enableProfiling is set to true.");
auto profileFilePrefixValue = options.Get("profileFilePrefix");
ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(),
"Invalid argument: sessionOptions.profileFilePrefix must be a string."
" when sessionOptions.enableProfiling is set to true.");
#ifdef _WIN32
auto str = profileFilePrefixValue.As<Napi::String>().Utf16Value();
std::basic_string<ORTCHAR_T> profileFilePrefix = std::wstring{str.begin(), str.end()};
#else
std::basic_string<ORTCHAR_T> profileFilePrefix = profileFilePrefixValue.As<Napi::String>().Utf8Value();
#endif
sessionOptions.EnableProfiling(profileFilePrefix.c_str());
} else {
sessionOptions.DisableProfiling();
}
}
// external data
if (options.Has("externalData")) {
auto externalDataValue = options.Get("externalData");
ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(),
"Invalid argument: sessionOptions.externalData must be an array.");
auto externalData = externalDataValue.As<Napi::Array>();
std::vector<std::basic_string<ORTCHAR_T>> paths;
std::vector<char*> buffs;
std::vector<size_t> sizes;
for (const auto& kvp : externalData) {
Napi::Value value = kvp.second;
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(),
"Invalid argument: sessionOptions.externalData value must be an object in Node.js binding.");
Napi::Object obj = value.As<Napi::Object>();
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(),
"Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding.");
#ifdef _WIN32
auto path = obj.Get("path").As<Napi::String>().Utf16Value();
paths.push_back(std::wstring{path.begin(), path.end()});
#else
auto path = obj.Get("path").As<Napi::String>().Utf8Value();
paths.push_back(path);
#endif
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") ||
!obj.Get("data").IsBuffer() ||
!(obj.Get("data").IsTypedArray() && obj.Get("data").As<Napi::TypedArray>().TypedArrayType() == napi_uint8_array),
options.Env(),
"Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding.");
auto data = obj.Get("data");
if (data.IsBuffer()) {
buffs.push_back(data.As<Napi::Buffer<char>>().Data());
sizes.push_back(data.As<Napi::Buffer<char>>().Length());
} else {
auto typedArray = data.As<Napi::TypedArray>();
buffs.push_back(reinterpret_cast<char*>(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset());
sizes.push_back(typedArray.ByteLength());
}
}
sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes);
}
}
void ParsePreferredOutputLocations(const Napi::Object options, const std::vector<std::string>& outputNames, std::vector<int>& preferredOutputLocations) {
if (options.Has("preferredOutputLocation")) {
auto polValue = options.Get("preferredOutputLocation");
if (polValue.IsNull() || polValue.IsUndefined()) {
return;
}
if (polValue.IsString()) {
DataLocation location = ParseDataLocation(polValue.As<Napi::String>().Utf8Value());
ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(),
"Invalid argument: preferredOutputLocation must be an array or a valid string.");
if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) {
preferredOutputLocations.resize(outputNames.size(), location);
}
} else if (polValue.IsObject()) {
preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU);
auto pol = polValue.As<Napi::Object>();
for (const auto& nameIter : pol.GetPropertyNames()) {
Napi::Value nameVar = nameIter.second;
std::string name = nameVar.As<Napi::String>().Utf8Value();
// find the name in outputNames
auto it = std::find(outputNames.begin(), outputNames.end(), name);
ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(),
"Invalid argument: \"", name, "\" is not a valid output name.");
Napi::Value value = pol.Get(nameVar);
DataLocation location = DATA_LOCATION_NONE;
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As<Napi::String>().Utf8Value())) == DATA_LOCATION_NONE,
options.Env(),
"Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string.");
size_t index = it - outputNames.begin();
preferredOutputLocations[index] = location;
}
if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) {
preferredOutputLocations.clear();
}
} else {
ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string.");
}
}
}

View file

@ -11,3 +11,6 @@ struct SessionOptions;
// parse a Javascript session options object and fill the native SessionOptions object.
void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions);
// parse a Javascript session options object and prepare the preferred output locations.
void ParsePreferredOutputLocations(const Napi::Object options, const std::vector<std::string>& outputNames, std::vector<int>& preferredOutputLocations);

View file

@ -8,6 +8,7 @@
#include "common.h"
#include "tensor_helper.h"
#include "inference_session_wrap.h"
// make sure consistent with origin definition
static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime");
@ -100,7 +101,7 @@ const std::unordered_map<std::string, ONNXTensorElementDataType> DATA_TYPE_NAME_
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}};
// currently only support tensor
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) {
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info) {
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object.");
// check 'dims'
@ -110,6 +111,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
auto dimsArray = dimsValue.As<Napi::Array>();
auto len = dimsArray.Length();
size_t elementSize = 1;
std::vector<int64_t> dims;
if (len > 0) {
dims.reserve(len);
@ -122,17 +124,26 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
"Tensor.dims[", i, "] is invalid: ", dimDouble);
int64_t dim = static_cast<int64_t>(dimDouble);
dims.push_back(dim);
elementSize *= dim;
}
}
// check 'location'
auto tensorLocationValue = tensorObject.Get("location");
ORT_NAPI_THROW_TYPEERROR_IF(!tensorLocationValue.IsString(), env, "Tensor.location must be a string.");
DataLocation tensorLocation = ParseDataLocation(tensorLocationValue.As<Napi::String>().Utf8Value());
ORT_NAPI_THROW_RANGEERROR_IF(tensorLocation == DATA_LOCATION_NONE, env, "Tensor.location is not supported.");
// check 'data' and 'type'
auto tensorDataValue = tensorObject.Get("data");
auto tensorTypeValue = tensorObject.Get("type");
ORT_NAPI_THROW_TYPEERROR_IF(!tensorTypeValue.IsString(), env, "Tensor.type must be a string.");
auto tensorTypeString = tensorTypeValue.As<Napi::String>().Utf8Value();
if (tensorTypeString == "string") {
auto tensorDataValue = tensorObject.Get("data");
ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_CPU, env, "Tensor.location must be 'cpu' for string tensors.");
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsArray(), env, "Tensor.data must be an array for string tensors.");
auto tensorDataArray = tensorDataValue.As<Napi::Array>();
@ -162,29 +173,42 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
auto v = DATA_TYPE_NAME_TO_ID_MAP.find(tensorTypeString);
ORT_NAPI_THROW_TYPEERROR_IF(v == DATA_TYPE_NAME_TO_ID_MAP.end(), env,
"Tensor.type is not supported: ", tensorTypeString);
ONNXTensorElementDataType elemType = v->second;
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env,
"Tensor.data must be a typed array for numeric tensor.");
if (tensorLocation == DATA_LOCATION_CPU) {
auto tensorDataValue = tensorObject.Get("data");
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env,
"Tensor.data must be a typed array for numeric tensor.");
auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>();
auto typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType();
ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env,
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ",
tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");
auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>();
auto typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType();
ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env,
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ",
tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");
char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data());
size_t bufferByteOffset = tensorDataTypedArray.ByteOffset();
size_t bufferByteLength = tensorDataTypedArray.ByteLength();
return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength,
dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data());
size_t bufferByteOffset = tensorDataTypedArray.ByteOffset();
size_t bufferByteLength = tensorDataTypedArray.ByteLength();
return Ort::Value::CreateTensor(cpu_memory_info, buffer + bufferByteOffset, bufferByteLength,
dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
} else {
ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_GPU_BUFFER, env, "Tensor.location must be 'gpu-buffer' for IO binding.");
auto gpuBufferValue = tensorObject.Get("gpuBuffer");
// nodejs: tensor.gpuBuffer is no longer a GPUBuffer in nodejs. we assume it is an external object (bind the OrtValue pointer).
ORT_NAPI_THROW_TYPEERROR_IF(!gpuBufferValue.IsExternal(), env, "Tensor.gpuBuffer must be an external object.");
Ort::Value dataValue(gpuBufferValue.As<Napi::External<OrtValue>>().Data());
void* gpuBuffer = dataValue.GetTensorMutableRawData();
dataValue.release();
size_t dataByteLength = DATA_TYPE_ELEMENT_SIZE_MAP[elemType] * elementSize;
return Ort::Value::CreateTensor(webgpu_memory_info, gpuBuffer, dataByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
}
}
}
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) {
Napi::EscapableHandleScope scope(env);
auto returnValue = Napi::Object::New(env);
auto typeInfo = value.GetTypeInfo();
auto onnxType = typeInfo.GetONNXType();
@ -197,24 +221,26 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
// type
auto typeCstr = DATA_TYPE_ID_TO_NAME_MAP[elemType];
ORT_NAPI_THROW_ERROR_IF(typeCstr == nullptr, env, "Tensor type (", elemType, ") is not supported.");
returnValue.Set("type", Napi::String::New(env, typeCstr));
auto type = Napi::String::New(env, typeCstr);
// dims
const size_t dimsCount = tensorTypeAndShapeInfo.GetDimensionsCount();
std::vector<int64_t> dims;
std::vector<int64_t> dimsVector;
if (dimsCount > 0) {
dims = tensorTypeAndShapeInfo.GetShape();
dimsVector = tensorTypeAndShapeInfo.GetShape();
}
auto dimsArray = Napi::Array::New(env, dimsCount);
auto dims = Napi::Array::New(env, dimsCount);
for (uint32_t i = 0; i < dimsCount; i++) {
dimsArray[i] = dims[i];
dims[i] = dimsVector[i];
}
returnValue.Set("dims", dimsArray);
// location
auto memoryInfo = value.GetTensorMemoryInfo();
bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU &&
memoryInfo.GetAllocatorName() == "WebGPU_Buffer";
// size
auto size = tensorTypeAndShapeInfo.GetElementCount();
returnValue.Set("size", Napi::Number::From(env, size));
// data
if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
@ -234,20 +260,48 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
i == size - 1 ? tempBufferLength - tempOffsets[i] : tempOffsets[i + 1] - tempOffsets[i]);
}
}
returnValue.Set("data", Napi::Value(env, stringArray));
// new Tensor("string", stringArray /* string[] */, dims /* number[] */)
return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New({Napi::String::New(env, "string"), stringArray, dims}));
} else {
// number data
// TODO: optimize memory
auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
if (size > 0) {
memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
}
napi_value typedArrayData;
napi_status status =
napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData);
NAPI_THROW_IF_FAILED(env, status, Napi::Value);
returnValue.Set("data", Napi::Value(env, typedArrayData));
}
if (isGpuBuffer) {
// Tensor.fromGpuBuffer(buffer, options)
Napi::Function tensorFromGpuBuffer = InferenceSessionWrap::GetTensorConstructor().Value().Get("fromGpuBuffer").As<Napi::Function>();
OrtValue* underlyingOrtValue = value.release();
return scope.Escape(returnValue);
auto options = Napi::Object::New(env);
options.Set("dataType", type);
options.Set("dims", dims);
options.Set("dispose", Napi::Function::New(
env, [](const Napi::CallbackInfo& info) {
Ort::GetApi().ReleaseValue(reinterpret_cast<OrtValue*>(info.Data()));
return info.Env().Undefined();
},
"dispose", underlyingOrtValue));
options.Set("download", Napi::Function::New(
env, [](const Napi::CallbackInfo& info) {
NAPI_THROW("not implemented");
},
"download", underlyingOrtValue));
return scope.Escape(tensorFromGpuBuffer.Call({Napi::External<OrtValue>::New(env, underlyingOrtValue), options}));
} else {
// TODO: optimize memory
auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
if (size > 0) {
memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
}
napi_value typedArrayData;
napi_status status =
napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData);
NAPI_THROW_IF_FAILED(env, status, Napi::Value);
// new Tensor(type, typedArrayData, dims)
return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New(
{type,
Napi::Value(env, typedArrayData),
dims}));
}
}
}

View file

@ -9,7 +9,32 @@
#include "onnxruntime_cxx_api.h"
// convert a Javascript OnnxValue object to an OrtValue object
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info);
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info);
// convert an OrtValue object to a Javascript OnnxValue object
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value);
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value);
enum DataLocation {
DATA_LOCATION_NONE = 0,
DATA_LOCATION_CPU = 1,
DATA_LOCATION_CPU_PINNED = 2,
DATA_LOCATION_TEXTURE = 3,
DATA_LOCATION_GPU_BUFFER = 4,
DATA_LOCATION_ML_TENSOR = 5
};
inline DataLocation ParseDataLocation(const std::string& location) {
if (location == "cpu") {
return DATA_LOCATION_CPU;
} else if (location == "cpu-pinned") {
return DATA_LOCATION_CPU_PINNED;
} else if (location == "texture") {
return DATA_LOCATION_TEXTURE;
} else if (location == "gpu-buffer") {
return DATA_LOCATION_GPU_BUFFER;
} else if (location == "ml-tensor") {
return DATA_LOCATION_ML_TENSOR;
} else {
return DATA_LOCATION_NONE;
}
}