diff --git a/js/.vscode/settings.json b/js/.vscode/settings.json index 15eacc675a..4948899ec6 100644 --- a/js/.vscode/settings.json +++ b/js/.vscode/settings.json @@ -46,5 +46,11 @@ }, "typescript.tsdk": "node_modules/typescript/lib", "git.detectSubmodules": false, - "cmake.configureOnOpen": false + "cmake.configureOnOpen": false, + "json.schemas": [ + { + "fileMatch": ["web/test/data/ops/*.jsonc"], + "url": "./web/test/op-test-schema.json" + } + ] } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 382c2eba73..3c5a2881db 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -234,7 +234,7 @@ async function main() { } const test = testIds && testIds.length > 0 ? allTests[testIds[0]] : undefined; - const condition = test && typeof test !== 'string' ? test.condition : undefined; + const platformCondition = test && typeof test !== 'string' ? test.platformCondition : undefined; const opsetVersion = folder.split('/')[0]; const category = `node-${opsetVersion}-${backend}`; @@ -243,14 +243,16 @@ async function main() { modelTests = []; opsetTests.set(category, modelTests); } - modelTests.push(modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, condition, times)); + modelTests.push( + modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, platformCondition, times)); } return Array.from(opsetTests.keys()).map(category => ({name: category, tests: opsetTests.get(category)!})); } function modelTestFromFolder( - testDataRootFolder: string, backend: string, condition?: Test.Condition, times?: number): Test.ModelTest { + testDataRootFolder: string, backend: string, platformCondition?: Test.PlatformCondition, + times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; @@ -326,7 +328,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), condition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -385,7 +387,7 @@ async function main() { // field 'verbose' and 'backend' is not set for (const test of tests) { test.backend = backend; - test.opsets = test.opsets || [{domain: '', version: MAX_OPSET_VERSION}]; + test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); diff --git a/js/web/test/data/ops/_example.jsonc b/js/web/test/data/ops/_example.jsonc new file mode 100644 index 0000000000..1c9f306a4c --- /dev/null +++ b/js/web/test/data/ops/_example.jsonc @@ -0,0 +1,103 @@ +// This file is an example of an operator test file. +// +// In this file, we demonstrate how to write a test file for ONNX operators. +// There are 2 operator tests defined in this file: +// +// - "Simple Abs test example": a simple operator test for Abs operator. This example shows how to write a simple test with minimal properties. +// +// - "Conv2D with padding": a simple operator test for Conv operator with padding. This example shows how to write a test with all optional properties. +// + +// test file starts with an array of test objects. +[ + // this is the first operator test object (Abs example). + { + "name": "Simple Abs op test example", // name of the test + "operator": "Abs", // OpType of the operator + "cases": [ + // in this example, we only have one test case. + { + // name of the test case + "name": "3D float32 test", + "inputs": [ + // specify the input tensor + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, -1, -2, -3, -4, -5, -6, -7, -8, 101, 102, 103, 104], + "dims": [2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 1, 2, 3, 4, 5, 6, 7, 8, 101, 102, 103, 104], + "dims": [2, 3, 4], + "type": "float32" + } + ] + } + ] + }, + // this is the second operator test object (Conv example). + { + // name of the test + "name": "Conv op test example", + + // OpType of the operator + "operator": "Conv", + + // [optional] specify the attributes of the operator + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + + // [optional] specify a regex pattern to match the platform description. + // + // If not specified, the test will run on all platforms. + // Otherwise, the test will only run on platforms that match the pattern. + "platformCondition": "", + + // [optional] specify input shape definitions. + // + // Sometimes, input shape definitions can offer shape information for ONNX Runtime to optimize its inferencing behavior. + // For example, ORT will transform a NCHW Conv operator into a NHWC operator when the input shape is 4 dimensional. + // If the input shape dimension is unknown, ORT will not perform this optimization. + // + // In operator test, we can specify input shape definitions to test the optimized behavior. + // + // The array of input shape definitions should have the same length as the number of model's inputs. + // + "inputShapeDefinitions": [ + // input 0 shape definition. use semantic names to specify the dynamic dimensions. + ["__input_0_dim_0", "__input_0_dim_1", "__input_0_dim_2", "__input_0_dim_3"], + // input 1 shape definition. use numbers to specify the static dimensions. + [1, 1, 2, 2] + ], + + // [optional] specify the opset of the operator. + "opset": { "domain": "", "version": 13 }, + + // test cases is required. + "cases": [ + { + "name": "NCHW Conv2D test", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/gelu.jsonc b/js/web/test/data/ops/gelu.jsonc index 79e4335c2d..b1546353bf 100644 --- a/js/web/test/data/ops/gelu.jsonc +++ b/js/web/test/data/ops/gelu.jsonc @@ -2,7 +2,7 @@ { "name": "gelu", "operator": "Gelu", - "opsets": [{ "domain": "com.microsoft", "version": 1 }], + "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [], "cases": [ { @@ -16,7 +16,7 @@ ], "outputs": [ { - "data": [1.0, 0, 0, 2.0], + "data": [0.8413447141647339, -0.04550027847290039, 0, 1.9544997215270996], "dims": [2, 2], "type": "float32" } @@ -33,7 +33,7 @@ ], "outputs": [ { - "data": [1.0], + "data": [0.8413447141647339], "dims": [], "type": "float32" } diff --git a/js/web/test/data/ops/pad-big.jsonc b/js/web/test/data/ops/pad-big.jsonc index b014f77659..601e1d58a4 100644 --- a/js/web/test/data/ops/pad-big.jsonc +++ b/js/web/test/data/ops/pad-big.jsonc @@ -2,7 +2,7 @@ { "name": "constant 2D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "reflect", "type": "string" }, { "name": "pads", "data": [0, 0, 1, 1, 0, 0, 1, 1], "type": "ints" } diff --git a/js/web/test/data/ops/pad.jsonc b/js/web/test/data/ops/pad.jsonc index 1705eee9b0..62414213b1 100644 --- a/js/web/test/data/ops/pad.jsonc +++ b/js/web/test/data/ops/pad.jsonc @@ -2,7 +2,7 @@ { "name": "constant 2D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "constant", "type": "string" }, { "name": "value", "data": 1.2, "type": "float" }, @@ -35,7 +35,7 @@ { "name": "constant 3D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "constant", "type": "string" }, { "name": "value", "data": 2.3, "type": "float" }, @@ -79,7 +79,7 @@ { "name": "Reflect 1D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "reflect", "type": "string" }, { "name": "pads", "data": [5, 7], "type": "ints" } @@ -107,7 +107,7 @@ { "name": "Reflect 2D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "reflect", "type": "string" }, { "name": "pads", "data": [3, 2, 2, 5], "type": "ints" } @@ -139,7 +139,7 @@ { "name": "Reflect 3D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "reflect", "type": "string" }, { "name": "pads", "data": [1, 2, 2, 2, 3, 1], "type": "ints" } @@ -182,7 +182,7 @@ { "name": "Edge 2D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "edge", "type": "string" }, { "name": "pads", "data": [3, 2, 2, 3], "type": "ints" } @@ -214,7 +214,7 @@ { "name": "Edge 3D", "operator": "Pad", - "opsets": [{ "domain": "", "version": 10 }], + "opset": { "domain": "", "version": 10 }, "attributes": [ { "name": "mode", "data": "edge", "type": "string" }, { "name": "pads", "data": [1, 2, 2, 2, 3, 1], "type": "ints" } diff --git a/js/web/test/data/ops/pow-big-number.jsonc b/js/web/test/data/ops/pow-big-number.jsonc index 5a87fe15b8..17693fa2d4 100644 --- a/js/web/test/data/ops/pow-big-number.jsonc +++ b/js/web/test/data/ops/pow-big-number.jsonc @@ -3,7 +3,7 @@ "name": "Pow with no attributes - big number", "operator": "Pow", "attributes": [], - "condition": "^((?!iOS).)*$", // does NOT contains 'iOS': large number cannot be handled in a half_float environment + "platformCondition": "^((?!iOS).)*$", // does NOT contains 'iOS': large number cannot be handled in a half_float environment "cases": [ { "name": "T[2,4] T[3,2,4]", diff --git a/js/web/test/data/ops/resize-pack.jsonc b/js/web/test/data/ops/resize-pack.jsonc index c2df2f9dab..7b9a2ef96d 100644 --- a/js/web/test/data/ops/resize-pack.jsonc +++ b/js/web/test/data/ops/resize-pack.jsonc @@ -2,12 +2,7 @@ { "name": "ResizeBilinearPacked with mode half_pixel", "operator": "Resize", - "opsets": [ - { - "domain": "", - "version": "11" - } - ], + "opset": { "domain": "", "version": 11 }, "attributes": [ // { "name": "scales", "data": [1.0, 1.0, 2.0, 3.0], "type": "floats" }, { @@ -54,12 +49,7 @@ { "name": "ResizeBilinearPacked with mode align_corners", "operator": "Resize", - "opsets": [ - { - "domain": "", - "version": "11" - } - ], + "opset": { "domain": "", "version": 11 }, "attributes": [ { "name": "coordinate_transformation_mode", @@ -105,12 +95,7 @@ { "name": "ResizeBilinearPacked with asymmetric", "operator": "Resize", - "opsets": [ - { - "domain": "", - "version": "11" - } - ], + "opset": { "domain": "", "version": 11 }, "attributes": [ { "name": "coordinate_transformation_mode", diff --git a/js/web/test/data/ops/split.jsonc b/js/web/test/data/ops/split.jsonc index a173f10471..46fc323cc6 100644 --- a/js/web/test/data/ops/split.jsonc +++ b/js/web/test/data/ops/split.jsonc @@ -2,7 +2,7 @@ { "name": "Split on Axis 0", "operator": "Split", - "opsets": [{ "domain": "", "version": 12 }], + "opset": { "domain": "", "version": 12 }, "attributes": [ { "name": "axis", "data": 0, "type": "int" }, { "name": "split", "data": [2, 4], "type": "ints" } @@ -35,7 +35,7 @@ { "name": "Split on Axis 1 - 2D", "operator": "Split", - "opsets": [{ "domain": "", "version": 12 }], + "opset": { "domain": "", "version": 12 }, "attributes": [ { "name": "axis", "data": 1, "type": "int" }, { "name": "split", "data": [2, 4], "type": "ints" } diff --git a/js/web/test/op-test-schema.json b/js/web/test/op-test-schema.json new file mode 100644 index 0000000000..aa08e29386 --- /dev/null +++ b/js/web/test/op-test-schema.json @@ -0,0 +1,282 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "type": "array", + "items": { + "properties": { + "name": { + "type": "string", + "title": "Name", + "description": "the name of the test case" + }, + "operator": { + "type": "string", + "pattern": "[A-Z][a-zA-Z]*", + "title": "Operator", + "description": "the operator to use for the test case" + }, + "attributes": { + "type": "array", + "description": "the attributes to use for the test case", + "items": { + "type": "object", + "oneOf": [ + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "int", + "description": "the type of the attribute" + }, + "data": { + "type": "integer", + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + }, + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "ints", + "description": "the type of the attribute" + }, + "data": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + }, + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "float", + "description": "the type of the attribute" + }, + "data": { + "type": "number", + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + }, + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "floats", + "description": "the type of the attribute" + }, + "data": { + "type": "array", + "items": { + "type": "number" + }, + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + }, + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "string", + "description": "the type of the attribute" + }, + "data": { + "type": "string", + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + }, + { + "properties": { + "name": { + "type": "string", + "description": "the name of the attribute" + }, + "type": { + "const": "strings", + "description": "the type of the attribute" + }, + "data": { + "type": "array", + "items": { + "type": "string" + }, + "description": "the value of the attribute" + } + }, + "required": ["name", "data", "type"], + "additionalProperties": false + } + ] + } + }, + "opset": { + "type": "object", + "description": "opset is an optional field that specifies the opset to use for the test case. If not specified, the latest opset of \"\"(onnx.ai) is used.", + "properties": { + "domain": { + "type": "string", + "description": "the domain of the opset" + }, + "version": { + "type": "integer", + "description": "the version of the opset" + } + }, + "required": ["domain", "version"], + "additionalProperties": false + }, + "cases": { + "type": "array", + "description": "the test cases", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "the name of the test case" + }, + "inputs": { + "type": "array", + "description": "the test case inputs", + "items": { + "properties": { + "type": { + "enum": [ + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "bool", + "string" + ] + }, + "data": { + "type": "array", + "items": { + "type": ["number", "string", "boolean"] + } + }, + "dims": { + "type": "array", + "items": { + "type": "integer", + "minimum": 0 + } + } + }, + "required": ["type", "data", "dims"], + "additionalProperties": false + } + }, + "outputs": { + "type": "array", + "description": "the test case outputs", + "items": { + "properties": { + "type": { + "enum": [ + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "bool", + "string" + ] + }, + "data": { + "type": "array", + "items": { + "type": ["number", "string", "boolean"] + } + }, + "dims": { + "type": "array", + "items": { + "type": "integer", + "minimum": 0 + } + } + }, + "required": ["type", "data", "dims"], + "additionalProperties": false + } + } + }, + "required": ["name", "inputs", "outputs"], + "additionalProperties": false + } + }, + "inputShapeDefinitions": { + "description": "inputShapeDefinitions is an optional field that specifies the shapes constraints for the test case inputs. It can be one of the following:\n - \"none\": no shape constraints for the test case inputs.\n - \"rankOnly\": the rank of the test case inputs are specified automatically, but not the shape.\n - \"static\": the shape of the test case inputs are fully specified automatically.\n - an array of shapes: the shapes constraints for the test case inputs. shape can be represented by an array, whose element is either a number for a static dimension or a string for a semantic(dynamic) dimension.", + "oneOf": [ + { + "type": "array", + "items": { + "oneOf": [ + { + "type": "array", + "items": { + "type": ["integer", "string"] + } + }, + { "type": "null" } + ] + } + }, + { + "enum": ["none", "rankOnly", "static"] + } + ] + }, + "platformCondition": { + "type": "string", + "description": "the condition for the test case, a regex string applied on platform name. If not specified, the test will run on all platforms. Otherwise, the test will only run on platforms that match the pattern. see https://github.com/bestiejs/platform.js/" + } + }, + "required": ["name", "operator", "cases"], + "additionalProperties": false + } +} diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c253aeff30..0fd848838e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -152,7 +152,7 @@ "test_softmax_example", { "name": "test_softmax_large_number", - "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment + "platformCondition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment }, "test_sub_bcast", "test_sub_example", @@ -183,7 +183,7 @@ "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_keepdims_random", { "name": "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_default_axes_keepdims_example", - "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment + "platformCondition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment }, "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_default_axes_keepdims_random", "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_do_not_keepdims_example", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 614dc4e16d..d19a4a7b0e 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -86,14 +86,14 @@ function shouldSkipTest(test: Test.ModelTest|Test.OperatorTest) { if (!test.cases || test.cases.length === 0) { return true; } - if (!test.condition) { + if (!test.platformCondition) { return false; } if (!platform.description) { throw new Error('failed to check current platform'); } - const regex = new RegExp(test.condition); + const regex = new RegExp(test.platformCondition); return !regex.test(platform.description); } @@ -149,14 +149,16 @@ for (const group of ORT_WEB_TEST_CONFIG.op) { }); after('Dispose Context', async () => { - if (ORT_WEB_TEST_CONFIG.profile) { - if (context instanceof ProtoOpTestContext) { - context.session.endProfiling(); - } else { - OpTestContext.profiler.stop(); + if (context) { + if (ORT_WEB_TEST_CONFIG.profile) { + if (context instanceof ProtoOpTestContext) { + context.session.endProfiling(); + } else { + OpTestContext.profiler.stop(); + } } + await context.dispose(); } - await context.dispose(); }); for (const testCase of test.cases) { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index d923837326..5552a8e299 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -390,7 +390,7 @@ export class TensorResultValidator { case 'uint32': case 'int64': case 'bool': - return this.integerEqual( + return TensorResultValidator.integerEqual( actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | @@ -463,7 +463,7 @@ export class TensorResultValidator { return true; } - integerEqual( + static integerEqual( actual: number[]|Uint8Array|Int8Array|Uint16Array|Int16Array|Uint32Array|Int32Array, expected: number[]|Uint8Array|Int8Array|Uint16Array|Int16Array|Uint32Array|Int32Array): boolean { if (actual.length !== expected.length) { @@ -551,8 +551,8 @@ export class OpTestContext { } createOperator(): Operator { return initializeOperator( - this.sessionHandler, this.opTest.operator, this.opTest.attributes, - this.opTest.opsets ?? [{domain: '', version: 7}]); + this.sessionHandler, this.opTest.operator, this.opTest.attributes || [], + [this.opTest.opset ?? {domain: '', version: 7}]); } async dispose(): Promise { @@ -575,9 +575,9 @@ export class ProtoOpTestContext { session: ort.InferenceSession; readonly backendHint: string; constructor(test: Test.OperatorTest) { - const opsetImport = test.opsets!.map(opset => onnx.OperatorSetIdProto.create(opset)); + const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; - const attribute = test.attributes!.map(attr => { + const attribute = (test.attributes || []).map(attr => { const protoAttr = onnx.AttributeProto.create({name: attr.name}); switch (attr.type) { case 'float': @@ -623,23 +623,70 @@ export class ProtoOpTestContext { const model = onnx.ModelProto.create(); model.irVersion = onnx.Version.IR_VERSION; - model.opsetImport = opsetImport; + model.opsetImport.push(opsetImport); model.graph = onnx.GraphProto.create(); model.graph.node = [onnx.NodeProto.create({ input: test.cases[0].inputs!.map((_, i) => `input_${i}`), output: test.cases[0].outputs!.map((_, i) => `output_${i}`), opType: operator, + domain: test.opset?.domain, name: operator, attribute })]; - model.graph.input = test.cases[0].inputs!.map((input, i) => onnx.ValueInfoProto.create({ - name: `input_${i}`, - type: onnx.TypeProto.create({ - tensorType: onnx.TypeProto.Tensor.create({elemType: tensorDataTypeStringToEnum(input.type)}), - }), - })); + // normalize input shape definitions + let normalizedInputShapeDefinitions: ReadonlyArray; + if (!test.inputShapeDefinitions || test.inputShapeDefinitions === 'none') { + // if inputShapeDefinitions is not specified, use undefined for all inputs + normalizedInputShapeDefinitions = new Array(inputCount).fill(undefined); + } else if (test.inputShapeDefinitions === 'rankOnly') { + // if inputShapeDefinitions is 'rankOnly', use semantic names for all inputs. This means only rank is specified. + normalizedInputShapeDefinitions = + test.cases[0].inputs!.map((input, i) => input.dims.map((_, j) => `_input_${i}_d${j}`)); + + // check if all test cases have the same rank for each inputs + if (test.cases.some( + testCase => + testCase.inputs!.some((input, i) => input.dims.length !== test.cases[0].inputs![i].dims.length))) { + throw new Error(`Test cases for test: ${test.name} [${ + test.operator}] must have the same rank for each inputs in different test cases`); + } + } else if (test.inputShapeDefinitions === 'static') { + // if inputShapeDefinitions is 'static', use the shape of the first test case for all inputs. + normalizedInputShapeDefinitions = test.cases[0].inputs!.map(input => input.dims); + + // check if all test cases have the same shape for each inputs + if (test.cases.some( + testCase => testCase.inputs!.some( + (input, i) => TensorResultValidator.integerEqual(input.dims, test.cases[0].inputs![i].dims)))) { + throw new Error(`Test cases for test: ${test.name} [${ + test.operator}] must have the same shape for each inputs in different test cases`); + } + } else { + // if inputShapeDefinitions is specified as an array, use it as is. + // check if inputShapeDefinitions has the same number of inputs as test cases + if (test.inputShapeDefinitions && test.inputShapeDefinitions.length !== inputCount) { + throw new Error( + `Input shape definitions for test: ${test.name} [${test.operator}] must have the same number of inputs`); + } + normalizedInputShapeDefinitions = test.inputShapeDefinitions; + } + + model.graph.input = test.cases[0].inputs!.map((input, i) => { + const shapeDefinition = normalizedInputShapeDefinitions[i]; + const shape = shapeDefinition ? onnx.TensorShapeProto.create({ + dim: shapeDefinition.map( + dim => onnx.TensorShapeProto.Dimension.create(typeof dim === 'string' ? {dimParam: dim} : {dimValue: dim})) + }) : + undefined; + return onnx.ValueInfoProto.create({ + name: `input_${i}`, + type: onnx.TypeProto.create({ + tensorType: onnx.TypeProto.Tensor.create({elemType: tensorDataTypeStringToEnum(input.type), shape}), + }), + }); + }); model.graph.output = test.cases[0].outputs!.map((output, i) => onnx.ValueInfoProto.create({ name: `output_${i}`, @@ -652,6 +699,19 @@ export class ProtoOpTestContext { this.backendHint = test.backend!; this.loadedData = onnx.ModelProto.encode(model).finish(); + + // in debug mode, open a new tab in browser for the generated onnx model. + if (ort.env.debug) { + const modelFile = + new File([this.loadedData], `op_test_generated_model_${test.name}.onnx`, {type: 'application/octet-stream'}); + const modelTempUrl = URL.createObjectURL(modelFile); + const a = document.createElement('a'); + a.href = modelTempUrl; + a.download = modelFile.name; + a.target = '_blank'; + a.click(); + URL.revokeObjectURL(modelTempUrl); + } } async init(): Promise { this.session = await ort.InferenceSession.create(this.loadedData, {executionProviders: [this.backendHint]}); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index e6afdcafd7..b86ac4e50c 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -33,7 +33,7 @@ export declare namespace Test { * Represent a string to describe the current environment. * Used in ModelTest and OperatorTest to determine whether to run the test or not. */ - export type Condition = string; + export type PlatformCondition = string; export interface ModelTestCase { name: string; @@ -46,7 +46,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time - condition?: Condition; + platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -66,13 +66,16 @@ export declare namespace Test { version: number; } + export type InputShapeDefinition = ReadonlyArray; + export interface OperatorTest { name: string; operator: string; - opsets?: readonly OperatorTestOpsetImport[]; + inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; + opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time - condition?: Condition; - attributes: readonly AttributeValue[]; + platformCondition?: PlatformCondition; + attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; } @@ -86,7 +89,7 @@ export declare namespace Test { export type TestName = string; export interface TestDescription { name: string; - condition: Condition; + platformCondition: PlatformCondition; } export type Test = TestName|TestDescription; }