mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[Web/JS] Fixing two bugs in reshape_pack and im2col_pack (#7689)
* fixing two bugs in reshape_pack and im2col_pack * minor fix * fix lint complaints
This commit is contained in:
parent
79854dda8f
commit
d3c4b70ede
4 changed files with 65 additions and 34 deletions
|
|
@ -16,6 +16,10 @@ import {WebGLReshapePacked} from './reshape-packed';
|
|||
export class WebGLConvPacked extends Conv {
|
||||
protected artifacts: Artifact[];
|
||||
protected programInfo: ProgramInfo[];
|
||||
private kernelReshape = new WebGLReshapePacked();
|
||||
private im2col: WebGLIm2ColPacked;
|
||||
private matmul = new WebGLMatMulPacked();
|
||||
private outputReshape = new WebGLReshapePacked();
|
||||
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const programManager = inferenceHandler.session.programManager;
|
||||
|
|
@ -35,35 +39,36 @@ export class WebGLConvPacked extends Conv {
|
|||
this.kernelShape}, pads:${this.pads}, strides:${this.strides}`);
|
||||
|
||||
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
|
||||
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
|
||||
const matmul = new WebGLMatMulPacked();
|
||||
if (this.im2col === undefined) {
|
||||
this.im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
|
||||
}
|
||||
if (this.activation) {
|
||||
const attributes = new Attribute(undefined);
|
||||
attributes.set('__internal_activation', 'string', (this.activation));
|
||||
matmul.initialize(attributes);
|
||||
this.matmul.initialize(attributes);
|
||||
}
|
||||
const reshape = new WebGLReshapePacked();
|
||||
// shape for kernel reshape
|
||||
const shape =
|
||||
new Tensor([2], 'int32', undefined, undefined, new Int32Array([kshape[0], kshape[1] * kshape[2] * kshape[3]]));
|
||||
if (!this.artifacts) {
|
||||
this.artifacts = [];
|
||||
this.programInfo = [];
|
||||
this.programInfo[0] = im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]);
|
||||
this.programInfo[0] = this.im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]);
|
||||
this.artifacts[0] = programManager.build(this.programInfo[0]);
|
||||
|
||||
this.programInfo[1] = reshape.createProgramInfo(inferenceHandler, [inputs[1], shape]);
|
||||
this.programInfo[1] = this.kernelReshape.createProgramInfo(inferenceHandler, [inputs[1], shape]);
|
||||
this.artifacts[1] = programManager.build(this.programInfo[1]);
|
||||
}
|
||||
|
||||
// run im2col
|
||||
const runDataIm2col = im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]);
|
||||
const runDataIm2col = this.im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]);
|
||||
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[0], runDataIm2col);
|
||||
programManager.run(this.artifacts[0], runDataIm2col);
|
||||
const im2colOutput = runDataIm2col.outputTextureData.tensor;
|
||||
|
||||
// reshape kernel
|
||||
const runDataKernelReshape = reshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]);
|
||||
const runDataKernelReshape =
|
||||
this.kernelReshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]);
|
||||
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[1], runDataKernelReshape);
|
||||
programManager.run(this.artifacts[1], runDataKernelReshape);
|
||||
const kernelReshaped = runDataKernelReshape.outputTextureData.tensor;
|
||||
|
|
@ -72,11 +77,11 @@ export class WebGLConvPacked extends Conv {
|
|||
const hasBias = (inputs.length === 3);
|
||||
assert(this.artifacts.length > 1, () => 'expect at least 2 artifacts created');
|
||||
if (this.artifacts.length === 2) {
|
||||
this.programInfo[2] = matmul.createProgramInfo(
|
||||
this.programInfo[2] = this.matmul.createProgramInfo(
|
||||
inferenceHandler, hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]);
|
||||
this.artifacts[2] = programManager.build(this.programInfo[2]);
|
||||
}
|
||||
const runDataMatmul = matmul.createRunData(
|
||||
const runDataMatmul = this.matmul.createRunData(
|
||||
inferenceHandler, this.programInfo[2],
|
||||
hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]);
|
||||
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[2], runDataMatmul);
|
||||
|
|
@ -90,11 +95,11 @@ export class WebGLConvPacked extends Conv {
|
|||
|
||||
assert(this.artifacts.length > 2, () => 'expect at least 3 artifacts created');
|
||||
if (this.artifacts.length === 3) {
|
||||
this.programInfo[3] = reshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]);
|
||||
this.programInfo[3] = this.outputReshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]);
|
||||
this.artifacts[3] = programManager.build(this.programInfo[3]);
|
||||
}
|
||||
const runDataOutputReshape =
|
||||
reshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]);
|
||||
this.outputReshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]);
|
||||
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[3], runDataOutputReshape);
|
||||
programManager.run(this.artifacts[3], runDataOutputReshape);
|
||||
return [runDataOutputReshape.outputTextureData.tensor];
|
||||
|
|
|
|||
|
|
@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator {
|
|||
|
||||
if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
|
||||
offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]};
|
||||
d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} );
|
||||
d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});
|
||||
|
||||
if(d0 < ${xshape[rowDim]} && d0 >= 0) {
|
||||
offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${
|
||||
this.pads[0]}.);
|
||||
d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.)));
|
||||
offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]};
|
||||
d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});
|
||||
|
||||
if(d1 < ${xshape[colDim]} && d1 >= 0) {
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ import {Tensor} from '../../../tensor';
|
|||
import {ShapeUtil} from '../../../util';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
import {ProgramInfo, RunData, TextureData, WebGLOperator} from '../types';
|
||||
import {TextureLayout} from '../types';
|
||||
|
||||
import {unpackFromChannel} from './packing-utils';
|
||||
|
||||
export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
||||
|
|
@ -32,18 +31,22 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
|||
// the same between input shape and output shape, the packed reshape can be
|
||||
// treated as no-op.
|
||||
const originInputShape = inputs[0].dims;
|
||||
const inputShape3D = processDims3D(inputs[0].dims);
|
||||
this.inputShape3D = processDims3D(inputs[0].dims);
|
||||
let inputLayout: TextureLayout;
|
||||
if (originInputShape.length === 3) {
|
||||
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
|
||||
} else {
|
||||
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
|
||||
if (originInputShape.length !== 3) {
|
||||
const originalInputLayout = inputLayout;
|
||||
// if originShape is not a 3D shape, create texture layout from the processed shape.
|
||||
inputLayout =
|
||||
handler.createTextureLayoutFromShape(inputShape3D, 4, inputShape3D, {isPacked: true, reverseWH: true});
|
||||
inputLayout = handler.createTextureLayoutFromShape(
|
||||
this.inputShape3D, 4, this.inputShape3D, {isPacked: true, reverseWH: true});
|
||||
// if the processed input shape produces texture layout differnt from the original
|
||||
// one, the run data has to use the processed (3D) input shape later.
|
||||
this.needSqueezeInputData =
|
||||
(inputLayout.height !== originalInputLayout.height) || (inputLayout.width !== originalInputLayout.width);
|
||||
}
|
||||
|
||||
const outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
|
||||
const squeezedOutputShape = processDims3D(outputShape);
|
||||
this.outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
|
||||
const squeezedOutputShape = processDims3D(this.outputShape);
|
||||
|
||||
this.outputLayout = handler.createTextureLayoutFromShape(
|
||||
squeezedOutputShape, 4, squeezedOutputShape, {isPacked: true, reverseWH: true});
|
||||
|
|
@ -84,9 +87,10 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
|||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
|
||||
const shaderSource = `
|
||||
${getReshapedInputCoords(inputShape3D)}
|
||||
${getReshapedInputCoords(this.inputShape3D)}
|
||||
${getFlattenedIndexFrom3D(squeezedOutputShape)}
|
||||
${unpackFromChannel()}
|
||||
|
||||
void main() {
|
||||
ivec3 rc = getOutputCoords();
|
||||
|
||||
|
|
@ -97,7 +101,6 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
|||
int cols = ${squeezedOutputShape[1]};
|
||||
|
||||
${mainLoop}
|
||||
|
||||
${glsl.output} = result;
|
||||
}
|
||||
`;
|
||||
|
|
@ -113,14 +116,30 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
|||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs =
|
||||
[handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))];
|
||||
let inputTDs: [TextureData];
|
||||
const originalInputLayout = handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false);
|
||||
const originalInputTD = handler.getOrCreateTextureData(inputs[0], originalInputLayout, false);
|
||||
|
||||
if (this.needSqueezeInputData) {
|
||||
const squeezedInputLayout: TextureLayout = {
|
||||
channels: 1,
|
||||
height: originalInputLayout.height,
|
||||
width: originalInputLayout.width,
|
||||
shape: this.inputShape3D,
|
||||
strides: ShapeUtil.computeStrides(this.inputShape3D),
|
||||
unpackedShape: this.inputShape3D,
|
||||
};
|
||||
const squeezedInputTD =
|
||||
handler.createSharedTextureData(squeezedInputLayout, inputs[0].type, originalInputTD.texture);
|
||||
inputTDs = [squeezedInputTD];
|
||||
|
||||
} else {
|
||||
inputTDs = [originalInputTD];
|
||||
}
|
||||
let outputLayout = this.outputLayout;
|
||||
if (outputLayout === undefined) {
|
||||
const originInputShape = inputs[0].dims;
|
||||
const outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
|
||||
outputLayout =
|
||||
handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true});
|
||||
outputLayout = handler.createTextureLayoutFromShape(
|
||||
this.outputShape, 4, this.outputShape, {isPacked: true, reverseWH: true});
|
||||
}
|
||||
// return run data for reshape. Here, we use the original calculate outputLayout to create the real output layout.
|
||||
return {
|
||||
|
|
@ -129,6 +148,9 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
|
|||
uniformData: {}
|
||||
};
|
||||
}
|
||||
protected outputShape: readonly number[];
|
||||
private inputShape3D: [number, number, number];
|
||||
private needSqueezeInputData = false;
|
||||
private outputLayout: TextureLayout;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,11 @@ function getTestData(): TestData[] {
|
|||
inputShape: [2, 2, 2, 4],
|
||||
outputShape: [2, 1, 4, 4],
|
||||
},
|
||||
{
|
||||
elementCount: 18432,
|
||||
inputShape: [512, 36, 1, 1],
|
||||
outputShape: [512, 36],
|
||||
},
|
||||
{
|
||||
elementCount: 18432,
|
||||
inputShape: [512, 36],
|
||||
|
|
|
|||
Loading…
Reference in a new issue