onnxruntime/js/web/test/unittests/backends/webgl/test-reshape-packed.ts
Yulong Wang af21a04977
[js] upgrade async@3.2.3 /js/ (#11421)
* [js] upgrade async@3.2.3 /js/

* format code
2022-05-03 23:41:36 -07:00

150 lines
3.8 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {expect} from 'chai';
import {env} from 'onnxruntime-common';
import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend';
import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler';
import {Profiler} from '../../../../lib/onnxjs/instrument';
import {Tensor} from '../../../../lib/onnxjs/tensor';
import {createAscendingArray} from './test-utils';
interface TestData {
elementCount: number;
inputShape: number[];
outputShape: number[];
}
function getTestData(): TestData[] {
return [
// test 2D tensor
{
elementCount: 16,
inputShape: [4, 4],
outputShape: [2, 8],
},
{
elementCount: 16,
inputShape: [4, 4],
outputShape: [1, 16],
},
{
elementCount: 8,
inputShape: [2, 4],
outputShape: [4, 2],
},
{
elementCount: 8,
inputShape: [2, 4],
outputShape: [1, 8],
},
{
elementCount: 6,
inputShape: [2, 3],
outputShape: [1, 6],
},
{
elementCount: 6,
inputShape: [2, 3],
outputShape: [3, 2],
},
// test 3d tensor
{
elementCount: 16,
inputShape: [2, 2, 4],
outputShape: [4, 2, 2],
},
{
elementCount: 16,
inputShape: [2, 2, 4],
outputShape: [2, 4, 2],
},
{
elementCount: 16,
inputShape: [2, 2, 4],
outputShape: [1, 1, 2, 8],
},
// test 4d tensor
{
elementCount: 32,
inputShape: [2, 2, 2, 4],
outputShape: [4, 2, 2, 2],
},
{
elementCount: 32,
inputShape: [2, 2, 2, 4],
outputShape: [2, 4, 2, 2],
},
{
elementCount: 32,
inputShape: [2, 2, 2, 4],
outputShape: [2, 2, 4, 2],
},
{
elementCount: 32,
inputShape: [2, 2, 2, 4],
outputShape: [2, 1, 4, 4],
},
{
elementCount: 18432,
inputShape: [512, 36, 1, 1],
outputShape: [512, 36],
},
{
elementCount: 18432,
inputShape: [512, 36],
outputShape: [512, 36, 1, 1],
},
];
}
let backend: Backend|undefined;
let sessionhandler: SessionHandler|undefined;
let inferenceHandler: InferenceHandler|undefined;
describe('#UnitTest# - reshape - packed', () => {
before('Initialize Context', async () => {
const profiler = Profiler.create();
backend = await resolveBackend('webgl');
sessionhandler = backend.createSessionHandler({profiler});
inferenceHandler = sessionhandler.createInferenceHandler();
});
const testDataSet = getTestData();
for (let k = 0; k < testDataSet.length; ++k) {
const testData = testDataSet[k];
describe(`Test reshape ${JSON.stringify(testData)}`, () => {});
it(`Test packed reshape kernel ${JSON.stringify(testData.outputShape)}`, () => {
const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler;
if (!env.webgl.pack) {
console.log('Skipping in unpacked texture mode.');
return;
}
const elementCount = testData.elementCount;
const inputTensorShape = testData.inputShape;
const outputTensorShape = testData.outputShape;
// create input data and tensor.
const inputData = createAscendingArray(elementCount);
const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData);
// run kernal and get output
const resultTensor = webglInferenceHandler.reshapePacked(inputTensorA, outputTensorShape);
const result = resultTensor.data;
webglInferenceHandler.session.textureManager.glContext.checkError();
// verify result.
expect(result).to.not.equal(null);
expect(result).to.have.lengthOf(elementCount);
expect(result).to.deep.equal(inputData);
});
}
});