mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
### Description This change enhances the Node.js binding with the following features: - support WebGPU EP - lazy initialization of `OrtEnv` - being able to initialize ORT with default log level setting from `ort.env.logLevel`. - session options: - `enableProfiling` and `profileFilePrefix`: support profiling. - `externalData`: explicit external data (optional in Node.js binding) - `optimizedModelFilePath`: allow dumping optimized model for diagnosis purpose - `preferredOutputLocation`: support IO binding. ====================================================== `Tensor.download()` is not implemented in this PR. Build pipeline update is not included in this PR.
110 lines
3.1 KiB
C++
110 lines
3.1 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "onnxruntime_cxx_api.h"
|
|
|
|
#include <memory>
|
|
#include <napi.h>
|
|
|
|
// class InferenceSessionWrap is a N-API object wrapper for native InferenceSession.
|
|
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
|
|
public:
|
|
static Napi::Object Init(Napi::Env env, Napi::Object exports);
|
|
static Napi::FunctionReference& GetTensorConstructor();
|
|
|
|
InferenceSessionWrap(const Napi::CallbackInfo& info);
|
|
|
|
private:
|
|
/**
|
|
* [sync] initialize ONNX Runtime once.
|
|
*
|
|
* This function must be called before any other functions.
|
|
*
|
|
* @param arg0 a number specifying the log level.
|
|
*
|
|
* @returns undefined
|
|
*/
|
|
static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info);
|
|
|
|
/**
|
|
* [sync] list supported backend list
|
|
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
|
|
*/
|
|
static Napi::Value ListSupportedBackends(const Napi::CallbackInfo& info);
|
|
|
|
/**
|
|
* [sync] create the session.
|
|
* @param arg0 either a string (file path) or a Uint8Array
|
|
* @returns nothing
|
|
* @throw error if status code != 0
|
|
*/
|
|
Napi::Value LoadModel(const Napi::CallbackInfo& info);
|
|
|
|
// following functions have to be called after model is loaded.
|
|
|
|
/**
|
|
* [sync] get input names.
|
|
* @param nothing
|
|
* @returns a string array.
|
|
* @throw nothing
|
|
*/
|
|
Napi::Value GetInputNames(const Napi::CallbackInfo& info);
|
|
/**
|
|
* [sync] get output names.
|
|
* @param nothing
|
|
* @returns a string array.
|
|
* @throw nothing
|
|
*/
|
|
Napi::Value GetOutputNames(const Napi::CallbackInfo& info);
|
|
|
|
/**
|
|
* [sync] run the model.
|
|
* @param arg0 input object: all keys must present, value is object
|
|
* @param arg1 output object: at least one key must present, value can be null.
|
|
* @returns an object that every output specified will present and value must be object
|
|
* @throw error if status code != 0
|
|
*/
|
|
Napi::Value Run(const Napi::CallbackInfo& info);
|
|
|
|
/**
|
|
* [sync] dispose the session.
|
|
* @param nothing
|
|
* @returns nothing
|
|
* @throw nothing
|
|
*/
|
|
Napi::Value Dispose(const Napi::CallbackInfo& info);
|
|
|
|
/**
|
|
* [sync] end the profiling.
|
|
* @param nothing
|
|
* @returns nothing
|
|
* @throw nothing
|
|
*/
|
|
Napi::Value EndProfiling(const Napi::CallbackInfo& info);
|
|
|
|
// private members
|
|
|
|
// persistent constructor
|
|
static Napi::FunctionReference wrappedSessionConstructor;
|
|
static Napi::FunctionReference ortTensorConstructor;
|
|
|
|
// session objects
|
|
bool initialized_;
|
|
bool disposed_;
|
|
std::unique_ptr<Ort::Session> session_;
|
|
std::unique_ptr<Ort::RunOptions> defaultRunOptions_;
|
|
|
|
// input/output metadata
|
|
std::vector<std::string> inputNames_;
|
|
std::vector<ONNXType> inputTypes_;
|
|
std::vector<ONNXTensorElementDataType> inputTensorElementDataTypes_;
|
|
std::vector<std::string> outputNames_;
|
|
std::vector<ONNXType> outputTypes_;
|
|
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;
|
|
|
|
// preferred output locations
|
|
std::vector<int> preferredOutputLocations_;
|
|
std::unique_ptr<Ort::IoBinding> ioBinding_;
|
|
};
|