mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[js/webgpu] Provide a naive vectorized matmul algorithm (#18758)
### Description This PR provided a vectorized matmul algorithm. In most situations, we still go to the workgroup memory optimized matmul. But for some situations, like N and K are very small, using workgroup optimized matmul can't fully utilize the underlying hardware due to the 32x32 tile size. So for very small N/K, we switch to the naive vectorized matmul algorithm to improve the hardware execution unit usage. With this PR, matmul with input0: [1, 36864, 3], input1: [1, 3, 3], input2: [3] becomes less than 1 ms from 4.34 ms on Intel Gen9 GPUs.
This commit is contained in:
parent
1ad6eb1359
commit
b30e721dc8
3 changed files with 164 additions and 10 deletions
|
|
@ -510,11 +510,7 @@ export const createMatmulProgramInfo =
|
|||
name: 'MatMul',
|
||||
shaderCache: {
|
||||
hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
|
||||
`${activationAttributes.activation}` +
|
||||
`${activationAttributes.clipMax}` +
|
||||
`${activationAttributes.clipMin}` +
|
||||
`${isVec4}` +
|
||||
`${hasBias}` +
|
||||
`${isChannelsLast}`,
|
||||
inputDependencies
|
||||
},
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
|
|||
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
|
||||
import {createGroupedConvProgramInfo} from './conv-grouped';
|
||||
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
|
||||
import {createNaiveMatmulProgramInfo} from './matmul';
|
||||
import {createTransposeProgramInfo} from './transpose';
|
||||
|
||||
export const calculateOutputShape =
|
||||
|
|
@ -195,9 +196,19 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
if (hasBias) {
|
||||
matmulInputs.push(inputs[2]);
|
||||
}
|
||||
context.compute(
|
||||
createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
|
||||
{inputs: matmulInputs});
|
||||
const N = matmulOutputShape[2];
|
||||
const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1];
|
||||
// Tune the threshold.
|
||||
if (N < 8 && K < 8) {
|
||||
context.compute(
|
||||
createNaiveMatmulProgramInfo(
|
||||
matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
|
||||
{inputs: matmulInputs});
|
||||
} else {
|
||||
context.compute(
|
||||
createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
|
||||
{inputs: matmulInputs});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,150 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {BroadcastUtil} from '../../util';
|
||||
import {ComputeContext} from '../types';
|
||||
import {BroadcastUtil, ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
|
||||
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common';
|
||||
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
|
||||
|
||||
export const createNaiveMatmulProgramInfo =
|
||||
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
|
||||
reshapedOutputShape?: readonly number[],
|
||||
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
|
||||
const aShape = inputs[0].dims;
|
||||
const bShape = inputs[1].dims;
|
||||
|
||||
const M = aShape[aShape.length - 2];
|
||||
const N = bShape[bShape.length - 1];
|
||||
const K = aShape[aShape.length - 1];
|
||||
const components = getMaxComponents(N);
|
||||
const aComponents = getMaxComponents(K);
|
||||
const outputNumber = getMaxComponents(M);
|
||||
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
|
||||
const hasBias = inputs.length > 2;
|
||||
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
|
||||
const batchSize = ShapeUtil.size(outerDims);
|
||||
const outputShapeInShader = [batchSize, M, N];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
|
||||
{type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
|
||||
...createTensorShapeVariables(bShape)
|
||||
];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShapeInShader));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
|
||||
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
|
||||
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const inputVariables = [a, b];
|
||||
let processBias = '';
|
||||
if (hasBias) {
|
||||
const biasComponents = isChannelsLast ? components : 1;
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
|
||||
processBias = `${
|
||||
isChannelsLast ? `value += bias[col / ${biasComponents}];` :
|
||||
`value += ${output.type.value}(bias[row + i]);`}`;
|
||||
}
|
||||
|
||||
const outerDimsA = aShape.slice(0, -2);
|
||||
const outerDimsB = bShape.slice(0, -2);
|
||||
const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
|
||||
const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
|
||||
const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
|
||||
const rank = variable.rank;
|
||||
const name = variable.name;
|
||||
if (rank === 2) {
|
||||
return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`;
|
||||
}
|
||||
const batchRank = batchDims.rank;
|
||||
let resStr = `var ${name}_indices: ${variable.type.indices};`;
|
||||
for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
|
||||
resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`;
|
||||
}
|
||||
broadCastDims.forEach(i => {
|
||||
resStr += `\n${name}_indices[${i}] = 0;`;
|
||||
});
|
||||
resStr += `${name}_indices[${rank - 2}] = 0u;
|
||||
${name}_indices[${rank - 1}] = 0u;`;
|
||||
return resStr;
|
||||
};
|
||||
|
||||
const calcResult = (): string => {
|
||||
let calcStr = `var a_data: ${a.type.value};`;
|
||||
for (let i = 0; i < aComponents; i++) {
|
||||
calcStr += `
|
||||
let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
|
||||
}
|
||||
for (let i = 0; i < outputNumber; i++) {
|
||||
calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;
|
||||
|
||||
for (let j = 0; j < aComponents; j++) {
|
||||
calcStr += `
|
||||
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${
|
||||
i}]);\n`;
|
||||
}
|
||||
}
|
||||
return calcStr;
|
||||
};
|
||||
|
||||
return `
|
||||
${
|
||||
shaderHelper.registerUniform('outputSize', 'u32')
|
||||
.registerUniform('M', 'u32')
|
||||
.registerUniform('N', 'u32')
|
||||
.registerUniform('K', 'u32')
|
||||
.registerInternalVariables(batchDims)
|
||||
.declareVariables(...inputVariables, output)}
|
||||
${activationFunction}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
let col = (global_idx % (uniforms.N / ${components})) * ${components};
|
||||
var index1 = global_idx / (uniforms.N / ${components});
|
||||
let stride1 = uniforms.M / ${outputNumber};
|
||||
let row = (index1 % stride1) * ${outputNumber};
|
||||
let batch = index1 / stride1;
|
||||
|
||||
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
|
||||
${getIndices(a, broadCastADims)}
|
||||
let a_offset = ${a.indicesToOffset('a_indices')};
|
||||
${getIndices(b, broadCastBDims)}
|
||||
let b_offset = ${b.indicesToOffset('b_indices')};
|
||||
var values: array<${output.type.value}, ${outputNumber}>;
|
||||
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
|
||||
${calcResult()}
|
||||
}
|
||||
for (var i = 0u; i < ${outputNumber}u; i++) {
|
||||
var value = values[i];
|
||||
${processBias}
|
||||
${applyActivation}
|
||||
let cur_indices = ${output.type.indices}(batch, row + i, col);
|
||||
let offset = ${output.indicesToOffset('cur_indices')};
|
||||
${output.setByOffset(`offset / ${components}`, 'value')};
|
||||
}
|
||||
}
|
||||
`;
|
||||
};
|
||||
return {
|
||||
name: 'MatMulNaive',
|
||||
shaderCache: {
|
||||
hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${
|
||||
isChannelsLast}`,
|
||||
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank']
|
||||
},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource
|
||||
};
|
||||
};
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
|
|
@ -23,5 +163,12 @@ export const matMul = (context: ComputeContext): void => {
|
|||
if (!outputShape) {
|
||||
throw new Error('Can\'t use matmul on the given tensors');
|
||||
}
|
||||
context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
|
||||
const N = outputShape[outputShape.length - 1];
|
||||
const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
|
||||
if (N < 8 && K < 8) {
|
||||
context.compute(
|
||||
createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
|
||||
} else {
|
||||
context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue