From fd6bab4250c41a7f6498e6fa02ba446bc74e0a8d Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jan 2024 08:12:43 +0800 Subject: [PATCH] [js/webgpu] Provide a vectorized algorithm for GroupedConv (#18884) ### Description This PR provides a vectorized algorithm for NHWC GroupedConv to improve performance. The aggregate time of GroupedConv in mobilenetv2-12 becomes ~1ms from ~4ms on Intel Alder Lake machine. About 20% improvement for the whole model. --- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 99 +++++++++++- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 26 ++- js/web/test/data/ops/conv.jsonc | 152 +++++++++++++++++- 3 files changed, 271 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 14482272ba..21b4953d3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo} from '../types'; +import {ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -95,3 +95,98 @@ export const createGroupedConvProgramInfo = getShaderSource, }; }; + +export const createGroupedConvVectorizeProgramInfo = + (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => { + const hasBias = inputs.length > 2; + const components = getMaxComponents(outputShape[3]); + const outputNumber = getMaxComponents(outputShape[2]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components]; + const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components]; + const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; + + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), + ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + ]; + const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length, components); + const w = inputVariable('w', inputs[1].dataType, wShape.length, components); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); + } + const processBias = hasBias ? 'value += b[output_channel];' : ''; + + return ` + ${ + shaderHelper.registerUniform('output_size', 'u32') + .registerUniform('strides', 'i32', 2) + .registerUniform('pads', 'i32', 2) + .declareVariables(...inputVars, output)} + ${activationFunction} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let width0 = uniforms.output_shape[3]; + let output_channel = global_idx % width0; + var index1 = global_idx / width0; + let width1 = uniforms.output_shape[2] / ${outputNumber}u; + let col = (index1 % width1) * ${outputNumber}u; + index1 = index1 / width1; + let row = index1 % uniforms.output_shape[1]; + let batch = index1 / uniforms.output_shape[1]; + + let x_corner = vec2(i32(row), i32(col)) * uniforms.strides - uniforms.pads; + + var x_vals: array<${x.type.value}, ${xNumber}>; + var values: array<${output.type.value}, ${outputNumber}>; + let input_channel = output_channel; + // Use constant instead of uniform can give better performance for w's height/width. + for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { + let x_height = x_corner.x + i32(w_height); + if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + for (var i = 0; i < ${xNumber}; i++) { + let x_width = x_corner.y + i; + if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { + x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')}; + } else { + x_vals[i] = ${x.type.value}(0); + } + } + for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { + let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; + for (var i = 0u; i < ${outputNumber}u; i++) { + values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + } + } + } + } + + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + ${output.set('batch', 'row', 'col + i', 'output_channel', 'value')}; + } + }`; + }; + + return { + name: 'GroupedConv-Vectorize', + shaderCache: { + hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource, + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 33a5db7ff6..cb40a9f08d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -8,7 +8,7 @@ import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createGroupedConvProgramInfo} from './conv-grouped'; +import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; import {createNaiveMatmulProgramInfo} from './matmul'; import {createTransposeProgramInfo} from './transpose'; @@ -136,12 +136,32 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ + const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + if (isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && + attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { + const outputShape = calculateOutputShape( + inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, + isChannelsLast); + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + const convInputs = [inputs[0], transposedWeight]; + if (inputs.length === 3) { + convInputs.push(inputs[2]); + } + context.compute( + createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs}); + } else { + context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + } return; } - const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 2e8eaaba19..cc10df5864 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -298,7 +298,157 @@ } ] }, - + { + "name": "conv - vectorize group - A", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - B", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - C", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374], + "dims": [1, 3, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - D", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0] strides = [2, 2]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 54, 386, 438, 1122, 1206], + "dims": [1, 3, 1, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "conv - pointwise", "operator": "Conv",