// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { expect } from 'chai'; import { Attribute } from '../../lib/onnxjs/attribute'; import { WEBGL_OP_RESOLVE_RULES } from '../../lib/onnxjs/backends/webgl/op-resolve-rules'; import { Graph } from '../../lib/onnxjs/graph'; import { OpSet, resolveOperator } from '../../lib/onnxjs/opset'; import { Tensor } from '../../lib/onnxjs/tensor'; function createTestGraphNode(name: string, opType: string): Graph.Node { return { name, opType, inputs: [], outputs: [], attributes: new Attribute(null) }; } function dummyOpImpl(): Tensor[] { return []; } function checkConsistency(rules: readonly OpSet.ResolveRule[]) { const VERSION_MIN = 1, VERSION_MAX = 10; const typeRules = new Map(); rules.forEach((rule) => { let ruleSet = typeRules.get(rule[0]); if (!ruleSet) { ruleSet = []; typeRules.set(rule[0], ruleSet); } ruleSet.push(rule); }); typeRules.forEach((rules, type) => { for (let i = VERSION_MIN; i < VERSION_MAX; i++) { let match = false; for (const r of rules) { try { resolveOperator(createTestGraphNode('', type), [{ domain: '', version: i }], [r]); } catch { continue; } expect(match, `multiple rules overlapped: opType='${type}', domain='', version=${i}`).to.be.false; match = true; } } }); } describe('#UnitTest# - resolveOperator', () => { const nodeAbs = createTestGraphNode('Abs_1', 'Abs'); const opset7 = [{ domain: '', version: 7 }]; it('ExpectFail - no rule available', () => { expect(() => { resolveOperator(nodeAbs, opset7, []); }).to.throw(TypeError); }); it('ExpectFail - no matching rule', () => { expect(() => { resolveOperator(nodeAbs, opset7, [ ['And', '', '7', dummyOpImpl], ['Sub', '', '7', dummyOpImpl], ]); }).to.throw(TypeError); }); it('ExpectFail - version not match (exact match)', () => { expect(() => { resolveOperator(nodeAbs, opset7, [['Abs', '', '6', dummyOpImpl]]); }).to.throw(TypeError); }); it('ExpectFail - version not match (minimum version match)', () => { expect(() => { resolveOperator(nodeAbs, opset7, [['Abs', '', '8+', dummyOpImpl]]); }).to.throw(TypeError); }); it('ExpectFail - version not match (range match 1)', () => { expect(() => { resolveOperator(nodeAbs, opset7, [['Abs', '', '4-6', dummyOpImpl]]); }).to.throw(TypeError); }); it('ExpectFail - version not match (range match 2)', () => { expect(() => { resolveOperator(nodeAbs, opset7, [['Abs', '', '8-10', dummyOpImpl]]); }).to.throw(TypeError); }); it('ExpectPass - version match (exact match)', () => { resolveOperator(nodeAbs, opset7, [['Abs', '', '7', dummyOpImpl]]); }); it('ExpectPass - version match (minimum version match)', () => { resolveOperator(nodeAbs, opset7, [['Abs', '', '5+', dummyOpImpl]]); }); it('ExpectPass - version match (range match 1)', () => { resolveOperator(nodeAbs, opset7, [['Abs', '', '5-7', dummyOpImpl]]); }); it('ExpectPass - version match (range match 2)', () => { resolveOperator(nodeAbs, opset7, [['Abs', '', '6-9', dummyOpImpl]]); }); }); describe('#UnitTest# - resolve rules', () => { const webglCheckOnlyRules = WEBGL_OP_RESOLVE_RULES.map( (rule) => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule, ); it('Consistency check - onnx.ai - webgl', () => { checkConsistency(webglCheckOnlyRules); }); });