onnxruntime/js/node/src/inference_session_wrap.cc
Yulong Wang 009f342caf
[JS] refactor Javascript/Typescript libraries in ONNX Runtime (#7308)
* working on re-organizing js code for ortweb

* remove dup files

* move folder

* fix common references

* fix common es5

* add webpack to common

* split interfact/impl

* use cjs for node

* add npmignore for common

* update sourcemap config for common

* update node

* adjust folder/path in CI and build

* update folder

* nit: readme

* add bundle for dev

* correct nodejs paths

* enable ORT_API_MANUAL_INIT

* set name for umd library

* correct name for commonjs export

* add priority into registerBackend()

* fix npm ci pwd

* update eslintrc

* revise code

* revert package-lock lockfileVersion 2->1

* update prebuild

* resolve comments

* update document

* revise eslint config

* update eslint for typescript rules

* revert changes by mistake in backend.ts

* add env

* resolve comments
2021-04-16 01:33:10 -07:00

195 lines
7.6 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include "common.h"
#include "inference_session_wrap.h"
#include "run_options_helper.h"
#include "session_options_helper.h"
#include "tensor_helper.h"
Napi::FunctionReference InferenceSessionWrap::constructor;
Ort::Env *InferenceSessionWrap::ortEnv;
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
// create ONNX runtime env
Ort::InitApi();
ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"};
// initialize binding
Napi::HandleScope scope(env);
Napi::Function func = DefineClass(
env, "InferenceSession",
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
exports.Set("InferenceSession", func);
return exports;
}
InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo &info)
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), session_(nullptr),
defaultRunOptions_(nullptr) {}
Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
ORT_NAPI_THROW_ERROR_IF(this->initialized_, env, "Model already loaded. Cannot load model multiple times.");
size_t argsLength = info.Length();
ORT_NAPI_THROW_TYPEERROR_IF(argsLength == 0, env, "Expect argument: model file path or buffer.");
try {
defaultRunOptions_.reset(new Ort::RunOptions{});
Ort::SessionOptions sessionOptions;
if (argsLength == 2 && info[0].IsString() && info[1].IsObject()) {
Napi::String value = info[0].As<Napi::String>();
ParseSessionOptions(info[1].As<Napi::Object>(), sessionOptions);
this->session_.reset(new Ort::Session(OrtEnv(),
#ifdef _WIN32
reinterpret_cast<const wchar_t *>(value.Utf16Value().c_str()),
#else
value.Utf8Value().c_str(),
#endif
sessionOptions));
} else if (argsLength == 4 && info[0].IsArrayBuffer() && info[1].IsNumber() && info[2].IsNumber() &&
info[3].IsObject()) {
void *buffer = info[0].As<Napi::ArrayBuffer>().Data();
int64_t bytesOffset = info[1].As<Napi::Number>().Int64Value();
int64_t bytesLength = info[2].As<Napi::Number>().Int64Value();
ParseSessionOptions(info[1].As<Napi::Object>(), sessionOptions);
this->session_.reset(
new Ort::Session(OrtEnv(), reinterpret_cast<char *>(buffer) + bytesOffset, bytesLength, sessionOptions));
} else {
ORT_NAPI_THROW_TYPEERROR(
env,
"Invalid argument: args has to be either (modelPath, options) or (buffer, byteOffset, byteLength, options).");
}
// cache input/output names and types
Ort::AllocatorWithDefaultOptions allocator;
size_t count = session_->GetInputCount();
inputNames_.reserve(count);
for (size_t i = 0; i < count; i++) {
inputNames_.emplace_back(session_->GetInputName(i, allocator));
auto typeInfo = session_->GetInputTypeInfo(i);
auto onnxType = typeInfo.GetONNXType();
inputTypes_.emplace_back(onnxType);
inputTensorElementDataTypes_.emplace_back(onnxType == ONNX_TYPE_TENSOR
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
}
count = session_->GetOutputCount();
outputNames_.reserve(count);
for (size_t i = 0; i < count; i++) {
outputNames_.emplace_back(session_->GetOutputName(i, allocator));
auto typeInfo = session_->GetOutputTypeInfo(i);
auto onnxType = typeInfo.GetONNXType();
outputTypes_.emplace_back(onnxType);
outputTensorElementDataTypes_.emplace_back(onnxType == ONNX_TYPE_TENSOR
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
}
} catch (Napi::Error const &e) {
throw e;
} catch (std::exception const &e) {
ORT_NAPI_THROW_ERROR(env, e.what());
}
this->initialized_ = true;
return env.Undefined();
}
Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
Napi::EscapableHandleScope scope(env);
return scope.Escape(CreateNapiArrayFrom(env, inputNames_));
}
Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
Napi::EscapableHandleScope scope(env);
return scope.Escape(CreateNapiArrayFrom(env, outputNames_));
}
Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_TYPEERROR_IF(info.Length() < 2, env, "Expect argument: inputs(feed) and outputs(fetch).");
ORT_NAPI_THROW_TYPEERROR_IF(!info[0].IsObject() || !info[1].IsObject(), env,
"Expect inputs(feed) and outputs(fetch) to be objects.");
ORT_NAPI_THROW_TYPEERROR_IF(info.Length() > 2 && (!info[2].IsObject() || info[2].IsNull()), env,
"'runOptions' must be an object.");
Napi::EscapableHandleScope scope(env);
auto feed = info[0].As<Napi::Object>();
auto fetch = info[1].As<Napi::Object>();
std::vector<const char *> inputNames_cstr;
std::vector<Ort::Value> inputValues;
std::vector<const char *> outputNames_cstr;
std::vector<Ort::Value> outputValues;
std::vector<bool> reuseOutput;
size_t inputIndex = 0;
size_t outputIndex = 0;
try {
for (auto &name : inputNames_) {
if (feed.Has(name)) {
inputIndex++;
inputNames_cstr.push_back(name.c_str());
auto value = feed.Get(name);
inputValues.push_back(NapiValueToOrtValue(env, value));
}
}
for (auto &name : outputNames_) {
if (fetch.Has(name)) {
outputIndex++;
outputNames_cstr.push_back(name.c_str());
auto value = fetch.Get(name);
reuseOutput.push_back(!value.IsNull());
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value));
}
}
Ort::RunOptions runOptions{nullptr};
if (info.Length() > 2) {
runOptions = Ort::RunOptions{};
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
}
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
Napi::Object result = Napi::Object::New(env);
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i]));
}
return scope.Escape(result);
} catch (Napi::Error const &e) {
throw e;
} catch (std::exception const &e) {
ORT_NAPI_THROW_ERROR(env, e.what());
}
}