mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[js/webgpu] Provide a vectorized algorithm for GroupedConv (#18884)
### Description This PR provides a vectorized algorithm for NHWC GroupedConv to improve performance. The aggregate time of GroupedConv in mobilenetv2-12 becomes ~1ms from ~4ms on Intel Alder Lake machine. About 20% improvement for the whole model.
This commit is contained in:
parent
e58319ebfc
commit
fd6bab4250
3 changed files with 271 additions and 6 deletions
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ProgramInfo} from '../types';
|
||||
import {ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {calculateOutputShape, ConvAttributes} from './conv';
|
||||
import {getActivationSnippet} from './fuse-utils';
|
||||
|
||||
|
|
@ -95,3 +95,98 @@ export const createGroupedConvProgramInfo =
|
|||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const createGroupedConvVectorizeProgramInfo =
|
||||
(inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => {
|
||||
const hasBias = inputs.length > 2;
|
||||
const components = getMaxComponents(outputShape[3]);
|
||||
const outputNumber = getMaxComponents(outputShape[2]);
|
||||
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
|
||||
const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components];
|
||||
const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components];
|
||||
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides},
|
||||
{type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape),
|
||||
...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader)
|
||||
];
|
||||
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
|
||||
const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
|
||||
const inputVars = [x, w];
|
||||
if (hasBias) {
|
||||
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components));
|
||||
}
|
||||
const processBias = hasBias ? 'value += b[output_channel];' : '';
|
||||
|
||||
return `
|
||||
${
|
||||
shaderHelper.registerUniform('output_size', 'u32')
|
||||
.registerUniform('strides', 'i32', 2)
|
||||
.registerUniform('pads', 'i32', 2)
|
||||
.declareVariables(...inputVars, output)}
|
||||
${activationFunction}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let width0 = uniforms.output_shape[3];
|
||||
let output_channel = global_idx % width0;
|
||||
var index1 = global_idx / width0;
|
||||
let width1 = uniforms.output_shape[2] / ${outputNumber}u;
|
||||
let col = (index1 % width1) * ${outputNumber}u;
|
||||
index1 = index1 / width1;
|
||||
let row = index1 % uniforms.output_shape[1];
|
||||
let batch = index1 / uniforms.output_shape[1];
|
||||
|
||||
let x_corner = vec2<i32>(i32(row), i32(col)) * uniforms.strides - uniforms.pads;
|
||||
|
||||
var x_vals: array<${x.type.value}, ${xNumber}>;
|
||||
var values: array<${output.type.value}, ${outputNumber}>;
|
||||
let input_channel = output_channel;
|
||||
// Use constant instead of uniform can give better performance for w's height/width.
|
||||
for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) {
|
||||
let x_height = x_corner.x + i32(w_height);
|
||||
if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) {
|
||||
for (var i = 0; i < ${xNumber}; i++) {
|
||||
let x_width = x_corner.y + i;
|
||||
if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) {
|
||||
x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')};
|
||||
} else {
|
||||
x_vals[i] = ${x.type.value}(0);
|
||||
}
|
||||
}
|
||||
for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) {
|
||||
let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')};
|
||||
for (var i = 0u; i < ${outputNumber}u; i++) {
|
||||
values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i = 0u; i < ${outputNumber}u; i++) {
|
||||
var value = values[i];
|
||||
${processBias}
|
||||
${applyActivation}
|
||||
${output.set('batch', 'row', 'col + i', 'output_channel', 'value')};
|
||||
}
|
||||
}`;
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'GroupedConv-Vectorize',
|
||||
shaderCache: {
|
||||
hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
|
||||
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank']
|
||||
},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import {ComputeContext} from '../types';
|
|||
|
||||
import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
|
||||
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
|
||||
import {createGroupedConvProgramInfo} from './conv-grouped';
|
||||
import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped';
|
||||
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
|
||||
import {createNaiveMatmulProgramInfo} from './matmul';
|
||||
import {createTransposeProgramInfo} from './transpose';
|
||||
|
|
@ -136,12 +136,32 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
// check attributes
|
||||
|
||||
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
if (attributes.group !== 1) {
|
||||
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
|
||||
if (isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 &&
|
||||
attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
|
||||
const outputShape = calculateOutputShape(
|
||||
inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,
|
||||
isChannelsLast);
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
}
|
||||
const convInputs = [inputs[0], transposedWeight];
|
||||
if (inputs.length === 3) {
|
||||
convInputs.push(inputs[2]);
|
||||
}
|
||||
context.compute(
|
||||
createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs});
|
||||
} else {
|
||||
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
const hasBias = inputs.length === 3;
|
||||
const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
|
||||
const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
|
||||
|
|
|
|||
|
|
@ -298,7 +298,157 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
|
||||
{
|
||||
"name": "conv - vectorize group - A",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [1, 1], "type": "ints" },
|
||||
{ "name": "group", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
|
||||
"dims": [1, 2, 3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0, 2.0],
|
||||
"dims": [2, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0],
|
||||
"dims": [1, 2, 3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - vectorize group - B",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
|
||||
{ "name": "group", "data": 3, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
|
||||
],
|
||||
"dims": [1, 3, 3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
"dims": [3, 1, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3],
|
||||
"dims": [1, 3, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - vectorize group - C",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
|
||||
{ "name": "group", "data": 3, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
|
||||
],
|
||||
"dims": [1, 3, 3, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
"dims": [3, 1, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374],
|
||||
"dims": [1, 3, 2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - vectorize group - D",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
|
||||
{ "name": "group", "data": 3, "type": "int" },
|
||||
{ "name": "strides", "data": [2, 2], "type": "ints" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0] strides = [2, 2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
|
||||
],
|
||||
"dims": [1, 3, 3, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
"dims": [3, 1, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [34, 54, 386, 438, 1122, 1206],
|
||||
"dims": [1, 3, 1, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - pointwise",
|
||||
"operator": "Conv",
|
||||
|
|
|
|||
Loading…
Reference in a new issue