diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index 8bc69a9ee5..08ec52430a 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -16,6 +16,10 @@ import {WebGLReshapePacked} from './reshape-packed'; export class WebGLConvPacked extends Conv { protected artifacts: Artifact[]; protected programInfo: ProgramInfo[]; + private kernelReshape = new WebGLReshapePacked(); + private im2col: WebGLIm2ColPacked; + private matmul = new WebGLMatMulPacked(); + private outputReshape = new WebGLReshapePacked(); run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { const programManager = inferenceHandler.session.programManager; @@ -35,35 +39,36 @@ export class WebGLConvPacked extends Conv { this.kernelShape}, pads:${this.pads}, strides:${this.strides}`); const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); - const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); - const matmul = new WebGLMatMulPacked(); + if (this.im2col === undefined) { + this.im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); + } if (this.activation) { const attributes = new Attribute(undefined); attributes.set('__internal_activation', 'string', (this.activation)); - matmul.initialize(attributes); + this.matmul.initialize(attributes); } - const reshape = new WebGLReshapePacked(); // shape for kernel reshape const shape = new Tensor([2], 'int32', undefined, undefined, new Int32Array([kshape[0], kshape[1] * kshape[2] * kshape[3]])); if (!this.artifacts) { this.artifacts = []; this.programInfo = []; - this.programInfo[0] = im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]); + this.programInfo[0] = this.im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]); this.artifacts[0] = programManager.build(this.programInfo[0]); - this.programInfo[1] = reshape.createProgramInfo(inferenceHandler, [inputs[1], shape]); + this.programInfo[1] = this.kernelReshape.createProgramInfo(inferenceHandler, [inputs[1], shape]); this.artifacts[1] = programManager.build(this.programInfo[1]); } // run im2col - const runDataIm2col = im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]); + const runDataIm2col = this.im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[0], runDataIm2col); programManager.run(this.artifacts[0], runDataIm2col); const im2colOutput = runDataIm2col.outputTextureData.tensor; // reshape kernel - const runDataKernelReshape = reshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]); + const runDataKernelReshape = + this.kernelReshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[1], runDataKernelReshape); programManager.run(this.artifacts[1], runDataKernelReshape); const kernelReshaped = runDataKernelReshape.outputTextureData.tensor; @@ -72,11 +77,11 @@ export class WebGLConvPacked extends Conv { const hasBias = (inputs.length === 3); assert(this.artifacts.length > 1, () => 'expect at least 2 artifacts created'); if (this.artifacts.length === 2) { - this.programInfo[2] = matmul.createProgramInfo( + this.programInfo[2] = this.matmul.createProgramInfo( inferenceHandler, hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]); this.artifacts[2] = programManager.build(this.programInfo[2]); } - const runDataMatmul = matmul.createRunData( + const runDataMatmul = this.matmul.createRunData( inferenceHandler, this.programInfo[2], hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[2], runDataMatmul); @@ -90,11 +95,11 @@ export class WebGLConvPacked extends Conv { assert(this.artifacts.length > 2, () => 'expect at least 3 artifacts created'); if (this.artifacts.length === 3) { - this.programInfo[3] = reshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]); + this.programInfo[3] = this.outputReshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]); this.artifacts[3] = programManager.build(this.programInfo[3]); } const runDataOutputReshape = - reshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]); + this.outputReshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[3], runDataOutputReshape); programManager.run(this.artifacts[3], runDataOutputReshape); return [runDataOutputReshape.outputTextureData.tensor]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts index 02adbc0506..ea172cecfa 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts @@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator { if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) { offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]}; - d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} ); + d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]}); if(d0 < ${xshape[rowDim]} && d0 >= 0) { - offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${ - this.pads[0]}.); - d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.))); + offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]}; + d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]}); if(d1 < ${xshape[colDim]} && d1 >= 0) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts index 8e023536e3..49f09819a1 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts @@ -6,9 +6,8 @@ import {Tensor} from '../../../tensor'; import {ShapeUtil} from '../../../util'; import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, RunData, WebGLOperator} from '../types'; +import {ProgramInfo, RunData, TextureData, WebGLOperator} from '../types'; import {TextureLayout} from '../types'; - import {unpackFromChannel} from './packing-utils'; export class WebGLReshapePacked extends Reshape implements WebGLOperator { @@ -32,18 +31,22 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { // the same between input shape and output shape, the packed reshape can be // treated as no-op. const originInputShape = inputs[0].dims; - const inputShape3D = processDims3D(inputs[0].dims); + this.inputShape3D = processDims3D(inputs[0].dims); let inputLayout: TextureLayout; - if (originInputShape.length === 3) { - inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true); - } else { + inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true); + if (originInputShape.length !== 3) { + const originalInputLayout = inputLayout; // if originShape is not a 3D shape, create texture layout from the processed shape. - inputLayout = - handler.createTextureLayoutFromShape(inputShape3D, 4, inputShape3D, {isPacked: true, reverseWH: true}); + inputLayout = handler.createTextureLayoutFromShape( + this.inputShape3D, 4, this.inputShape3D, {isPacked: true, reverseWH: true}); + // if the processed input shape produces texture layout differnt from the original + // one, the run data has to use the processed (3D) input shape later. + this.needSqueezeInputData = + (inputLayout.height !== originalInputLayout.height) || (inputLayout.width !== originalInputLayout.width); } - const outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData); - const squeezedOutputShape = processDims3D(outputShape); + this.outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData); + const squeezedOutputShape = processDims3D(this.outputShape); this.outputLayout = handler.createTextureLayoutFromShape( squeezedOutputShape, 4, squeezedOutputShape, {isPacked: true, reverseWH: true}); @@ -84,9 +87,10 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { const glsl = getGlsl(handler.session.backend.glContext.version); const shaderSource = ` - ${getReshapedInputCoords(inputShape3D)} + ${getReshapedInputCoords(this.inputShape3D)} ${getFlattenedIndexFrom3D(squeezedOutputShape)} ${unpackFromChannel()} + void main() { ivec3 rc = getOutputCoords(); @@ -97,7 +101,6 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { int cols = ${squeezedOutputShape[1]}; ${mainLoop} - ${glsl.output} = result; } `; @@ -113,14 +116,30 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { }; } createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { - const inputTDs = - [handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))]; + let inputTDs: [TextureData]; + const originalInputLayout = handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false); + const originalInputTD = handler.getOrCreateTextureData(inputs[0], originalInputLayout, false); + + if (this.needSqueezeInputData) { + const squeezedInputLayout: TextureLayout = { + channels: 1, + height: originalInputLayout.height, + width: originalInputLayout.width, + shape: this.inputShape3D, + strides: ShapeUtil.computeStrides(this.inputShape3D), + unpackedShape: this.inputShape3D, + }; + const squeezedInputTD = + handler.createSharedTextureData(squeezedInputLayout, inputs[0].type, originalInputTD.texture); + inputTDs = [squeezedInputTD]; + + } else { + inputTDs = [originalInputTD]; + } let outputLayout = this.outputLayout; if (outputLayout === undefined) { - const originInputShape = inputs[0].dims; - const outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData); - outputLayout = - handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true}); + outputLayout = handler.createTextureLayoutFromShape( + this.outputShape, 4, this.outputShape, {isPacked: true, reverseWH: true}); } // return run data for reshape. Here, we use the original calculate outputLayout to create the real output layout. return { @@ -129,6 +148,9 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { uniformData: {} }; } + protected outputShape: readonly number[]; + private inputShape3D: [number, number, number]; + private needSqueezeInputData = false; private outputLayout: TextureLayout; } diff --git a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts index cfd6e3d459..d39d033142 100644 --- a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts +++ b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts @@ -90,6 +90,11 @@ function getTestData(): TestData[] { inputShape: [2, 2, 2, 4], outputShape: [2, 1, 4, 4], }, + { + elementCount: 18432, + inputShape: [512, 36, 1, 1], + outputShape: [512, 36], + }, { elementCount: 18432, inputShape: [512, 36],