mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
### Description
Make error friendly when isOrtFormat is undefined
(`onnxruntime.InferenceSession.create` is called with ArrayBuffer or
Uint8Array).
### Motivation and Context
I was trying to run my onnx model in WebGL EP, but it gave me the error
"Cannot read properties of null (reading 'irVersion')".
I used debugger to find that actual error is `int64 is not supported`,
but the error was invisible for me.
So I made it to show both error when isOrtFormat is undefined.
<s>I haven't written unit test yet, so I'm making it draft. (I have no
idea about how do I test this though...)</s>
[d62d942](d62d9425ba)
82 lines
2.6 KiB
TypeScript
82 lines
2.6 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {flatbuffers} from 'flatbuffers';
|
|
|
|
import {Graph} from './graph';
|
|
import {OpSet} from './opset';
|
|
import {onnxruntime} from './ort-schema/flatbuffers/ort-generated';
|
|
import {onnx} from './ort-schema/protobuf/onnx';
|
|
import {LongUtil} from './util';
|
|
|
|
import ortFbs = onnxruntime.experimental.fbs;
|
|
|
|
export class Model {
|
|
// empty model
|
|
constructor() {}
|
|
|
|
load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void {
|
|
let onnxError: Error|undefined;
|
|
if (!isOrtFormat) {
|
|
// isOrtFormat === false || isOrtFormat === undefined
|
|
try {
|
|
this.loadFromOnnxFormat(buf, graphInitializer);
|
|
return;
|
|
} catch (e) {
|
|
if (isOrtFormat !== undefined) {
|
|
throw e;
|
|
}
|
|
onnxError = e;
|
|
}
|
|
}
|
|
|
|
try {
|
|
this.loadFromOrtFormat(buf, graphInitializer);
|
|
} catch (e) {
|
|
if (isOrtFormat !== undefined) {
|
|
throw e;
|
|
}
|
|
// Tried both formats and failed (when isOrtFormat === undefined)
|
|
throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);
|
|
}
|
|
}
|
|
|
|
private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
|
|
const modelProto = onnx.ModelProto.decode(buf);
|
|
const irVersion = LongUtil.longToNumber(modelProto.irVersion);
|
|
if (irVersion < 3) {
|
|
throw new Error('only support ONNX model with IR_VERSION>=3');
|
|
}
|
|
|
|
this._opsets =
|
|
modelProto.opsetImport.map(i => ({domain: i.domain as string, version: LongUtil.longToNumber(i.version!)}));
|
|
|
|
this._graph = Graph.from(modelProto.graph!, graphInitializer);
|
|
}
|
|
|
|
private loadFromOrtFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
|
|
const fb = new flatbuffers.ByteBuffer(buf);
|
|
const ortModel = ortFbs.InferenceSession.getRootAsInferenceSession(fb).model()!;
|
|
const irVersion = LongUtil.longToNumber(ortModel.irVersion());
|
|
if (irVersion < 3) {
|
|
throw new Error('only support ONNX model with IR_VERSION>=3');
|
|
}
|
|
this._opsets = [];
|
|
for (let i = 0; i < ortModel.opsetImportLength(); i++) {
|
|
const opsetId = ortModel.opsetImport(i)!;
|
|
this._opsets.push({domain: opsetId?.domain() as string, version: LongUtil.longToNumber(opsetId.version()!)});
|
|
}
|
|
|
|
this._graph = Graph.from(ortModel.graph()!, graphInitializer);
|
|
}
|
|
|
|
private _graph: Graph;
|
|
get graph(): Graph {
|
|
return this._graph;
|
|
}
|
|
|
|
private _opsets: OpSet[];
|
|
get opsets(): readonly OpSet[] {
|
|
return this._opsets;
|
|
}
|
|
}
|