onnxruntime/nodejs/src/inference_session_wrap.h
Yulong Wang 5dfc91db51
Node.js binding for ONNX Runtime (#3613)
* initial commit for Node.js binding

* add c++ code

* add inference session impl

* e2e working

* add settings.json

* add test data

* adjust binding declaration

* refine tensor constructor declaration

* update tests

* enable onnx tests

* simply refine readme

* refine cpp impl

* refine tests

* formatting

* add linting

* move bin folder

* fix linux build

* manually update test filter list

* update C++ API headers: fix crash in release build

* make (manually) prebuild work

* add test into prepack script

* specify prebuild runtime type (N-API)

* build.ts: update rebuild and include regex

* fix lazy load on electron.js

* update dev version, git link and binary host

* support session options and run options

* bump dev version

* update README

* add 1 example

* move folder

* adjust path

* update document for examples

* rename example 01

* add example 02

* add session option: log severity level

* add example 04

* resolve comments

* fix typo

* remove double guard in header files

* add copyright banner

* move BUILD outside from README

* consume test filter list from onnxruntime
2020-05-05 11:45:12 -07:00

73 lines
2.1 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnxruntime_cxx_api.h"
#include <memory>
#include <napi.h>
// class InferenceSessionWrap is a N-API object wrapper for native InferenceSession.
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports);
InferenceSessionWrap(const Napi::CallbackInfo &info);
private:
/**
* [sync] create the session.
* @param arg0 either a string (file path) or a Uint8Array
* @returns nothing
* @throw error if status code != 0
*/
Napi::Value LoadModel(const Napi::CallbackInfo &info);
// following functions have to be called after model is loaded.
/**
* [sync] get input names.
* @param nothing
* @returns a string array.
* @throw nothing
*/
Napi::Value GetInputNames(const Napi::CallbackInfo &info);
/**
* [sync] get output names.
* @param nothing
* @returns a string array.
* @throw nothing
*/
Napi::Value GetOutputNames(const Napi::CallbackInfo &info);
/**
* [sync] run the model.
* @param arg0 input object: all keys must present, value is object
* @param arg1 output object: at least one key must present, value can be null.
* @returns an object that every output specified will present and value must be object
* @throw error if status code != 0
*/
Napi::Value Run(const Napi::CallbackInfo &info);
// private members
// persistent constructor
static Napi::FunctionReference constructor;
// global env
static Ort::Env *ortEnv;
static Ort::Env &OrtEnv() { return *ortEnv; }
// session objects
bool initialized_;
std::unique_ptr<Ort::Session> session_;
std::unique_ptr<Ort::RunOptions> defaultRunOptions_;
// input/output metadata
std::vector<std::string> inputNames_;
std::vector<ONNXType> inputTypes_;
std::vector<ONNXTensorElementDataType> inputTensorElementDataTypes_;
std::vector<std::string> outputNames_;
std::vector<ONNXType> outputTypes_;
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;
};