mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidator (#19358)
### Description use ApiTensor insteadof onnxjs Tensor in TensorResultValidator. Make test runner less depend on onnxjs classes.
This commit is contained in:
parent
3fe2c137ee
commit
70567a4b3a
2 changed files with 13 additions and 17 deletions
|
|
@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001;
|
|||
*/
|
||||
const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now;
|
||||
|
||||
function toInternalTensor(tensor: ort.Tensor): Tensor {
|
||||
return new Tensor(
|
||||
tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType);
|
||||
}
|
||||
function fromInternalTensor(tensor: Tensor): ort.Tensor {
|
||||
return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims);
|
||||
}
|
||||
|
|
@ -330,6 +326,10 @@ export class TensorResultValidator {
|
|||
}
|
||||
|
||||
checkTensorResult(actual: Tensor[], expected: Tensor[]): void {
|
||||
this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor));
|
||||
}
|
||||
|
||||
checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
|
||||
// check output size
|
||||
expect(actual.length, 'size of output tensors').to.equal(expected.length);
|
||||
|
||||
|
|
@ -347,10 +347,6 @@ export class TensorResultValidator {
|
|||
}
|
||||
}
|
||||
|
||||
checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
|
||||
this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor));
|
||||
}
|
||||
|
||||
checkNamedTensorResult(actual: Record<string, ort.Tensor>, expected: Test.NamedTensor[]): void {
|
||||
// check output size
|
||||
expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length);
|
||||
|
|
@ -364,7 +360,7 @@ export class TensorResultValidator {
|
|||
}
|
||||
|
||||
// This function check whether 2 tensors should be considered as 'match' or not
|
||||
areEqual(actual: Tensor, expected: Tensor): boolean {
|
||||
areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean {
|
||||
if (!actual || !expected) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -392,13 +388,13 @@ export class TensorResultValidator {
|
|||
|
||||
switch (actualType) {
|
||||
case 'string':
|
||||
return this.strictEqual(actual.stringData, expected.stringData);
|
||||
return this.strictEqual(actual.data, expected.data);
|
||||
|
||||
case 'float32':
|
||||
case 'float64':
|
||||
return this.floatEqual(
|
||||
actual.numberData as number[] | Float32Array | Float64Array,
|
||||
expected.numberData as number[] | Float32Array | Float64Array);
|
||||
actual.data as number[] | Float32Array | Float64Array,
|
||||
expected.data as number[] | Float32Array | Float64Array);
|
||||
|
||||
case 'uint8':
|
||||
case 'int8':
|
||||
|
|
@ -409,10 +405,8 @@ export class TensorResultValidator {
|
|||
case 'int64':
|
||||
case 'bool':
|
||||
return TensorResultValidator.integerEqual(
|
||||
actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
|
||||
Int32Array,
|
||||
expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
|
||||
Int32Array);
|
||||
actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
|
||||
expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array);
|
||||
|
||||
default:
|
||||
throw new Error('type not implemented or not supported');
|
||||
|
|
|
|||
|
|
@ -893,7 +893,9 @@ describe('New Conv tests', () => {
|
|||
const expected = cpuConv(
|
||||
inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads,
|
||||
testData.strides);
|
||||
if (!validator.areEqual(actual, expected)) {
|
||||
try {
|
||||
validator.checkTensorResult([actual], [expected]);
|
||||
} catch {
|
||||
console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`);
|
||||
console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`);
|
||||
throw new Error('Expected and Actual did not match');
|
||||
|
|
|
|||
Loading…
Reference in a new issue