mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[js/webgpu] Refactor matmul conv to support uniforms for matmul (#18452)
This change refactored matmul/conv related programs to support shape uniforms. Currently only matmul shape uniforms are fully enabled. TODOs: add input dependencies for conv related programs, turn clipMax and clipMin to uniforms.
This commit is contained in:
parent
42c6799c59
commit
fa106942a7
5 changed files with 174 additions and 98 deletions
|
|
@ -21,9 +21,8 @@
|
|||
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo} from '../../types';
|
||||
import {tensorTypeToWsglStorageType} from '../common';
|
||||
import {ProgramInfo, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {ConvAttributes} from '../conv';
|
||||
import {getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
|
|
@ -50,9 +49,9 @@ const conv2dCommonSnippet =
|
|||
const getWSnippet = (innerElementSize: number) => {
|
||||
switch (innerElementSize) {
|
||||
case 1:
|
||||
return 'return w[row * wShape[3] + colIn];';
|
||||
return 'return w[row * i32(uniforms.w_shape[3]) + colIn];';
|
||||
case 4:
|
||||
return 'return w[row * wShape[3] / 4 + colIn];';
|
||||
return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];';
|
||||
default:
|
||||
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
|
||||
}
|
||||
|
|
@ -79,13 +78,13 @@ const conv2dCommonSnippet =
|
|||
col % outWidth);
|
||||
`;
|
||||
|
||||
const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]';
|
||||
const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]';
|
||||
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
|
||||
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
|
||||
const row = isChannelsLast ? 'row' : 'col';
|
||||
const col = isChannelsLast ? 'col' : 'row';
|
||||
const readXSnippet = `
|
||||
let inChannels = wShape[2];
|
||||
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
|
||||
let inChannels = i32(uniforms.w_shape[2]);
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
let outRow = ${row} / outWidth;
|
||||
let outCol = ${row} % outWidth;
|
||||
|
||||
|
|
@ -99,7 +98,7 @@ const conv2dCommonSnippet =
|
|||
// the 'same' padding type.
|
||||
if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) {
|
||||
${coordASnippet}
|
||||
let xIndex = getIndexFromCoords4D(coord, xShape);
|
||||
let xIndex = getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape));
|
||||
${getXSnippet(innerElementSizeX)}
|
||||
}
|
||||
return resData;`;
|
||||
|
|
@ -109,7 +108,7 @@ const conv2dCommonSnippet =
|
|||
${readXSnippet}` :
|
||||
`
|
||||
let col = colIn * ${innerElementSizeX};
|
||||
if (row < dimAOuter && col < dimInner) {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
|
||||
${readXSnippet}
|
||||
}
|
||||
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
|
||||
|
|
@ -118,7 +117,7 @@ const conv2dCommonSnippet =
|
|||
${readXSnippet}` :
|
||||
`
|
||||
let col = colIn * ${innerElementSizeX};
|
||||
if (row < dimInner && col < dimBOuter) {
|
||||
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
|
||||
${readXSnippet}
|
||||
}
|
||||
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
|
||||
|
|
@ -143,10 +142,10 @@ const conv2dCommonSnippet =
|
|||
|
||||
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < dimAOuter && col < dimBOuter)
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
|
||||
{
|
||||
var value = valueIn;
|
||||
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
${coordResSnippet}
|
||||
${biasSnippet(addBias)}
|
||||
${applyActivation}
|
||||
|
|
@ -194,10 +193,17 @@ export const createConv2DMatMulProgramInfo =
|
|||
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
|
||||
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const declareInputs = [
|
||||
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`,
|
||||
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? `vec4<${t}>` : t}>;`
|
||||
];
|
||||
// TODO: support component 2, 3.
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
|
||||
const inputVariables = [x, w];
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
||||
|
||||
let declareFunctions = `
|
||||
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
|
||||
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
|
||||
|
|
@ -207,41 +213,40 @@ export const createConv2DMatMulProgramInfo =
|
|||
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
|
||||
}`;
|
||||
if (hasBias) {
|
||||
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? `vec4<${t}>` : t}>;`);
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
return {
|
||||
name: 'Conv2DMatMul',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource: () => `
|
||||
${utilFunctions}
|
||||
getShaderSource: (shaderHelper: ShaderHelper) => `
|
||||
${utilFunctions('uniforms.result_strides')}
|
||||
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
|
||||
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
|
||||
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
|
||||
${declareInputs.join('')}
|
||||
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
|
||||
isVec4 ? `vec4<${t}>` : t}>;
|
||||
//@group(0) @binding(${declareInputs.length + 1}) var<uniform> uniforms: Uniforms;
|
||||
|
||||
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
|
||||
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
|
||||
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
|
||||
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.declareVariables(...inputVariables, output)}
|
||||
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
|
||||
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
|
||||
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
|
||||
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
|
||||
const dimAOuter : i32 = ${dimAOuter};
|
||||
const dimBOuter : i32 = ${dimBOuter};
|
||||
const dimInner : i32 = ${dimInner};
|
||||
${declareFunctions}
|
||||
${
|
||||
conv2dCommonSnippet(
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@
|
|||
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo} from '../../types';
|
||||
import {ProgramInfo, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
|
||||
import {ConvTransposeAttributes} from '../conv-transpose';
|
||||
import {getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
|
|
@ -36,16 +36,16 @@ const conv2dTransposeCommonSnippet =
|
|||
const getWSnippet = (innerElementSize: number) => {
|
||||
switch (innerElementSize) {
|
||||
case 1:
|
||||
return 'return W[getIndexFromCoords4D(coord, wShape)];';
|
||||
return 'return w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];';
|
||||
case 4:
|
||||
return `
|
||||
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
|
||||
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
|
||||
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
|
||||
let v0 = W[getIndexFromCoords4D(coord, wShape)];
|
||||
let v1 = W[getIndexFromCoords4D(coord1, wShape)];
|
||||
let v2 = W[getIndexFromCoords4D(coord2, wShape)];
|
||||
let v3 = W[getIndexFromCoords4D(coord3, wShape)];
|
||||
let v0 = w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];
|
||||
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
|
||||
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
|
||||
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
|
||||
return vec4<f32>(v0, v1, v2, v3);
|
||||
`;
|
||||
default:
|
||||
|
|
@ -81,7 +81,7 @@ const conv2dTransposeCommonSnippet =
|
|||
|
||||
const readASnippet = `
|
||||
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
|
||||
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
let outRow = ${row} / outWidth;
|
||||
let outCol = ${row} % outWidth;
|
||||
|
||||
|
|
@ -99,17 +99,17 @@ const conv2dTransposeCommonSnippet =
|
|||
let iXC = i32(xC);
|
||||
let xCh = ${col} % inChannels;
|
||||
${coordASnippet}
|
||||
return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;
|
||||
return x[getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape))/${innerElementSize}];`;
|
||||
|
||||
const sampleA = isChannelsLast ? `
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < dimAOuter && col < dimInner) {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
|
||||
${readASnippet}
|
||||
}
|
||||
return ${type}(0.0);` :
|
||||
`
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < dimInner && col < dimBOuter) {
|
||||
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
|
||||
${readASnippet}
|
||||
}
|
||||
return ${type}(0.0);`;
|
||||
|
|
@ -120,8 +120,8 @@ const conv2dTransposeCommonSnippet =
|
|||
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
|
||||
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
|
||||
if (${
|
||||
isChannelsLast ? 'row < dimInner && col < dimBOuter' :
|
||||
'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) {
|
||||
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
|
||||
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
|
||||
let rowInner = row % inChannels;
|
||||
let coord = vec4<i32>(coordX, coordY, col, rowInner);
|
||||
${getWSnippet(innerElementSize)}
|
||||
|
|
@ -142,13 +142,13 @@ const conv2dTransposeCommonSnippet =
|
|||
|
||||
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < dimAOuter && col < dimBOuter) {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
||||
var value = valueInput;
|
||||
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
${coordResSnippet}
|
||||
${biasSnippet(addBias)}
|
||||
${applyActivation}
|
||||
result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
|
||||
result[getIndexFromCoords4D(coords, vec4<i32>(uniforms.result_shape))/${innerElementSize}] = value;
|
||||
}
|
||||
}`;
|
||||
return userCode;
|
||||
|
|
@ -185,37 +185,46 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
|
||||
const innerElementSize = isVec4 ? 4 : 1;
|
||||
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
const inputVariables = [x, w];
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
||||
|
||||
|
||||
const declareInputs = [
|
||||
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
|
||||
'@group(0) @binding(1) var<storage, read> W: array<f32>;'
|
||||
];
|
||||
let declareFunctions = '';
|
||||
if (hasBias) {
|
||||
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
return {
|
||||
name: 'Conv2DTransposeMatMul',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource: () => `
|
||||
${utilFunctions}
|
||||
${declareInputs.join('\n')}
|
||||
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
|
||||
isVec4 ? 'vec4<f32>' : 'f32'}>;
|
||||
getShaderSource: (shaderHelper: ShaderHelper) => `
|
||||
${utilFunctions('uniforms.result_strides')}
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.declareVariables(...inputVariables, output)};
|
||||
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
|
||||
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
|
||||
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
|
||||
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
|
||||
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
|
||||
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
|
||||
attributes.kernelShape[isChannelsLast ? 2 : 3]});
|
||||
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
|
||||
|
|
|
|||
|
|
@ -19,13 +19,13 @@
|
|||
//
|
||||
// modified to fit the needs of the project
|
||||
|
||||
export const utilFunctions = `
|
||||
export const utilFunctions = (strideStr: string) => (`
|
||||
fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
|
||||
return dot(coords, vec4<i32>(
|
||||
shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
|
||||
}
|
||||
fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
|
||||
return dot(coords, vec4<i32>(
|
||||
outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1));
|
||||
i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1));
|
||||
}
|
||||
`;
|
||||
`);
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@
|
|||
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo} from '../../types';
|
||||
import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';
|
||||
|
||||
import {typeSnippet} from './activation_util';
|
||||
|
|
@ -112,7 +112,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
|
|||
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
|
||||
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
|
||||
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
|
||||
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
|
||||
|
||||
var acc: array<vec4<${type}>, rowPerThread>;
|
||||
|
|
@ -322,7 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
|
|||
@builtin(workgroup_id) workgroupId : vec3<u32>) {
|
||||
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
|
||||
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
|
||||
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
|
||||
|
||||
var acc : array<array<${type}, colPerThread>, rowPerThread>;
|
||||
|
|
@ -384,7 +384,7 @@ const matMulReadWriteFnSource =
|
|||
typeSnippet(component, dataType)} {
|
||||
var value = ${typeSnippet(component, dataType)}(0.0);
|
||||
let col = colIn * ${component};
|
||||
if(row < dimAOuter && col < dimInner)
|
||||
if(row < uniforms.dimAOuter && col < uniforms.dimInner)
|
||||
{
|
||||
${getAIndices()}
|
||||
value = ${aVariable.getByIndices('aIndices')};
|
||||
|
|
@ -396,7 +396,7 @@ const matMulReadWriteFnSource =
|
|||
typeSnippet(component, dataType)} {
|
||||
var value = ${typeSnippet(component, dataType)}(0.0);
|
||||
let col = colIn * ${component};
|
||||
if(row < dimInner && col < dimBOuter)
|
||||
if(row < uniforms.dimInner && col < uniforms.dimBOuter)
|
||||
{
|
||||
${getBIndices()}
|
||||
value = ${bVariable.getByIndices('bIndices')};
|
||||
|
|
@ -406,7 +406,7 @@ const matMulReadWriteFnSource =
|
|||
|
||||
fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) {
|
||||
let col = colIn * ${component};
|
||||
if (row < dimAOuter && col < dimBOuter) {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
||||
var value = valueIn;
|
||||
let coords = vec3<i32>(batch, row, colIn);
|
||||
${
|
||||
|
|
@ -430,8 +430,11 @@ export const createMatmulProgramInfo =
|
|||
|
||||
const outerDimsA = aShape.slice(0, -2);
|
||||
const outerDimsB = bShape.slice(0, -2);
|
||||
|
||||
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
|
||||
const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims);
|
||||
const enableBatchUniforms = enableShapesUniforms(outerDims.length);
|
||||
const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims;
|
||||
const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true);
|
||||
const variables = [batchDims];
|
||||
const batchShapes = [outerDimsA, outerDimsB, outerDims];
|
||||
const batchSize = ShapeUtil.size(outerDims);
|
||||
|
|
@ -452,39 +455,81 @@ export const createMatmulProgramInfo =
|
|||
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
|
||||
const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
|
||||
const output =
|
||||
outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components);
|
||||
|
||||
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
|
||||
const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length);
|
||||
const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp;
|
||||
|
||||
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
|
||||
const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length);
|
||||
const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp;
|
||||
|
||||
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
|
||||
|
||||
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
|
||||
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
|
||||
variables.push(A);
|
||||
variables.push(B);
|
||||
variables.push(output);
|
||||
const inputVariables = [A, B];
|
||||
const inputVariables = [batchDims, A, B];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
if (enableBatchUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outerDims));
|
||||
}
|
||||
if (enableAShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(aShapeTemp));
|
||||
}
|
||||
if (enableBShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(bShapeTemp));
|
||||
}
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims');
|
||||
inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims');
|
||||
|
||||
const hasBias = inputs.length > 2;
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const declareFunctions =
|
||||
matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast);
|
||||
if (hasBias) {
|
||||
const biasComponents = isChannelsLast ? components : 1;
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const dimAOuter: i32 = ${dimAOuter};
|
||||
const dimBOuter: i32 = ${dimBOuter};
|
||||
const dimInner: i32 = ${dimInner};
|
||||
${shaderHelper.declareVariables(...inputVariables, output)}
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.declareVariables(...inputVariables, output)}
|
||||
${activationFunction}
|
||||
${declareFunctions}
|
||||
${
|
||||
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
|
||||
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
|
||||
${batchDims.impl()}`;
|
||||
`;
|
||||
// TODO: turn clipMax and clipMin to uniforms.
|
||||
return {
|
||||
name: 'MatMul',
|
||||
shaderCache: {hint: activationAttributes.activationCacheKey},
|
||||
shaderCache: {
|
||||
hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
|
||||
`${activationAttributes.activation}` +
|
||||
`${activationAttributes.clipMax}` +
|
||||
`${activationAttributes.clipMin}` +
|
||||
`${isVec4}` +
|
||||
`${hasBias}` +
|
||||
`${isChannelsLast}`,
|
||||
inputDependencies
|
||||
},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -210,6 +210,11 @@ export interface IndicesHelper {
|
|||
* a string representing the variable name for the strides of the input or output.
|
||||
*/
|
||||
readonly strides: string;
|
||||
|
||||
/**
|
||||
* representing variable with uniforms, but without binding.
|
||||
*/
|
||||
readonly uniformOnly: boolean;
|
||||
}
|
||||
|
||||
const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
|
||||
|
|
@ -335,8 +340,8 @@ export const sumVector = (name: string, components: number) => {
|
|||
* vec4.
|
||||
*/
|
||||
const createIndicesHelper =
|
||||
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean,
|
||||
components: 1|2|3|4): IndicesHelper => {
|
||||
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4,
|
||||
uniformOnly = false): IndicesHelper => {
|
||||
const useUniform = typeof shapeOrRank === 'number';
|
||||
const rank = useUniform ? shapeOrRank : shapeOrRank.length;
|
||||
const rankIdentity = [...new Array(rank).keys()];
|
||||
|
|
@ -358,7 +363,7 @@ const createIndicesHelper =
|
|||
getByIndices: false,
|
||||
};
|
||||
|
||||
const uniformPrefix = useUniform ? 'uniforms.' : '';
|
||||
const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : '';
|
||||
const shape = `${uniformPrefix}${name}_shape`;
|
||||
const strides = `${uniformPrefix}${name}_strides`;
|
||||
let o2iSnippet = '';
|
||||
|
|
@ -616,7 +621,8 @@ const createIndicesHelper =
|
|||
name,
|
||||
strides,
|
||||
shape,
|
||||
rank
|
||||
rank,
|
||||
uniformOnly
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -630,8 +636,8 @@ const createIndicesHelper =
|
|||
* @returns an IndicesHelper for the input.
|
||||
*/
|
||||
export const inputVariable =
|
||||
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
|
||||
createIndicesHelper(name, type, shapeOrRank, true, components);
|
||||
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false):
|
||||
IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly);
|
||||
|
||||
/**
|
||||
* Create a IndicesHelper for an output.
|
||||
|
|
@ -734,7 +740,7 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
`;
|
||||
}
|
||||
|
||||
private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
|
||||
private declareVariable(variable: IndicesHelper, bindingIndex = -1): string {
|
||||
this.indicesHelpers.push(variable);
|
||||
if (variable.rank !== 0) {
|
||||
if (variable.shape.startsWith('uniforms.')) {
|
||||
|
|
@ -744,13 +750,24 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
|
||||
}
|
||||
}
|
||||
if (variable.uniformOnly) {
|
||||
return '';
|
||||
}
|
||||
const access = variable.usage === 'input' ? 'read' : 'read_write';
|
||||
const storageType = variable.type.storage;
|
||||
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
|
||||
}
|
||||
|
||||
declareVariables(...variables: IndicesHelper[]): string {
|
||||
return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n');
|
||||
return variables
|
||||
.map(v => {
|
||||
if (v.uniformOnly === true) {
|
||||
return this.declareVariable(v);
|
||||
} else {
|
||||
return this.declareVariable(v, this.variableIndex++);
|
||||
}
|
||||
})
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
registerUniform(name: string, type: string): ShaderHelper {
|
||||
|
|
|
|||
Loading…
Reference in a new issue