mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
101 lines
3.5 KiB
TypeScript
101 lines
3.5 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {expect} from 'chai';
|
|
|
|
import {Attribute} from '../../lib/onnxjs/attribute';
|
|
import {WEBGL_OP_RESOLVE_RULES} from '../../lib/onnxjs/backends/webgl/op-resolve-rules';
|
|
import {Graph} from '../../lib/onnxjs/graph';
|
|
import {OpSet, resolveOperator} from '../../lib/onnxjs/opset';
|
|
import {Tensor} from '../../lib/onnxjs/tensor';
|
|
|
|
function createTestGraphNode(name: string, opType: string): Graph.Node {
|
|
return {name, opType, inputs: [], outputs: [], attributes: new Attribute(null)};
|
|
}
|
|
|
|
function dummyOpImpl(): Tensor[] {
|
|
return [];
|
|
}
|
|
|
|
function checkConsistency(rules: readonly OpSet.ResolveRule[]) {
|
|
const VERSION_MIN = 1, VERSION_MAX = 10;
|
|
const typeRules = new Map<string, OpSet.ResolveRule[]>();
|
|
rules.forEach(rule => {
|
|
let ruleSet = typeRules.get(rule[0]);
|
|
if (!ruleSet) {
|
|
ruleSet = [];
|
|
typeRules.set(rule[0], ruleSet);
|
|
}
|
|
ruleSet.push(rule);
|
|
});
|
|
|
|
typeRules.forEach((rules, type) => {
|
|
for (let i = VERSION_MIN; i < VERSION_MAX; i++) {
|
|
let match = false;
|
|
for (const r of rules) {
|
|
try {
|
|
resolveOperator(createTestGraphNode('', type), [{domain: '', version: i}], [r]);
|
|
} catch {
|
|
continue;
|
|
}
|
|
expect(match, `multiple rules overlapped: opType='${type}', domain='', version=${i}`).to.be.false;
|
|
match = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
describe('#UnitTest# - resolveOperator', () => {
|
|
const nodeAbs = createTestGraphNode('Abs_1', 'Abs');
|
|
const opset7 = [{domain: '', version: 7}];
|
|
it('ExpectFail - no rule available', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, []);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectFail - no matching rule', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, [['And', '', '7', dummyOpImpl], ['Sub', '', '7', dummyOpImpl]]);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectFail - version not match (exact match)', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '6', dummyOpImpl]]);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectFail - version not match (minimum version match)', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '8+', dummyOpImpl]]);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectFail - version not match (range match 1)', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '4-6', dummyOpImpl]]);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectFail - version not match (range match 2)', () => {
|
|
expect(() => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '8-10', dummyOpImpl]]);
|
|
}).to.throw(TypeError);
|
|
});
|
|
it('ExpectPass - version match (exact match)', () => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '7', dummyOpImpl]]);
|
|
});
|
|
it('ExpectPass - version match (minimum version match)', () => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '5+', dummyOpImpl]]);
|
|
});
|
|
it('ExpectPass - version match (range match 1)', () => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '5-7', dummyOpImpl]]);
|
|
});
|
|
it('ExpectPass - version match (range match 2)', () => {
|
|
resolveOperator(nodeAbs, opset7, [['Abs', '', '6-9', dummyOpImpl]]);
|
|
});
|
|
});
|
|
|
|
describe('#UnitTest# - resolve rules', () => {
|
|
const webglCheckOnlyRules =
|
|
WEBGL_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule);
|
|
it('Consistency check - onnx.ai - webgl', () => {
|
|
checkConsistency(webglCheckOnlyRules);
|
|
});
|
|
});
|