onnxruntime/js/web/lib/wasm/session-handler.ts
Sunghoon da5f24bd2d
Support additional session options and run options in WebAssembly (#7712)
* add all session options and run options in C API except AddInitializer and AddFreeDimensionOverride

* remove unnecessary comment

* change extra session and run options to object notation

* resolve comments

* use an optional chaining for options

* resolve comments
2021-05-17 14:57:19 -07:00

370 lines
12 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {onnx} from 'onnx-proto';
import {env, InferenceSession, SessionHandler, Tensor, TypedTensor} from 'onnxruntime-common';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {getInstance} from './wasm-factory';
let ortInit: boolean;
const tensorDataTypeStringToEnum = (type: string): onnx.TensorProto.DataType => {
switch (type) {
case 'int8':
return onnx.TensorProto.DataType.INT8;
case 'uint8':
return onnx.TensorProto.DataType.UINT8;
case 'bool':
return onnx.TensorProto.DataType.BOOL;
case 'int16':
return onnx.TensorProto.DataType.INT16;
case 'uint16':
return onnx.TensorProto.DataType.UINT16;
case 'int32':
return onnx.TensorProto.DataType.INT32;
case 'uint32':
return onnx.TensorProto.DataType.UINT32;
case 'float32':
return onnx.TensorProto.DataType.FLOAT;
case 'float64':
return onnx.TensorProto.DataType.DOUBLE;
case 'string':
return onnx.TensorProto.DataType.STRING;
case 'int64':
return onnx.TensorProto.DataType.INT64;
case 'uint64':
return onnx.TensorProto.DataType.UINT64;
default:
throw new Error(`unsupported data type: ${type}`);
}
};
const tensorDataTypeEnumToString = (typeProto: onnx.TensorProto.DataType): Tensor.Type => {
switch (typeProto) {
case onnx.TensorProto.DataType.INT8:
return 'int8';
case onnx.TensorProto.DataType.UINT8:
return 'uint8';
case onnx.TensorProto.DataType.BOOL:
return 'bool';
case onnx.TensorProto.DataType.INT16:
return 'int16';
case onnx.TensorProto.DataType.UINT16:
return 'uint16';
case onnx.TensorProto.DataType.INT32:
return 'int32';
case onnx.TensorProto.DataType.UINT32:
return 'uint32';
case onnx.TensorProto.DataType.FLOAT:
return 'float32';
case onnx.TensorProto.DataType.DOUBLE:
return 'float64';
case onnx.TensorProto.DataType.STRING:
return 'string';
case onnx.TensorProto.DataType.INT64:
return 'int32';
case onnx.TensorProto.DataType.UINT64:
return 'uint32';
default:
throw new Error(`unsupported data type: ${onnx.TensorProto.DataType[typeProto]}`);
}
};
const numericTensorTypeToTypedArray = (type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor|
Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|
Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
switch (type) {
case 'float32':
return Float32Array;
case 'uint8':
return Uint8Array;
case 'int8':
return Int8Array;
case 'uint16':
return Uint16Array;
case 'int16':
return Int16Array;
case 'int32':
return Int32Array;
case 'bool':
return Uint8Array;
case 'float64':
return Float64Array;
case 'uint32':
return Uint32Array;
case 'int64':
return BigInt64Array;
case 'uint64':
return BigUint64Array;
default:
throw new Error(`unsupported type: ${type}`);
}
};
const getLogLevel = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): number => {
switch (logLevel) {
case 'verbose':
return 0;
case 'info':
return 1;
case 'warning':
return 2;
case 'error':
return 3;
case 'fatal':
return 4;
default:
throw new Error(`unsupported logging level: ${logLevel}`);
}
};
export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
private sessionHandle: number;
inputNames: string[];
private inputNamesUTF8Encoded: number[];
outputNames: string[];
private outputNamesUTF8Encoded: number[];
loadModel(model: Uint8Array, options?: InferenceSession.SessionOptions): void {
const wasm = getInstance();
if (!ortInit) {
const errorCode = wasm._OrtInit(env.wasm.numThreads!, getLogLevel(env.logLevel!));
if (errorCode !== 0) {
throw new Error(`Can't initialize onnxruntime. error code = ${errorCode}`);
}
ortInit = true;
}
const modelDataOffset = wasm._malloc(model.byteLength);
let sessionOptionsHandle = 0;
let allocs: number[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
wasm.HEAPU8.set(model, modelDataOffset);
this.sessionHandle = wasm._OrtCreateSession(modelDataOffset, model.byteLength, sessionOptionsHandle);
if (this.sessionHandle === 0) {
throw new Error('Can\'t create a session');
}
} finally {
wasm._free(modelDataOffset);
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
allocs.forEach(wasm._free);
}
const inputCount = wasm._OrtGetInputCount(this.sessionHandle);
const outputCount = wasm._OrtGetOutputCount(this.sessionHandle);
this.inputNames = [];
this.inputNamesUTF8Encoded = [];
this.outputNames = [];
this.outputNamesUTF8Encoded = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(this.sessionHandle, i);
if (name === 0) {
throw new Error('Can\'t get an input name');
}
this.inputNamesUTF8Encoded.push(name);
this.inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(this.sessionHandle, i);
if (name === 0) {
throw new Error('Can\'t get an output name');
}
this.outputNamesUTF8Encoded.push(name);
this.outputNames.push(wasm.UTF8ToString(name));
}
}
async dispose(): Promise<void> {
const wasm = getInstance();
if (this.inputNamesUTF8Encoded) {
this.inputNamesUTF8Encoded.forEach(wasm._OrtFree);
this.inputNamesUTF8Encoded = [];
}
if (this.outputNamesUTF8Encoded) {
this.outputNamesUTF8Encoded.forEach(wasm._OrtFree);
this.outputNamesUTF8Encoded = [];
}
if (this.sessionHandle) {
wasm._OrtReleaseSession(this.sessionHandle);
this.sessionHandle = 0;
}
}
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise<SessionHandler.ReturnType> {
const wasm = getInstance();
const inputArray: Tensor[] = [];
const inputIndices: number[] = [];
Object.entries(feeds).forEach(kvp => {
const name = kvp[0];
const tensor = kvp[1];
const index = this.inputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}'`);
}
if (tensor.type === 'string') {
// TODO: support string tensor
throw new TypeError('string tensor is not supported');
}
inputArray.push(tensor);
inputIndices.push(index);
});
const outputIndices: number[] = [];
Object.entries(fetches).forEach(kvp => {
const name = kvp[0];
// TODO: support pre-allocated output
const index = this.outputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid output '${name}'`);
}
outputIndices.push(index);
});
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
let runOptionsHandle = 0;
let allocs: number[] = [];
const inputValues: number[] = [];
const inputDataOffsets: number[] = [];
try {
[runOptionsHandle, allocs] = setRunOptions(options);
// create input tensors
for (let i = 0; i < inputCount; i++) {
const data = inputArray[i].data;
if (Array.isArray(data)) {
// string tensor
throw new TypeError('string tensor is not supported');
} else {
const dataOffset = wasm._malloc(data.byteLength);
inputDataOffsets.push(dataOffset);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength), dataOffset);
const dims = inputArray[i].dims;
const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(inputArray[i].type), dataOffset, data.byteLength, dimsOffset, dims.length);
if (tensor === 0) {
throw new Error('Can\'t create a tensor');
}
inputValues.push(tensor);
} finally {
wasm.stackRestore(stack);
}
}
}
const beforeRunStack = wasm.stackSave();
const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
try {
let inputValuesIndex = inputValuesOffset / 4;
let inputNamesIndex = inputNamesOffset / 4;
let outputValuesIndex = outputValuesOffset / 4;
let outputNamesIndex = outputNamesOffset / 4;
for (let i = 0; i < inputCount; i++) {
wasm.HEAPU32[inputValuesIndex++] = inputValues[i];
wasm.HEAPU32[inputNamesIndex++] = this.inputNamesUTF8Encoded[inputIndices[i]];
}
for (let i = 0; i < outputCount; i++) {
wasm.HEAPU32[outputValuesIndex++] = 0;
wasm.HEAPU32[outputNamesIndex++] = this.outputNamesUTF8Encoded[outputIndices[i]];
}
// support RunOptions
let errorCode = wasm._OrtRun(
this.sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset, runOptionsHandle);
const output: {[name: string]: Tensor} = {};
if (errorCode === 0) {
for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);
try {
errorCode = wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
if (errorCode !== 0) {
throw new Error(`Can't get a tensor data. error code = ${errorCode}`);
}
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
const dataOffset: number = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);
const type = tensorDataTypeEnumToString(dataType);
if (type === 'string') {
// string tensor
throw new TypeError('string tensor is not supported');
} else {
const typedArray = numericTensorTypeToTypedArray(type);
const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b);
const t = new Tensor(type, new typedArray(size), dims) as TypedTensor<Exclude<Tensor.Type, 'string'>>;
new Uint8Array(t.data.buffer, t.data.byteOffset, t.data.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + t.data.byteLength));
output[this.outputNames[outputIndices[i]]] = t;
}
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
wasm._OrtReleaseTensor(tensor);
}
}
}
if (errorCode === 0) {
return output;
} else {
throw new Error(`failed to call OrtRun(). error code = ${errorCode}.`);
}
} finally {
wasm.stackRestore(beforeRunStack);
}
} finally {
inputValues.forEach(wasm._OrtReleaseTensor);
inputDataOffsets.forEach(wasm._free);
wasm._OrtReleaseRunOptions(runOptionsHandle);
allocs.forEach(wasm._free);
}
}
startProfiling(): void {
// TODO: implement profiling
}
endProfiling(): void {
// TODO: implement profiling
}
}