mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
support WebGPU EP in Node.js binding (#22660)
### Description This change enhances the Node.js binding with the following features: - support WebGPU EP - lazy initialization of `OrtEnv` - being able to initialize ORT with default log level setting from `ort.env.logLevel`. - session options: - `enableProfiling` and `profileFilePrefix`: support profiling. - `externalData`: explicit external data (optional in Node.js binding) - `optimizedModelFilePath`: allow dumping optimized model for diagnosis purpose - `preferredOutputLocation`: support IO binding. ====================================================== `Tensor.download()` is not implemented in this PR. Build pipeline update is not included in this PR.
This commit is contained in:
parent
6c21ab7337
commit
bd5dbf86fe
10 changed files with 480 additions and 64 deletions
|
|
@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11)
|
|||
|
||||
project (onnxruntime-node)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
add_compile_definitions(NAPI_VERSION=${napi_build_version})
|
||||
add_compile_definitions(ORT_API_MANUAL_INIT)
|
||||
|
|
@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api)
|
|||
|
||||
# optional providers
|
||||
option(USE_DML "Build with DirectML support" OFF)
|
||||
option(USE_WEBGPU "Build with WebGPU support" OFF)
|
||||
option(USE_CUDA "Build with CUDA support" OFF)
|
||||
option(USE_TENSORRT "Build with TensorRT support" OFF)
|
||||
option(USE_COREML "Build with CoreML support" OFF)
|
||||
|
|
@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF)
|
|||
if(USE_DML)
|
||||
add_compile_definitions(USE_DML=1)
|
||||
endif()
|
||||
if(USE_WEBGPU)
|
||||
add_compile_definitions(USE_WEBGPU=1)
|
||||
endif()
|
||||
if(USE_CUDA)
|
||||
add_compile_definitions(USE_CUDA=1)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common';
|
||||
|
||||
import { Binding, binding } from './binding';
|
||||
import { Binding, binding, initOrt } from './binding';
|
||||
|
||||
class OnnxruntimeSessionHandler implements InferenceSessionHandler {
|
||||
#inferenceSession: Binding.InferenceSession;
|
||||
|
||||
constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) {
|
||||
initOrt();
|
||||
|
||||
this.#inferenceSession = new binding.InferenceSession();
|
||||
if (typeof pathOrBuffer === 'string') {
|
||||
this.#inferenceSession.loadModel(pathOrBuffer, options);
|
||||
|
|
@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
|
|||
readonly outputNames: string[];
|
||||
|
||||
startProfiling(): void {
|
||||
// TODO: implement profiling
|
||||
// startProfiling is a no-op.
|
||||
//
|
||||
// if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded.
|
||||
}
|
||||
endProfiling(): void {
|
||||
// TODO: implement profiling
|
||||
this.#inferenceSession.endProfiling();
|
||||
}
|
||||
|
||||
async run(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession, OnnxValue } from 'onnxruntime-common';
|
||||
import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common';
|
||||
|
||||
type SessionOptions = InferenceSession.SessionOptions;
|
||||
type FeedsType = {
|
||||
|
|
@ -28,6 +28,8 @@ export declare namespace Binding {
|
|||
|
||||
run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;
|
||||
|
||||
endProfiling(): void;
|
||||
|
||||
dispose(): void;
|
||||
}
|
||||
|
||||
|
|
@ -48,4 +50,35 @@ export const binding =
|
|||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
InferenceSession: Binding.InferenceSessionConstructor;
|
||||
listSupportedBackends: () => Binding.SupportedBackend[];
|
||||
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void;
|
||||
};
|
||||
|
||||
let ortInitialized = false;
|
||||
export const initOrt = (): void => {
|
||||
if (!ortInitialized) {
|
||||
ortInitialized = true;
|
||||
let logLevel = 2;
|
||||
if (env.logLevel) {
|
||||
switch (env.logLevel) {
|
||||
case 'verbose':
|
||||
logLevel = 0;
|
||||
break;
|
||||
case 'info':
|
||||
logLevel = 1;
|
||||
break;
|
||||
case 'warning':
|
||||
logLevel = 2;
|
||||
break;
|
||||
case 'error':
|
||||
logLevel = 3;
|
||||
break;
|
||||
case 'fatal':
|
||||
logLevel = 4;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unsupported log level: ${env.logLevel}`);
|
||||
}
|
||||
}
|
||||
binding.initOrtOnce(logLevel, Tensor);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator'];
|
|||
const REBUILD = !!buildArgs.rebuild;
|
||||
// --use_dml
|
||||
const USE_DML = !!buildArgs.use_dml;
|
||||
// --use_webgpu
|
||||
const USE_WEBGPU = !!buildArgs.use_webgpu;
|
||||
// --use_cuda
|
||||
const USE_CUDA = !!buildArgs.use_cuda;
|
||||
// --use_tensorrt
|
||||
|
|
@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') {
|
|||
if (USE_DML) {
|
||||
args.push('--CDUSE_DML=ON');
|
||||
}
|
||||
if (USE_WEBGPU) {
|
||||
args.push('--CDUSE_WEBGPU=ON');
|
||||
}
|
||||
if (USE_CUDA) {
|
||||
args.push('--CDUSE_CUDA=ON');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@
|
|||
#include "tensor_helper.h"
|
||||
#include <string>
|
||||
|
||||
Napi::FunctionReference InferenceSessionWrap::constructor;
|
||||
Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor;
|
||||
Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor;
|
||||
|
||||
Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
|
||||
return InferenceSessionWrap::ortTensorConstructor;
|
||||
}
|
||||
|
||||
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
|
||||
#if defined(USE_DML) && defined(_WIN32)
|
||||
|
|
@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
|
|||
Ort::Global<void>::api_ == nullptr, env,
|
||||
"Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
|
||||
"ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library).");
|
||||
auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"};
|
||||
env.SetInstanceData(ortEnv);
|
||||
|
||||
// initialize binding
|
||||
Napi::HandleScope scope(env);
|
||||
|
||||
Napi::Function func = DefineClass(
|
||||
env, "InferenceSession",
|
||||
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
|
||||
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel),
|
||||
InstanceMethod("run", &InferenceSessionWrap::Run),
|
||||
InstanceMethod("dispose", &InferenceSessionWrap::Dispose),
|
||||
InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling),
|
||||
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
|
||||
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});
|
||||
|
||||
constructor = Napi::Persistent(func);
|
||||
constructor.SuppressDestruct();
|
||||
wrappedSessionConstructor = Napi::Persistent(func);
|
||||
wrappedSessionConstructor.SuppressDestruct();
|
||||
exports.Set("InferenceSession", func);
|
||||
|
||||
Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends);
|
||||
exports.Set("listSupportedBackends", listSupportedBackends);
|
||||
|
||||
Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce);
|
||||
exports.Set("initOrtOnce", initOrtOnce);
|
||||
|
||||
return exports;
|
||||
}
|
||||
|
||||
Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) {
|
||||
Napi::Env env = info.Env();
|
||||
Napi::HandleScope scope(env);
|
||||
|
||||
int log_level = info[0].As<Napi::Number>().Int32Value();
|
||||
|
||||
Ort::Env* ortEnv = env.GetInstanceData<Ort::Env>();
|
||||
if (ortEnv == nullptr) {
|
||||
ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"};
|
||||
env.SetInstanceData(ortEnv);
|
||||
}
|
||||
|
||||
Napi::Function tensorConstructor = info[1].As<Napi::Function>();
|
||||
ortTensorConstructor = Napi::Persistent(tensorConstructor);
|
||||
ortTensorConstructor.SuppressDestruct();
|
||||
|
||||
return env.Undefined();
|
||||
}
|
||||
|
||||
InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info)
|
||||
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {}
|
||||
|
||||
|
|
@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
|
|||
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
|
||||
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
}
|
||||
|
||||
// cache preferred output locations
|
||||
ParsePreferredOutputLocations(info[argsLength - 1].As<Napi::Object>(), outputNames_, preferredOutputLocations_);
|
||||
if (preferredOutputLocations_.size() > 0) {
|
||||
ioBinding_ = std::make_unique<Ort::IoBinding>(*session_);
|
||||
}
|
||||
} catch (Napi::Error const& e) {
|
||||
throw e;
|
||||
} catch (std::exception const& e) {
|
||||
|
|
@ -167,7 +201,8 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
|
|||
std::vector<bool> reuseOutput;
|
||||
size_t inputIndex = 0;
|
||||
size_t outputIndex = 0;
|
||||
OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release();
|
||||
Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault};
|
||||
|
||||
try {
|
||||
for (auto& name : inputNames_) {
|
||||
|
|
@ -175,7 +210,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
|
|||
inputIndex++;
|
||||
inputNames_cstr.push_back(name.c_str());
|
||||
auto value = feed.Get(name);
|
||||
inputValues.push_back(NapiValueToOrtValue(env, value, memory_info));
|
||||
inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
|
||||
}
|
||||
}
|
||||
for (auto& name : outputNames_) {
|
||||
|
|
@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
|
|||
outputNames_cstr.push_back(name.c_str());
|
||||
auto value = fetch.Get(name);
|
||||
reuseOutput.push_back(!value.IsNull());
|
||||
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info));
|
||||
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
|
|||
runOptions = Ort::RunOptions{};
|
||||
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
|
||||
}
|
||||
if (preferredOutputLocations_.size() == 0) {
|
||||
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
|
||||
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
|
||||
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
|
||||
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
|
||||
|
||||
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
|
||||
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
|
||||
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
|
||||
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
|
||||
Napi::Object result = Napi::Object::New(env);
|
||||
|
||||
Napi::Object result = Napi::Object::New(env);
|
||||
for (size_t i = 0; i < outputIndex; i++) {
|
||||
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i])));
|
||||
}
|
||||
return scope.Escape(result);
|
||||
} else {
|
||||
// IO binding
|
||||
ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env,
|
||||
"Preferred output locations must have the same size as output names.");
|
||||
|
||||
for (size_t i = 0; i < outputIndex; i++) {
|
||||
result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i]));
|
||||
for (size_t i = 0; i < inputIndex; i++) {
|
||||
ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]);
|
||||
}
|
||||
for (size_t i = 0; i < outputIndex; i++) {
|
||||
// TODO: support preallocated output tensor (outputValues[i])
|
||||
|
||||
if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) {
|
||||
ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo);
|
||||
} else {
|
||||
ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo);
|
||||
}
|
||||
}
|
||||
|
||||
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_);
|
||||
|
||||
auto outputs = ioBinding_->GetOutputValues();
|
||||
ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch.");
|
||||
|
||||
Napi::Object result = Napi::Object::New(env);
|
||||
for (size_t i = 0; i < outputIndex; i++) {
|
||||
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i])));
|
||||
}
|
||||
return scope.Escape(result);
|
||||
}
|
||||
|
||||
return scope.Escape(result);
|
||||
} catch (Napi::Error const& e) {
|
||||
throw e;
|
||||
} catch (std::exception const& e) {
|
||||
|
|
@ -218,6 +281,8 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
|
|||
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
|
||||
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
|
||||
|
||||
this->ioBinding_.reset(nullptr);
|
||||
|
||||
this->defaultRunOptions_.reset(nullptr);
|
||||
this->session_.reset(nullptr);
|
||||
|
||||
|
|
@ -225,6 +290,20 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
|
|||
return env.Undefined();
|
||||
}
|
||||
|
||||
Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) {
|
||||
Napi::Env env = info.Env();
|
||||
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
|
||||
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
|
||||
|
||||
Napi::EscapableHandleScope scope(env);
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
auto filename = session_->EndProfilingAllocated(allocator);
|
||||
Napi::String filenameValue = Napi::String::From(env, filename.get());
|
||||
return scope.Escape(filenameValue);
|
||||
}
|
||||
|
||||
Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) {
|
||||
Napi::Env env = info.Env();
|
||||
Napi::EscapableHandleScope scope(env);
|
||||
|
|
@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo
|
|||
#ifdef USE_DML
|
||||
result.Set(result.Length(), createObject("dml", true));
|
||||
#endif
|
||||
#ifdef USE_WEBGPU
|
||||
result.Set(result.Length(), createObject("webgpu", true));
|
||||
#endif
|
||||
#ifdef USE_CUDA
|
||||
result.Set(result.Length(), createObject("cuda", false));
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -12,9 +12,22 @@
|
|||
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
|
||||
public:
|
||||
static Napi::Object Init(Napi::Env env, Napi::Object exports);
|
||||
static Napi::FunctionReference& GetTensorConstructor();
|
||||
|
||||
InferenceSessionWrap(const Napi::CallbackInfo& info);
|
||||
|
||||
private:
|
||||
/**
|
||||
* [sync] initialize ONNX Runtime once.
|
||||
*
|
||||
* This function must be called before any other functions.
|
||||
*
|
||||
* @param arg0 a number specifying the log level.
|
||||
*
|
||||
* @returns undefined
|
||||
*/
|
||||
static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info);
|
||||
|
||||
/**
|
||||
* [sync] list supported backend list
|
||||
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
|
||||
|
|
@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
|
|||
*/
|
||||
Napi::Value Dispose(const Napi::CallbackInfo& info);
|
||||
|
||||
/**
|
||||
* [sync] end the profiling.
|
||||
* @param nothing
|
||||
* @returns nothing
|
||||
* @throw nothing
|
||||
*/
|
||||
Napi::Value EndProfiling(const Napi::CallbackInfo& info);
|
||||
|
||||
// private members
|
||||
|
||||
// persistent constructor
|
||||
static Napi::FunctionReference constructor;
|
||||
static Napi::FunctionReference wrappedSessionConstructor;
|
||||
static Napi::FunctionReference ortTensorConstructor;
|
||||
|
||||
// session objects
|
||||
bool initialized_;
|
||||
|
|
@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
|
|||
std::vector<std::string> outputNames_;
|
||||
std::vector<ONNXType> outputTypes_;
|
||||
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;
|
||||
|
||||
// preferred output locations
|
||||
std::vector<int> preferredOutputLocations_;
|
||||
std::unique_ptr<Ort::IoBinding> ioBinding_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,15 +6,20 @@
|
|||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <filesystem>
|
||||
|
||||
#include "common.h"
|
||||
#include "session_options_helper.h"
|
||||
#include "tensor_helper.h"
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_provider_options.h"
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#endif
|
||||
#ifdef USE_WEBGPU
|
||||
#include "core/providers/webgpu/webgpu_provider_factory.h"
|
||||
#endif
|
||||
#ifdef USE_TENSORRT
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
#endif
|
||||
|
|
@ -36,7 +41,12 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
|
|||
Napi::Value epValue = epList[i];
|
||||
std::string name;
|
||||
int deviceId = 0;
|
||||
#ifdef USE_COREML
|
||||
int coreMlFlags = 0;
|
||||
#endif
|
||||
#ifdef USE_WEBGPU
|
||||
std::unordered_map<std::string, std::string> webgpu_options;
|
||||
#endif
|
||||
if (epValue.IsString()) {
|
||||
name = epValue.As<Napi::String>().Utf8Value();
|
||||
} else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As<Napi::Object>().Has("name") ||
|
||||
|
|
@ -49,9 +59,23 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
|
|||
if (obj.Has("deviceId")) {
|
||||
deviceId = obj.Get("deviceId").As<Napi::Number>();
|
||||
}
|
||||
#ifdef USE_COREML
|
||||
if (obj.Has("coreMlFlags")) {
|
||||
coreMlFlags = obj.Get("coreMlFlags").As<Napi::Number>();
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_WEBGPU
|
||||
for (const auto& nameIter : obj.GetPropertyNames()) {
|
||||
Napi::Value nameVar = nameIter.second;
|
||||
std::string name = nameVar.As<Napi::String>().Utf8Value();
|
||||
if (name != "name") {
|
||||
Napi::Value valueVar = obj.Get(nameVar);
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'.");
|
||||
std::string value = valueVar.As<Napi::String>().Utf8Value();
|
||||
webgpu_options[name] = value;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// CPU execution provider
|
||||
|
|
@ -77,6 +101,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
|
|||
} else if (name == "dml") {
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId));
|
||||
#endif
|
||||
#ifdef USE_WEBGPU
|
||||
} else if (name == "webgpu") {
|
||||
sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options);
|
||||
#endif
|
||||
#ifdef USE_COREML
|
||||
} else if (name == "coreml") {
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags));
|
||||
|
|
@ -95,6 +123,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
|
|||
}
|
||||
}
|
||||
|
||||
void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) {
|
||||
for (const auto& kvp : obj) {
|
||||
auto key = kvp.first.As<Napi::String>().Utf8Value();
|
||||
Napi::Value value = kvp.second;
|
||||
if (value.IsObject()) {
|
||||
IterateExtraOptions(prefix + key + ".", value.As<Napi::Object>(), sessionOptions);
|
||||
} else {
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(),
|
||||
"Invalid argument: sessionOptions.extra value must be a string in Node.js binding.");
|
||||
std::string entry = prefix + key;
|
||||
auto val = value.As<Napi::String>().Utf8Value();
|
||||
sessionOptions.AddConfigEntry(entry.c_str(), val.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) {
|
||||
// Execution provider
|
||||
if (options.Has("executionProviders")) {
|
||||
|
|
@ -162,6 +206,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio
|
|||
}
|
||||
}
|
||||
|
||||
// optimizedModelFilePath
|
||||
if (options.Has("optimizedModelFilePath")) {
|
||||
auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(),
|
||||
"Invalid argument: sessionOptions.optimizedModelFilePath must be a string.");
|
||||
#ifdef _WIN32
|
||||
auto str = optimizedModelFilePathValue.As<Napi::String>().Utf16Value();
|
||||
std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}};
|
||||
#else
|
||||
std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As<Napi::String>().Utf8Value()};
|
||||
#endif
|
||||
sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str());
|
||||
}
|
||||
|
||||
// extra
|
||||
if (options.Has("extra")) {
|
||||
auto extraValue = options.Get("extra");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(),
|
||||
"Invalid argument: sessionOptions.extra must be an object.");
|
||||
IterateExtraOptions("", extraValue.As<Napi::Object>(), sessionOptions);
|
||||
}
|
||||
|
||||
// execution mode
|
||||
if (options.Has("executionMode")) {
|
||||
auto executionModeValue = options.Get("executionMode");
|
||||
|
|
@ -195,4 +261,118 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio
|
|||
|
||||
sessionOptions.SetLogSeverityLevel(static_cast<int>(logLevelNumber));
|
||||
}
|
||||
|
||||
// Profiling
|
||||
if (options.Has("enableProfiling")) {
|
||||
auto enableProfilingValue = options.Get("enableProfiling");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(),
|
||||
"Invalid argument: sessionOptions.enableProfiling must be a boolean value.");
|
||||
|
||||
if (enableProfilingValue.As<Napi::Boolean>().Value()) {
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(),
|
||||
"Invalid argument: sessionOptions.profileFilePrefix is required"
|
||||
" when sessionOptions.enableProfiling is set to true.");
|
||||
auto profileFilePrefixValue = options.Get("profileFilePrefix");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(),
|
||||
"Invalid argument: sessionOptions.profileFilePrefix must be a string."
|
||||
" when sessionOptions.enableProfiling is set to true.");
|
||||
#ifdef _WIN32
|
||||
auto str = profileFilePrefixValue.As<Napi::String>().Utf16Value();
|
||||
std::basic_string<ORTCHAR_T> profileFilePrefix = std::wstring{str.begin(), str.end()};
|
||||
#else
|
||||
std::basic_string<ORTCHAR_T> profileFilePrefix = profileFilePrefixValue.As<Napi::String>().Utf8Value();
|
||||
#endif
|
||||
sessionOptions.EnableProfiling(profileFilePrefix.c_str());
|
||||
} else {
|
||||
sessionOptions.DisableProfiling();
|
||||
}
|
||||
}
|
||||
|
||||
// external data
|
||||
if (options.Has("externalData")) {
|
||||
auto externalDataValue = options.Get("externalData");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(),
|
||||
"Invalid argument: sessionOptions.externalData must be an array.");
|
||||
auto externalData = externalDataValue.As<Napi::Array>();
|
||||
std::vector<std::basic_string<ORTCHAR_T>> paths;
|
||||
std::vector<char*> buffs;
|
||||
std::vector<size_t> sizes;
|
||||
|
||||
for (const auto& kvp : externalData) {
|
||||
Napi::Value value = kvp.second;
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(),
|
||||
"Invalid argument: sessionOptions.externalData value must be an object in Node.js binding.");
|
||||
Napi::Object obj = value.As<Napi::Object>();
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(),
|
||||
"Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding.");
|
||||
#ifdef _WIN32
|
||||
auto path = obj.Get("path").As<Napi::String>().Utf16Value();
|
||||
paths.push_back(std::wstring{path.begin(), path.end()});
|
||||
#else
|
||||
auto path = obj.Get("path").As<Napi::String>().Utf8Value();
|
||||
paths.push_back(path);
|
||||
#endif
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") ||
|
||||
!obj.Get("data").IsBuffer() ||
|
||||
!(obj.Get("data").IsTypedArray() && obj.Get("data").As<Napi::TypedArray>().TypedArrayType() == napi_uint8_array),
|
||||
options.Env(),
|
||||
"Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding.");
|
||||
|
||||
auto data = obj.Get("data");
|
||||
if (data.IsBuffer()) {
|
||||
buffs.push_back(data.As<Napi::Buffer<char>>().Data());
|
||||
sizes.push_back(data.As<Napi::Buffer<char>>().Length());
|
||||
} else {
|
||||
auto typedArray = data.As<Napi::TypedArray>();
|
||||
buffs.push_back(reinterpret_cast<char*>(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset());
|
||||
sizes.push_back(typedArray.ByteLength());
|
||||
}
|
||||
}
|
||||
sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes);
|
||||
}
|
||||
}
|
||||
|
||||
void ParsePreferredOutputLocations(const Napi::Object options, const std::vector<std::string>& outputNames, std::vector<int>& preferredOutputLocations) {
|
||||
if (options.Has("preferredOutputLocation")) {
|
||||
auto polValue = options.Get("preferredOutputLocation");
|
||||
if (polValue.IsNull() || polValue.IsUndefined()) {
|
||||
return;
|
||||
}
|
||||
if (polValue.IsString()) {
|
||||
DataLocation location = ParseDataLocation(polValue.As<Napi::String>().Utf8Value());
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(),
|
||||
"Invalid argument: preferredOutputLocation must be an array or a valid string.");
|
||||
|
||||
if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) {
|
||||
preferredOutputLocations.resize(outputNames.size(), location);
|
||||
}
|
||||
} else if (polValue.IsObject()) {
|
||||
preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU);
|
||||
|
||||
auto pol = polValue.As<Napi::Object>();
|
||||
for (const auto& nameIter : pol.GetPropertyNames()) {
|
||||
Napi::Value nameVar = nameIter.second;
|
||||
std::string name = nameVar.As<Napi::String>().Utf8Value();
|
||||
// find the name in outputNames
|
||||
auto it = std::find(outputNames.begin(), outputNames.end(), name);
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(),
|
||||
"Invalid argument: \"", name, "\" is not a valid output name.");
|
||||
|
||||
Napi::Value value = pol.Get(nameVar);
|
||||
DataLocation location = DATA_LOCATION_NONE;
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As<Napi::String>().Utf8Value())) == DATA_LOCATION_NONE,
|
||||
options.Env(),
|
||||
"Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string.");
|
||||
|
||||
size_t index = it - outputNames.begin();
|
||||
preferredOutputLocations[index] = location;
|
||||
}
|
||||
|
||||
if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) {
|
||||
preferredOutputLocations.clear();
|
||||
}
|
||||
} else {
|
||||
ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,3 +11,6 @@ struct SessionOptions;
|
|||
|
||||
// parse a Javascript session options object and fill the native SessionOptions object.
|
||||
void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions);
|
||||
|
||||
// parse a Javascript session options object and prepare the preferred output locations.
|
||||
void ParsePreferredOutputLocations(const Napi::Object options, const std::vector<std::string>& outputNames, std::vector<int>& preferredOutputLocations);
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "common.h"
|
||||
#include "tensor_helper.h"
|
||||
#include "inference_session_wrap.h"
|
||||
|
||||
// make sure consistent with origin definition
|
||||
static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime");
|
||||
|
|
@ -100,7 +101,7 @@ const std::unordered_map<std::string, ONNXTensorElementDataType> DATA_TYPE_NAME_
|
|||
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}};
|
||||
|
||||
// currently only support tensor
|
||||
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) {
|
||||
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info) {
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object.");
|
||||
|
||||
// check 'dims'
|
||||
|
|
@ -110,6 +111,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
|
|||
|
||||
auto dimsArray = dimsValue.As<Napi::Array>();
|
||||
auto len = dimsArray.Length();
|
||||
size_t elementSize = 1;
|
||||
std::vector<int64_t> dims;
|
||||
if (len > 0) {
|
||||
dims.reserve(len);
|
||||
|
|
@ -122,17 +124,26 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
|
|||
"Tensor.dims[", i, "] is invalid: ", dimDouble);
|
||||
int64_t dim = static_cast<int64_t>(dimDouble);
|
||||
dims.push_back(dim);
|
||||
elementSize *= dim;
|
||||
}
|
||||
}
|
||||
|
||||
// check 'location'
|
||||
auto tensorLocationValue = tensorObject.Get("location");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!tensorLocationValue.IsString(), env, "Tensor.location must be a string.");
|
||||
DataLocation tensorLocation = ParseDataLocation(tensorLocationValue.As<Napi::String>().Utf8Value());
|
||||
ORT_NAPI_THROW_RANGEERROR_IF(tensorLocation == DATA_LOCATION_NONE, env, "Tensor.location is not supported.");
|
||||
|
||||
// check 'data' and 'type'
|
||||
auto tensorDataValue = tensorObject.Get("data");
|
||||
auto tensorTypeValue = tensorObject.Get("type");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!tensorTypeValue.IsString(), env, "Tensor.type must be a string.");
|
||||
|
||||
auto tensorTypeString = tensorTypeValue.As<Napi::String>().Utf8Value();
|
||||
|
||||
if (tensorTypeString == "string") {
|
||||
auto tensorDataValue = tensorObject.Get("data");
|
||||
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_CPU, env, "Tensor.location must be 'cpu' for string tensors.");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsArray(), env, "Tensor.data must be an array for string tensors.");
|
||||
|
||||
auto tensorDataArray = tensorDataValue.As<Napi::Array>();
|
||||
|
|
@ -162,29 +173,42 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*
|
|||
auto v = DATA_TYPE_NAME_TO_ID_MAP.find(tensorTypeString);
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(v == DATA_TYPE_NAME_TO_ID_MAP.end(), env,
|
||||
"Tensor.type is not supported: ", tensorTypeString);
|
||||
|
||||
ONNXTensorElementDataType elemType = v->second;
|
||||
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env,
|
||||
"Tensor.data must be a typed array for numeric tensor.");
|
||||
if (tensorLocation == DATA_LOCATION_CPU) {
|
||||
auto tensorDataValue = tensorObject.Get("data");
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env,
|
||||
"Tensor.data must be a typed array for numeric tensor.");
|
||||
|
||||
auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>();
|
||||
auto typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType();
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env,
|
||||
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ",
|
||||
tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");
|
||||
auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>();
|
||||
auto typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType();
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env,
|
||||
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ",
|
||||
tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");
|
||||
|
||||
char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data());
|
||||
size_t bufferByteOffset = tensorDataTypedArray.ByteOffset();
|
||||
size_t bufferByteLength = tensorDataTypedArray.ByteLength();
|
||||
return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength,
|
||||
dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
|
||||
char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data());
|
||||
size_t bufferByteOffset = tensorDataTypedArray.ByteOffset();
|
||||
size_t bufferByteLength = tensorDataTypedArray.ByteLength();
|
||||
return Ort::Value::CreateTensor(cpu_memory_info, buffer + bufferByteOffset, bufferByteLength,
|
||||
dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
|
||||
} else {
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_GPU_BUFFER, env, "Tensor.location must be 'gpu-buffer' for IO binding.");
|
||||
|
||||
auto gpuBufferValue = tensorObject.Get("gpuBuffer");
|
||||
// nodejs: tensor.gpuBuffer is no longer a GPUBuffer in nodejs. we assume it is an external object (bind the OrtValue pointer).
|
||||
ORT_NAPI_THROW_TYPEERROR_IF(!gpuBufferValue.IsExternal(), env, "Tensor.gpuBuffer must be an external object.");
|
||||
Ort::Value dataValue(gpuBufferValue.As<Napi::External<OrtValue>>().Data());
|
||||
void* gpuBuffer = dataValue.GetTensorMutableRawData();
|
||||
dataValue.release();
|
||||
|
||||
size_t dataByteLength = DATA_TYPE_ELEMENT_SIZE_MAP[elemType] * elementSize;
|
||||
return Ort::Value::CreateTensor(webgpu_memory_info, gpuBuffer, dataByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
|
||||
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) {
|
||||
Napi::EscapableHandleScope scope(env);
|
||||
auto returnValue = Napi::Object::New(env);
|
||||
|
||||
auto typeInfo = value.GetTypeInfo();
|
||||
auto onnxType = typeInfo.GetONNXType();
|
||||
|
|
@ -197,24 +221,26 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
|
|||
// type
|
||||
auto typeCstr = DATA_TYPE_ID_TO_NAME_MAP[elemType];
|
||||
ORT_NAPI_THROW_ERROR_IF(typeCstr == nullptr, env, "Tensor type (", elemType, ") is not supported.");
|
||||
|
||||
returnValue.Set("type", Napi::String::New(env, typeCstr));
|
||||
auto type = Napi::String::New(env, typeCstr);
|
||||
|
||||
// dims
|
||||
const size_t dimsCount = tensorTypeAndShapeInfo.GetDimensionsCount();
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<int64_t> dimsVector;
|
||||
if (dimsCount > 0) {
|
||||
dims = tensorTypeAndShapeInfo.GetShape();
|
||||
dimsVector = tensorTypeAndShapeInfo.GetShape();
|
||||
}
|
||||
auto dimsArray = Napi::Array::New(env, dimsCount);
|
||||
auto dims = Napi::Array::New(env, dimsCount);
|
||||
for (uint32_t i = 0; i < dimsCount; i++) {
|
||||
dimsArray[i] = dims[i];
|
||||
dims[i] = dimsVector[i];
|
||||
}
|
||||
returnValue.Set("dims", dimsArray);
|
||||
|
||||
// location
|
||||
auto memoryInfo = value.GetTensorMemoryInfo();
|
||||
bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU &&
|
||||
memoryInfo.GetAllocatorName() == "WebGPU_Buffer";
|
||||
|
||||
// size
|
||||
auto size = tensorTypeAndShapeInfo.GetElementCount();
|
||||
returnValue.Set("size", Napi::Number::From(env, size));
|
||||
|
||||
// data
|
||||
if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
|
|
@ -234,20 +260,48 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) {
|
|||
i == size - 1 ? tempBufferLength - tempOffsets[i] : tempOffsets[i + 1] - tempOffsets[i]);
|
||||
}
|
||||
}
|
||||
returnValue.Set("data", Napi::Value(env, stringArray));
|
||||
|
||||
// new Tensor("string", stringArray /* string[] */, dims /* number[] */)
|
||||
return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New({Napi::String::New(env, "string"), stringArray, dims}));
|
||||
} else {
|
||||
// number data
|
||||
// TODO: optimize memory
|
||||
auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
|
||||
if (size > 0) {
|
||||
memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
|
||||
}
|
||||
napi_value typedArrayData;
|
||||
napi_status status =
|
||||
napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData);
|
||||
NAPI_THROW_IF_FAILED(env, status, Napi::Value);
|
||||
returnValue.Set("data", Napi::Value(env, typedArrayData));
|
||||
}
|
||||
if (isGpuBuffer) {
|
||||
// Tensor.fromGpuBuffer(buffer, options)
|
||||
Napi::Function tensorFromGpuBuffer = InferenceSessionWrap::GetTensorConstructor().Value().Get("fromGpuBuffer").As<Napi::Function>();
|
||||
OrtValue* underlyingOrtValue = value.release();
|
||||
|
||||
return scope.Escape(returnValue);
|
||||
auto options = Napi::Object::New(env);
|
||||
options.Set("dataType", type);
|
||||
options.Set("dims", dims);
|
||||
options.Set("dispose", Napi::Function::New(
|
||||
env, [](const Napi::CallbackInfo& info) {
|
||||
Ort::GetApi().ReleaseValue(reinterpret_cast<OrtValue*>(info.Data()));
|
||||
return info.Env().Undefined();
|
||||
},
|
||||
"dispose", underlyingOrtValue));
|
||||
options.Set("download", Napi::Function::New(
|
||||
env, [](const Napi::CallbackInfo& info) {
|
||||
NAPI_THROW("not implemented");
|
||||
},
|
||||
"download", underlyingOrtValue));
|
||||
|
||||
return scope.Escape(tensorFromGpuBuffer.Call({Napi::External<OrtValue>::New(env, underlyingOrtValue), options}));
|
||||
} else {
|
||||
// TODO: optimize memory
|
||||
auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
|
||||
if (size > 0) {
|
||||
memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]);
|
||||
}
|
||||
napi_value typedArrayData;
|
||||
napi_status status =
|
||||
napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData);
|
||||
NAPI_THROW_IF_FAILED(env, status, Napi::Value);
|
||||
|
||||
// new Tensor(type, typedArrayData, dims)
|
||||
return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New(
|
||||
{type,
|
||||
Napi::Value(env, typedArrayData),
|
||||
dims}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,32 @@
|
|||
#include "onnxruntime_cxx_api.h"
|
||||
|
||||
// convert a Javascript OnnxValue object to an OrtValue object
|
||||
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info);
|
||||
Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info);
|
||||
|
||||
// convert an OrtValue object to a Javascript OnnxValue object
|
||||
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value);
|
||||
Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value);
|
||||
|
||||
enum DataLocation {
|
||||
DATA_LOCATION_NONE = 0,
|
||||
DATA_LOCATION_CPU = 1,
|
||||
DATA_LOCATION_CPU_PINNED = 2,
|
||||
DATA_LOCATION_TEXTURE = 3,
|
||||
DATA_LOCATION_GPU_BUFFER = 4,
|
||||
DATA_LOCATION_ML_TENSOR = 5
|
||||
};
|
||||
|
||||
inline DataLocation ParseDataLocation(const std::string& location) {
|
||||
if (location == "cpu") {
|
||||
return DATA_LOCATION_CPU;
|
||||
} else if (location == "cpu-pinned") {
|
||||
return DATA_LOCATION_CPU_PINNED;
|
||||
} else if (location == "texture") {
|
||||
return DATA_LOCATION_TEXTURE;
|
||||
} else if (location == "gpu-buffer") {
|
||||
return DATA_LOCATION_GPU_BUFFER;
|
||||
} else if (location == "ml-tensor") {
|
||||
return DATA_LOCATION_ML_TENSOR;
|
||||
} else {
|
||||
return DATA_LOCATION_NONE;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue