mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
# Motivation Currently, ORT minimal builds use kernel def hashes to map from nodes to kernels to execute when loading the model. As the kernel def hashes must be known ahead of time, this works for statically registered kernels. This works well for the CPU EP. For this approach to work, the kernel def hashes must also be known at ORT format model conversion time, which means the EP with statically registered kernels must also be enabled then. This is not an issue for the always-available CPU EP. However, we do not want to require that any EP which statically registers kernels is always available too. Consequently, we explore another approach to match nodes to kernels that does not rely on kernel def hashes. An added benefit of this is the possibility of moving away from kernel def hashes completely, which would eliminate the maintenance burden of keeping the hashes stable. # Approach In a full build, ORT uses some information from the ONNX op schema to match a node to a kernel. We want to avoid including the ONNX op schema in a minimal build to reduce binary size. Essentially, we take the necessary information from the ONNX op schema and make it available in a minimal build. We decouple the ONNX op schema from the kernel matching logic. The kernel matching logic instead relies on per-op information which can either be obtained from the ONNX op schema or another source. This per-op information must be available in a minimal build when there are no ONNX op schemas. We put it in the ORT format model. Existing uses of kernel def hashes to look up kernels are replaced with the updated kernel matching logic. We no longer store kernel def hashes in the ORT format model’s session state and runtime optimization representations. We no longer keep the logic to generate and ensure stability of kernel def hashes.
104 lines
4.3 KiB
TypeScript
104 lines
4.3 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import assert from 'assert';
|
|
import {InferenceSession, Tensor} from 'onnxruntime-common';
|
|
import * as path from 'path';
|
|
|
|
import {assertDataEqual, TEST_DATA_ROOT} from '../test-utils';
|
|
|
|
|
|
const MODEL_TEST_TYPES_CASES:
|
|
Array<{model: string; type: Tensor.Type; input0: Tensor.DataType; expectedOutput0: Tensor.DataType}> = [
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'),
|
|
type: 'bool',
|
|
input0: Uint8Array.from([1, 0, 0, 1, 0]),
|
|
expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'),
|
|
type: 'float64',
|
|
input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
|
|
expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'),
|
|
type: 'float32',
|
|
input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
|
|
expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'),
|
|
type: 'int8',
|
|
input0: Int8Array.from([1, -2, 3, 4, -5]),
|
|
expectedOutput0: Int8Array.from([1, -2, 3, 4, -5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'),
|
|
type: 'int16',
|
|
input0: Int16Array.from([1, -2, 3, 4, -5]),
|
|
expectedOutput0: Int16Array.from([1, -2, 3, 4, -5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'),
|
|
type: 'int32',
|
|
input0: Int32Array.from([1, -2, 3, 4, -5]),
|
|
expectedOutput0: Int32Array.from([1, -2, 3, 4, -5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'),
|
|
type: 'int64',
|
|
input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]),
|
|
expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'),
|
|
type: 'string',
|
|
input0: ['a', 'b', 'c', 'd', 'e'],
|
|
expectedOutput0: ['a', 'b', 'c', 'd', 'e']
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'),
|
|
type: 'uint8',
|
|
input0: Uint8Array.from([1, 2, 3, 4, 5]),
|
|
expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'),
|
|
type: 'uint16',
|
|
input0: Uint16Array.from([1, 2, 3, 4, 5]),
|
|
expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'),
|
|
type: 'uint32',
|
|
input0: Uint32Array.from([1, 2, 3, 4, 5]),
|
|
expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5])
|
|
},
|
|
{
|
|
model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'),
|
|
type: 'uint64',
|
|
input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]),
|
|
expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)])
|
|
},
|
|
];
|
|
|
|
describe('E2E Tests - simple E2E tests', () => {
|
|
MODEL_TEST_TYPES_CASES.forEach(testCase => {
|
|
it(`${testCase.model}`, async () => {
|
|
const session = await InferenceSession.create(testCase.model);
|
|
const output = await session.run({'input': new Tensor(testCase.type, testCase.input0, [1, 5])});
|
|
assert(Object.prototype.hasOwnProperty.call(output, 'output'), '\'output\' should be in the result object.');
|
|
assert(output.output instanceof Tensor, 'result[output] should be a Tensor object.');
|
|
assert.strictEqual(output.output.size, 5, `output size expected 5, got ${output.output.size}.`);
|
|
assert.strictEqual(
|
|
output.output.type, testCase.type, `tensor type expected ${testCase.type}, got ${output.output.type}.`);
|
|
assert.strictEqual(
|
|
Object.getPrototypeOf(output.output.data), Object.getPrototypeOf(testCase.expectedOutput0),
|
|
`tensor data expected ${Object.getPrototypeOf(testCase.expectedOutput0).constructor.name}, got ${
|
|
Object.getPrototypeOf(output.output.data).constructor.name}`);
|
|
assertDataEqual(testCase.type, output.output.data, testCase.expectedOutput0);
|
|
});
|
|
});
|
|
});
|