mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
[js/webgpu] Enable the NCHW ConvMatMul path (#17717)
1) Enable pointwise NCHW conv2d by MatMul. 2) Enable non-pointwise NCHW conv2d by convMatMul. 3) Fix bug when `sameSize` is true --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
parent
1bc115719c
commit
db3901ab97
5 changed files with 99 additions and 49 deletions
|
|
@ -163,17 +163,14 @@ export const createConv2DMatMulProgramInfo =
|
|||
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
|
||||
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
|
||||
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
|
||||
const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) ||
|
||||
(outWidth % 4 === 0 && !isChannelsLast)) &&
|
||||
outChannels % 4 === 0;
|
||||
// TODO: enable vec4 for NCHW
|
||||
const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0;
|
||||
|
||||
// TODO: fine tune size
|
||||
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
|
||||
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
|
||||
const workGroupSize: [number, number, number] =
|
||||
isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
|
||||
const elementsPerThread =
|
||||
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1];
|
||||
const workGroupSize: [number, number, number] = [8, 8, 1];
|
||||
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
|
||||
const dispatch = [
|
||||
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
|
||||
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
|
||||
|
|
|
|||
|
|
@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source =
|
|||
workPerThread[0]} must be 4.`);
|
||||
}
|
||||
return `
|
||||
var<workgroup> mm_Asub : array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
|
||||
var<workgroup> mm_Bsub : array<array<vec4<${type}>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
|
||||
var<workgroup> mm_Asub: array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
|
||||
var<workgroup> mm_Bsub: array<array<vec4<${type}>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
|
||||
|
||||
const rowPerThread = ${workPerThread[1]};
|
||||
const colPerThread = ${workPerThread[0]};
|
||||
|
|
@ -339,7 +339,8 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
|
|||
};
|
||||
|
||||
const matMulReadWriteFnSource =
|
||||
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => {
|
||||
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[],
|
||||
isChannelsLast = false): string => {
|
||||
const batchAVariable = variables[0];
|
||||
const batchBVariable = variables[1];
|
||||
const batchVariable = variables[2];
|
||||
|
|
@ -407,7 +408,10 @@ const matMulReadWriteFnSource =
|
|||
if (row < dimAOuter && col < dimBOuter) {
|
||||
var value = valueIn;
|
||||
let coords = vec3<i32>(batch, row, colIn);
|
||||
${hasBias ? 'value = value + bias[colIn];' : ''}
|
||||
${
|
||||
hasBias ?
|
||||
`value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` :
|
||||
'' }
|
||||
${applyActivation}
|
||||
${outputVariable.setByIndices('vec3<u32>(coords)', 'value')}
|
||||
}
|
||||
|
|
@ -418,7 +422,8 @@ const matMulReadWriteFnSource =
|
|||
|
||||
export const createMatmulProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
|
||||
outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => {
|
||||
outputShape: readonly number[], reshapedOutputShape?: readonly number[],
|
||||
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
|
||||
const aShape = inputs[0].dims;
|
||||
const bShape = inputs[1].dims;
|
||||
|
||||
|
|
@ -457,9 +462,10 @@ export const createMatmulProgramInfo =
|
|||
variables.push(output);
|
||||
const inputVariables = [A, B];
|
||||
const hasBias = inputs.length > 2;
|
||||
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables);
|
||||
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, isChannelsLast);
|
||||
if (hasBias) {
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components));
|
||||
const biasComponents = isChannelsLast ? components : 1;
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
|
||||
}
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const dimAOuter: i32 = ${dimAOuter};
|
||||
|
|
|
|||
|
|
@ -134,15 +134,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
|
||||
// check attributes
|
||||
|
||||
const hasBias = inputs.length === 3;
|
||||
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
if (!isChannelsLast || attributes.group !== 1) {
|
||||
if (attributes.group !== 1) {
|
||||
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
|
||||
return;
|
||||
}
|
||||
|
||||
// const batchSize = context.inputs[0].dims[0];
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
const hasBias = inputs.length === 3;
|
||||
const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
|
||||
const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
|
||||
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
|
||||
|
|
@ -155,47 +154,59 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
const outHeight = outputShape[isChannelsLast ? 1 : 2];
|
||||
const outWidth = outputShape[isChannelsLast ? 2 : 3];
|
||||
const outChannels = outputShape[isChannelsLast ? 3 : 1];
|
||||
const batch = outputShape[0];
|
||||
|
||||
const sameSize =
|
||||
isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID';
|
||||
const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth &&
|
||||
attributes.pads[0] === 0 && attributes.pads[1] === 0;
|
||||
if (sameSize ||
|
||||
(weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 &&
|
||||
attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 &&
|
||||
attributes.pads[1] === 0)) {
|
||||
// conv2dByMatMul
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
{
|
||||
...transposeProgramMetadata,
|
||||
cacheHint: weightTransposeAttribute.cacheKey,
|
||||
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
|
||||
},
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
}
|
||||
|
||||
const batch = outputShape[0];
|
||||
let xReshaped, wReshaped, matmulOutputShape;
|
||||
const matmulInputs = [];
|
||||
matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]));
|
||||
matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels]));
|
||||
if (isChannelsLast) {
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
{
|
||||
...transposeProgramMetadata,
|
||||
cacheHint: weightTransposeAttribute.cacheKey,
|
||||
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
|
||||
},
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
}
|
||||
if (sameSize) {
|
||||
const sharedDim = inputHeight * inputWidth * inputChannels;
|
||||
xReshaped = inputs[0].reshape([1, batch, sharedDim]);
|
||||
wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]);
|
||||
matmulOutputShape = [1, batch, outChannels];
|
||||
} else {
|
||||
xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]);
|
||||
wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]);
|
||||
matmulOutputShape = [batch, outHeight * outWidth, outChannels];
|
||||
}
|
||||
matmulInputs.push(xReshaped);
|
||||
matmulInputs.push(wReshaped);
|
||||
} else {
|
||||
xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]);
|
||||
wReshaped = inputs[1].reshape([1, outChannels, inputChannels]);
|
||||
matmulOutputShape = [batch, outChannels, outHeight * outWidth];
|
||||
matmulInputs.push(wReshaped);
|
||||
matmulInputs.push(xReshaped);
|
||||
}
|
||||
if (hasBias) {
|
||||
matmulInputs.push(inputs[2]);
|
||||
}
|
||||
const matmulOutputShape = [batch, outHeight * outWidth, outChannels];
|
||||
context.compute(
|
||||
createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape),
|
||||
createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
|
||||
{inputs: matmulInputs});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: implement conv2dWithIm2Col()
|
||||
|
||||
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
|
||||
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
|
||||
const dimInner = weightHeight * weightWidth * inputChannels;
|
||||
|
||||
const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
|
||||
|
||||
// STEP.1: transpose weight
|
||||
|
|
@ -214,14 +225,13 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
// STEP.2: prepare reshaped inputs
|
||||
const convInputs = [inputs[0], transposedWeight];
|
||||
if (hasBias) {
|
||||
if (!isChannelsLast && inputs[2].dims.length === 1) {
|
||||
convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
|
||||
} else {
|
||||
convInputs.push(inputs[2]);
|
||||
}
|
||||
convInputs.push(inputs[2]);
|
||||
}
|
||||
|
||||
// STEP.3: compute matmul
|
||||
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
|
||||
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
|
||||
const dimInner = weightHeight * weightWidth * inputChannels;
|
||||
context.compute(
|
||||
createConv2DMatMulProgramInfoLoader(
|
||||
convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
|
||||
|
|
|
|||
|
|
@ -17,11 +17,12 @@ const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
|
|||
|
||||
export const createMatmulProgramInfoLoader =
|
||||
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
|
||||
reshapedOutputShape?: readonly number[]): ProgramInfoLoader => {
|
||||
reshapedOutputShape?: readonly number[], isChannelsLast = false): ProgramInfoLoader => {
|
||||
const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
|
||||
return {
|
||||
...metadata,
|
||||
get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape, reshapedOutputShape)
|
||||
get: () => createMatmulProgramInfo(
|
||||
metadata, inputs, activationAttributes, outputShape, reshapedOutputShape, isChannelsLast)
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,42 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv with bias addition C",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [3, 1, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 3, 4, 5],
|
||||
"dims": [2, 1, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5, 6],
|
||||
"dims": [2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [15, 46, 31, 102, 47, 158],
|
||||
"dims": [3, 2, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - group - A",
|
||||
"operator": "Conv",
|
||||
|
|
|
|||
Loading…
Reference in a new issue