onnxruntime/js/web/test/unittests/opset.ts
2021-08-12 12:30:49 -07:00

101 lines
3.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {expect} from 'chai';
import {Attribute} from '../../lib/onnxjs/attribute';
import {WEBGL_OP_RESOLVE_RULES} from '../../lib/onnxjs/backends/webgl/op-resolve-rules';
import {Graph} from '../../lib/onnxjs/graph';
import {OpSet, resolveOperator} from '../../lib/onnxjs/opset';
import {Tensor} from '../../lib/onnxjs/tensor';
function createTestGraphNode(name: string, opType: string): Graph.Node {
return {name, opType, inputs: [], outputs: [], attributes: new Attribute(null)};
}
function dummyOpImpl(): Tensor[] {
return [];
}
function checkConsistency(rules: readonly OpSet.ResolveRule[]) {
const VERSION_MIN = 1, VERSION_MAX = 10;
const typeRules = new Map<string, OpSet.ResolveRule[]>();
rules.forEach(rule => {
let ruleSet = typeRules.get(rule[0]);
if (!ruleSet) {
ruleSet = [];
typeRules.set(rule[0], ruleSet);
}
ruleSet.push(rule);
});
typeRules.forEach((rules, type) => {
for (let i = VERSION_MIN; i < VERSION_MAX; i++) {
let match = false;
for (const r of rules) {
try {
resolveOperator(createTestGraphNode('', type), [{domain: '', version: i}], [r]);
} catch {
continue;
}
expect(match, `multiple rules overlapped: opType='${type}', domain='', version=${i}`).to.be.false;
match = true;
}
}
});
}
describe('#UnitTest# - resolveOperator', () => {
const nodeAbs = createTestGraphNode('Abs_1', 'Abs');
const opset7 = [{domain: '', version: 7}];
it('ExpectFail - no rule available', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, []);
}).to.throw(TypeError);
});
it('ExpectFail - no matching rule', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, [['And', '', '7', dummyOpImpl], ['Sub', '', '7', dummyOpImpl]]);
}).to.throw(TypeError);
});
it('ExpectFail - version not match (exact match)', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '6', dummyOpImpl]]);
}).to.throw(TypeError);
});
it('ExpectFail - version not match (minimum version match)', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '8+', dummyOpImpl]]);
}).to.throw(TypeError);
});
it('ExpectFail - version not match (range match 1)', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '4-6', dummyOpImpl]]);
}).to.throw(TypeError);
});
it('ExpectFail - version not match (range match 2)', () => {
expect(() => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '8-10', dummyOpImpl]]);
}).to.throw(TypeError);
});
it('ExpectPass - version match (exact match)', () => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '7', dummyOpImpl]]);
});
it('ExpectPass - version match (minimum version match)', () => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '5+', dummyOpImpl]]);
});
it('ExpectPass - version match (range match 1)', () => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '5-7', dummyOpImpl]]);
});
it('ExpectPass - version match (range match 2)', () => {
resolveOperator(nodeAbs, opset7, [['Abs', '', '6-9', dummyOpImpl]]);
});
});
describe('#UnitTest# - resolve rules', () => {
const webglCheckOnlyRules =
WEBGL_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule);
it('Consistency check - onnx.ai - webgl', () => {
checkConsistency(webglCheckOnlyRules);
});
});