mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
* align ios version with onnxruntime-mobile-c * support 'file://' in iOS * fix lint error
166 lines
5.2 KiB
TypeScript
166 lines
5.2 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Buffer} from 'buffer';
|
|
import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common';
|
|
import {Platform} from 'react-native';
|
|
|
|
import {binding, Binding} from './binding';
|
|
|
|
type SupportedTypedArray = Exclude<Tensor.DataType, string[]>;
|
|
|
|
const tensorTypeToTypedArray = (type: Tensor.Type):|Float32ArrayConstructor|Int8ArrayConstructor|Int16ArrayConstructor|
|
|
Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor => {
|
|
switch (type) {
|
|
case 'float32':
|
|
return Float32Array;
|
|
case 'int8':
|
|
return Int8Array;
|
|
case 'int16':
|
|
return Int16Array;
|
|
case 'int32':
|
|
return Int32Array;
|
|
case 'bool':
|
|
return Int8Array;
|
|
case 'float64':
|
|
return Float64Array;
|
|
case 'int64':
|
|
/* global BigInt64Array */
|
|
/* eslint no-undef: ["error", { "typeof": true }] */
|
|
return BigInt64Array;
|
|
default:
|
|
throw new Error(`unsupported type: ${type}`);
|
|
}
|
|
};
|
|
|
|
const normalizePath = (path: string): string => {
|
|
// remove 'file://' prefix in iOS
|
|
if (Platform.OS === 'ios' && path.toLowerCase().startsWith('file://')) {
|
|
return path.substring(7);
|
|
}
|
|
|
|
return path;
|
|
};
|
|
|
|
class OnnxruntimeSessionHandler implements SessionHandler {
|
|
#inferenceSession: Binding.InferenceSession;
|
|
#key: string;
|
|
|
|
inputNames: string[];
|
|
outputNames: string[];
|
|
|
|
constructor(path: string) {
|
|
this.#inferenceSession = binding;
|
|
this.#key = normalizePath(path);
|
|
this.inputNames = [];
|
|
this.outputNames = [];
|
|
}
|
|
|
|
async loadModel(options: InferenceSession.SessionOptions): Promise<void> {
|
|
try {
|
|
// load a model
|
|
const results: Binding.ModelLoadInfoType = await this.#inferenceSession.loadModel(this.#key, options);
|
|
// resolve promise if onnxruntime session is successfully created
|
|
if (results.key !== this.#key) {
|
|
throw new Error('Session key is invalid');
|
|
}
|
|
|
|
this.inputNames = results.inputNames;
|
|
this.outputNames = results.outputNames;
|
|
} catch (e) {
|
|
throw new Error(`Can't load a model: ${(e as Error).message}`);
|
|
}
|
|
}
|
|
|
|
async dispose(): Promise<void> {
|
|
return Promise.resolve();
|
|
}
|
|
|
|
startProfiling(): void {
|
|
// TODO: implement profiling
|
|
}
|
|
endProfiling(): void {
|
|
// TODO: implement profiling
|
|
}
|
|
|
|
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
|
|
Promise<SessionHandler.ReturnType> {
|
|
const outputNames: Binding.FetchesType = [];
|
|
for (const name in fetches) {
|
|
if (Object.prototype.hasOwnProperty.call(fetches, name)) {
|
|
if (fetches[name]) {
|
|
throw new Error(
|
|
'Preallocated output is not supported and only names as string array is allowed as parameter');
|
|
}
|
|
outputNames.push(name);
|
|
}
|
|
}
|
|
const input = this.encodeFeedsType(feeds);
|
|
const results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options);
|
|
const output = this.decodeReturnType(results);
|
|
return output;
|
|
}
|
|
|
|
encodeFeedsType(feeds: SessionHandler.FeedsType): Binding.FeedsType {
|
|
const returnValue: {[name: string]: Binding.EncodedTensorType} = {};
|
|
for (const key in feeds) {
|
|
if (Object.hasOwnProperty.call(feeds, key)) {
|
|
let data: string|string[];
|
|
|
|
if (Array.isArray(feeds[key].data)) {
|
|
data = feeds[key].data as string[];
|
|
} else {
|
|
// Base64-encode tensor data
|
|
const buffer = (feeds[key].data as SupportedTypedArray).buffer;
|
|
data = Buffer.from(buffer, 0, buffer.byteLength).toString('base64');
|
|
}
|
|
|
|
returnValue[key] = {
|
|
dims: feeds[key].dims,
|
|
type: feeds[key].type,
|
|
data,
|
|
};
|
|
}
|
|
}
|
|
return returnValue;
|
|
}
|
|
|
|
decodeReturnType(results: Binding.ReturnType): SessionHandler.ReturnType {
|
|
const returnValue: SessionHandler.ReturnType = {};
|
|
|
|
for (const key in results) {
|
|
if (Object.hasOwnProperty.call(results, key)) {
|
|
let tensorData: Tensor.DataType;
|
|
if (Array.isArray(results[key].data)) {
|
|
tensorData = results[key].data as string[];
|
|
} else {
|
|
const buffer: Buffer = Buffer.from(results[key].data as string, 'base64');
|
|
const typedArray = tensorTypeToTypedArray(results[key].type as Tensor.Type);
|
|
tensorData = new typedArray(buffer.buffer, buffer.byteOffset, buffer.length / typedArray.BYTES_PER_ELEMENT);
|
|
}
|
|
|
|
returnValue[key] = new Tensor(results[key].type as Tensor.Type, tensorData, results[key].dims);
|
|
}
|
|
}
|
|
|
|
return returnValue;
|
|
}
|
|
}
|
|
|
|
class OnnxruntimeBackend implements Backend {
|
|
async init(): Promise<void> {
|
|
return Promise.resolve();
|
|
}
|
|
|
|
async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
|
|
Promise<SessionHandler> {
|
|
if (typeof pathOrBuffer !== 'string') {
|
|
throw new Error('Uint8Array is not supported');
|
|
}
|
|
const handler = new OnnxruntimeSessionHandler(pathOrBuffer);
|
|
await handler.loadModel(options || {});
|
|
return handler;
|
|
}
|
|
}
|
|
|
|
export const onnxruntimeBackend = new OnnxruntimeBackend();
|