mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-20 21:40:57 +00:00
### Description
See
454996d496
for manual changes (excluded auto-generated formatting changes)
### Why
Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.
- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.
No one in community seems interested in fixing those.
Choose Prettier as it is the most popular TS/JS formatter.
### How to merge
It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.
So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
270 lines
8.9 KiB
TypeScript
270 lines
8.9 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { resolveBackend, SessionHandlerType } from './backend';
|
|
import { ExecutionPlan } from './execution-plan';
|
|
import { Graph } from './graph';
|
|
import { Profiler } from './instrument';
|
|
import { Model } from './model';
|
|
import { Operator } from './operators';
|
|
import { Tensor } from './tensor';
|
|
|
|
export declare namespace Session {
|
|
export interface Config {
|
|
backendHint?: string;
|
|
profiler?: Profiler.Config;
|
|
}
|
|
|
|
export interface Context {
|
|
profiler: Readonly<Profiler>;
|
|
graphInputTypes?: Tensor.DataType[];
|
|
graphInputDims?: Array<readonly number[]>;
|
|
}
|
|
}
|
|
|
|
export class Session {
|
|
constructor(config: Session.Config = {}) {
|
|
this._initialized = false;
|
|
this.backendHint = config.backendHint;
|
|
this.profiler = Profiler.create(config.profiler);
|
|
this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] };
|
|
}
|
|
|
|
get inputNames(): readonly string[] {
|
|
return this._model.graph.getInputNames();
|
|
}
|
|
get outputNames(): readonly string[] {
|
|
return this._model.graph.getOutputNames();
|
|
}
|
|
|
|
startProfiling() {
|
|
this.profiler.start();
|
|
}
|
|
|
|
endProfiling() {
|
|
this.profiler.stop();
|
|
}
|
|
|
|
async loadModel(uri: string): Promise<void>;
|
|
async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise<void>;
|
|
async loadModel(buffer: Uint8Array): Promise<void>;
|
|
async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise<void> {
|
|
await this.profiler.event('session', 'Session.loadModel', async () => {
|
|
// resolve backend and session handler
|
|
const backend = await resolveBackend(this.backendHint);
|
|
this.sessionHandler = backend.createSessionHandler(this.context);
|
|
|
|
this._model = new Model();
|
|
if (typeof arg === 'string') {
|
|
const isOrtFormat = arg.endsWith('.ort');
|
|
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
|
|
// node
|
|
const { readFile } = require('node:fs/promises');
|
|
const buf = await readFile(arg);
|
|
this.initialize(buf, isOrtFormat);
|
|
} else {
|
|
// browser
|
|
const response = await fetch(arg);
|
|
const buf = await response.arrayBuffer();
|
|
this.initialize(new Uint8Array(buf), isOrtFormat);
|
|
}
|
|
} else if (!ArrayBuffer.isView(arg)) {
|
|
// load model from ArrayBuffer
|
|
const arr = new Uint8Array(arg, byteOffset || 0, length || arg.byteLength);
|
|
this.initialize(arr);
|
|
} else {
|
|
// load model from Uint8array
|
|
this.initialize(arg);
|
|
}
|
|
});
|
|
}
|
|
|
|
private initialize(modelProtoBlob: Uint8Array, isOrtFormat?: boolean): void {
|
|
if (this._initialized) {
|
|
throw new Error('already initialized');
|
|
}
|
|
|
|
this.profiler.event('session', 'Session.initialize', () => {
|
|
// load graph
|
|
const graphInitializer = this.sessionHandler.transformGraph
|
|
? (this.sessionHandler as Graph.Initializer)
|
|
: undefined;
|
|
this._model.load(modelProtoBlob, graphInitializer, isOrtFormat);
|
|
|
|
// graph is completely initialzied at this stage , let the interested handlers know
|
|
if (this.sessionHandler.onGraphInitialized) {
|
|
this.sessionHandler.onGraphInitialized(this._model.graph);
|
|
}
|
|
// initialize each operator in the graph
|
|
this.initializeOps(this._model.graph);
|
|
|
|
// instantiate an ExecutionPlan object to be used by the Session object
|
|
this._executionPlan = new ExecutionPlan(this._model.graph, this._ops, this.profiler);
|
|
});
|
|
|
|
this._initialized = true;
|
|
}
|
|
|
|
async run(inputs: Map<string, Tensor> | Tensor[]): Promise<Map<string, Tensor>> {
|
|
if (!this._initialized) {
|
|
throw new Error('session not initialized yet');
|
|
}
|
|
|
|
return this.profiler.event('session', 'Session.run', async () => {
|
|
const inputTensors = this.normalizeAndValidateInputs(inputs);
|
|
|
|
const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);
|
|
|
|
return this.createOutput(outputTensors);
|
|
});
|
|
}
|
|
|
|
private normalizeAndValidateInputs(inputs: Map<string, Tensor> | Tensor[]): Tensor[] {
|
|
const modelInputNames = this._model.graph.getInputNames();
|
|
|
|
// normalize inputs
|
|
// inputs: Tensor[]
|
|
if (Array.isArray(inputs)) {
|
|
if (inputs.length !== modelInputNames.length) {
|
|
throw new Error(`incorrect input array length: expected ${modelInputNames.length} but got ${inputs.length}`);
|
|
}
|
|
}
|
|
// convert map to array
|
|
// inputs: Map<string, Tensor>
|
|
else {
|
|
if (inputs.size !== modelInputNames.length) {
|
|
throw new Error(`incorrect input map size: expected ${modelInputNames.length} but got ${inputs.size}`);
|
|
}
|
|
|
|
const sortedInputs = new Array<Tensor>(inputs.size);
|
|
let sortedInputsIndex = 0;
|
|
for (let i = 0; i < modelInputNames.length; ++i) {
|
|
const tensor = inputs.get(modelInputNames[i]);
|
|
if (!tensor) {
|
|
throw new Error(`missing input tensor for: '${name}'`);
|
|
}
|
|
sortedInputs[sortedInputsIndex++] = tensor;
|
|
}
|
|
|
|
inputs = sortedInputs;
|
|
}
|
|
|
|
// validate dims requirements
|
|
// First session run - graph input data is not cached for the session
|
|
if (
|
|
!this.context.graphInputTypes ||
|
|
this.context.graphInputTypes.length === 0 ||
|
|
!this.context.graphInputDims ||
|
|
this.context.graphInputDims.length === 0
|
|
) {
|
|
const modelInputIndices = this._model.graph.getInputIndices();
|
|
const modelValues = this._model.graph.getValues();
|
|
|
|
const graphInputDims = new Array<readonly number[]>(modelInputIndices.length);
|
|
|
|
for (let i = 0; i < modelInputIndices.length; ++i) {
|
|
const graphInput = modelValues[modelInputIndices[i]];
|
|
graphInputDims[i] = graphInput.type!.shape.dims;
|
|
|
|
// cached for second and subsequent runs.
|
|
// Some parts of the framework works on the assumption that the graph and types and shapes are static
|
|
this.context.graphInputTypes!.push(graphInput.type!.tensorType);
|
|
this.context.graphInputDims!.push(inputs[i].dims);
|
|
}
|
|
|
|
this.validateInputTensorDims(graphInputDims, inputs, true);
|
|
}
|
|
|
|
// Second and subsequent session runs - graph input data is cached for the session
|
|
else {
|
|
this.validateInputTensorDims(this.context.graphInputDims, inputs, false);
|
|
}
|
|
|
|
// validate types requirement
|
|
this.validateInputTensorTypes(this.context.graphInputTypes!, inputs);
|
|
|
|
return inputs;
|
|
}
|
|
|
|
private validateInputTensorTypes(graphInputTypes: Tensor.DataType[], givenInputs: Tensor[]) {
|
|
for (let i = 0; i < givenInputs.length; i++) {
|
|
const expectedType = graphInputTypes[i];
|
|
const actualType = givenInputs[i].type;
|
|
if (expectedType !== actualType) {
|
|
throw new Error(`input tensor[${i}] check failed: expected type '${expectedType}' but got ${actualType}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
private validateInputTensorDims(
|
|
graphInputDims: Array<readonly number[]>,
|
|
givenInputs: Tensor[],
|
|
noneDimSupported: boolean,
|
|
) {
|
|
for (let i = 0; i < givenInputs.length; i++) {
|
|
const expectedDims = graphInputDims[i];
|
|
const actualDims = givenInputs[i].dims;
|
|
if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) {
|
|
throw new Error(
|
|
`input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join(
|
|
',',
|
|
)}]`,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
private compareTensorDims(
|
|
expectedDims: readonly number[],
|
|
actualDims: readonly number[],
|
|
noneDimSupported: boolean,
|
|
): boolean {
|
|
if (expectedDims.length !== actualDims.length) {
|
|
return false;
|
|
}
|
|
|
|
for (let i = 0; i < expectedDims.length; ++i) {
|
|
if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) {
|
|
// data shape mis-match AND not a 'None' dimension.
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
private createOutput(outputTensors: Tensor[]): Map<string, Tensor> {
|
|
const modelOutputNames = this._model.graph.getOutputNames();
|
|
if (outputTensors.length !== modelOutputNames.length) {
|
|
throw new Error('expected number of outputs do not match number of generated outputs');
|
|
}
|
|
|
|
const output = new Map<string, Tensor>();
|
|
for (let i = 0; i < modelOutputNames.length; ++i) {
|
|
output.set(modelOutputNames[i], outputTensors[i]);
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
private initializeOps(graph: Graph): void {
|
|
const nodes = graph.getNodes();
|
|
this._ops = new Array(nodes.length);
|
|
|
|
for (let i = 0; i < nodes.length; i++) {
|
|
this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph);
|
|
}
|
|
}
|
|
|
|
private _model: Model;
|
|
private _initialized: boolean;
|
|
|
|
private _ops: Operator[];
|
|
private _executionPlan: ExecutionPlan;
|
|
|
|
private backendHint?: string;
|
|
|
|
private sessionHandler: SessionHandlerType;
|
|
private context: Session.Context;
|
|
private profiler: Readonly<Profiler>;
|
|
}
|