mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
150 lines
3.8 KiB
TypeScript
150 lines
3.8 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {expect} from 'chai';
|
|
import {env} from 'onnxruntime-common';
|
|
|
|
import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend';
|
|
import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler';
|
|
import {Profiler} from '../../../../lib/onnxjs/instrument';
|
|
import {Tensor} from '../../../../lib/onnxjs/tensor';
|
|
|
|
import {createAscendingArray} from './test-utils';
|
|
|
|
interface TestData {
|
|
elementCount: number;
|
|
inputShape: number[];
|
|
outputShape: number[];
|
|
}
|
|
function getTestData(): TestData[] {
|
|
return [
|
|
// test 2D tensor
|
|
{
|
|
elementCount: 16,
|
|
inputShape: [4, 4],
|
|
outputShape: [2, 8],
|
|
},
|
|
{
|
|
elementCount: 16,
|
|
inputShape: [4, 4],
|
|
outputShape: [1, 16],
|
|
},
|
|
{
|
|
elementCount: 8,
|
|
inputShape: [2, 4],
|
|
outputShape: [4, 2],
|
|
},
|
|
{
|
|
elementCount: 8,
|
|
inputShape: [2, 4],
|
|
outputShape: [1, 8],
|
|
},
|
|
{
|
|
elementCount: 6,
|
|
inputShape: [2, 3],
|
|
outputShape: [1, 6],
|
|
},
|
|
{
|
|
elementCount: 6,
|
|
inputShape: [2, 3],
|
|
outputShape: [3, 2],
|
|
},
|
|
|
|
// test 3d tensor
|
|
{
|
|
elementCount: 16,
|
|
inputShape: [2, 2, 4],
|
|
outputShape: [4, 2, 2],
|
|
},
|
|
{
|
|
elementCount: 16,
|
|
inputShape: [2, 2, 4],
|
|
outputShape: [2, 4, 2],
|
|
},
|
|
{
|
|
elementCount: 16,
|
|
inputShape: [2, 2, 4],
|
|
outputShape: [1, 1, 2, 8],
|
|
},
|
|
|
|
// test 4d tensor
|
|
{
|
|
elementCount: 32,
|
|
inputShape: [2, 2, 2, 4],
|
|
outputShape: [4, 2, 2, 2],
|
|
},
|
|
{
|
|
elementCount: 32,
|
|
inputShape: [2, 2, 2, 4],
|
|
outputShape: [2, 4, 2, 2],
|
|
},
|
|
|
|
{
|
|
elementCount: 32,
|
|
inputShape: [2, 2, 2, 4],
|
|
outputShape: [2, 2, 4, 2],
|
|
},
|
|
{
|
|
elementCount: 32,
|
|
inputShape: [2, 2, 2, 4],
|
|
outputShape: [2, 1, 4, 4],
|
|
},
|
|
{
|
|
elementCount: 18432,
|
|
inputShape: [512, 36, 1, 1],
|
|
outputShape: [512, 36],
|
|
},
|
|
{
|
|
elementCount: 18432,
|
|
inputShape: [512, 36],
|
|
outputShape: [512, 36, 1, 1],
|
|
},
|
|
];
|
|
}
|
|
|
|
let backend: Backend|undefined;
|
|
let sessionhandler: SessionHandler|undefined;
|
|
let inferenceHandler: InferenceHandler|undefined;
|
|
|
|
describe('#UnitTest# - reshape - packed', () => {
|
|
before('Initialize Context', async () => {
|
|
const profiler = Profiler.create();
|
|
backend = await resolveBackend('webgl');
|
|
sessionhandler = backend.createSessionHandler({profiler});
|
|
inferenceHandler = sessionhandler.createInferenceHandler();
|
|
});
|
|
|
|
const testDataSet = getTestData();
|
|
for (let k = 0; k < testDataSet.length; ++k) {
|
|
const testData = testDataSet[k];
|
|
describe(`Test reshape ${JSON.stringify(testData)}`, () => {});
|
|
it(`Test packed reshape kernel ${JSON.stringify(testData.outputShape)}`, () => {
|
|
const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler;
|
|
|
|
if (!env.webgl.pack) {
|
|
console.log('Skipping in unpacked texture mode.');
|
|
return;
|
|
}
|
|
|
|
const elementCount = testData.elementCount;
|
|
const inputTensorShape = testData.inputShape;
|
|
const outputTensorShape = testData.outputShape;
|
|
|
|
// create input data and tensor.
|
|
const inputData = createAscendingArray(elementCount);
|
|
const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData);
|
|
|
|
// run kernal and get output
|
|
const resultTensor = webglInferenceHandler.reshapePacked(inputTensorA, outputTensorShape);
|
|
const result = resultTensor.data;
|
|
|
|
webglInferenceHandler.session.textureManager.glContext.checkError();
|
|
// verify result.
|
|
expect(result).to.not.equal(null);
|
|
|
|
expect(result).to.have.lengthOf(elementCount);
|
|
|
|
expect(result).to.deep.equal(inputData);
|
|
});
|
|
}
|
|
});
|