[js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753)

This commit is contained in:
Xu Xing 2024-01-26 07:37:05 +08:00 committed by GitHub
parent a2867b911e
commit 656ca66186
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 420 additions and 346 deletions

View file

@ -21,8 +21,8 @@
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {getActivationSnippet} from '../fuse-utils';
@ -88,10 +88,10 @@ const conv2dCommonSnippet =
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
@ -108,7 +108,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
@ -117,7 +117,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
@ -129,9 +129,8 @@ const conv2dCommonSnippet =
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
const bType =
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType);
const applyActivation = getActivationSnippet(attributes, resType);
const userCode = `
${activationFunction}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
${isChannelsLast ? sampleX : sampleW}
}
@ -142,7 +141,7 @@ const conv2dCommonSnippet =
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
let col = colIn * ${innerElementSize};
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer)
{
var value = valueIn;
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo =
LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const tileAOuter = workGroupSize[1] * elementsPerThread[1];
const tileBOuter = workGroupSize[0] * elementsPerThread[0];
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const fitAOuter = dimAOuter % tileAOuter === 0;
const fitBOuter = dimBOuter % tileBOuter === 0;
const fitInner = dimInner % tileInner === 0;
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x =
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniforms: UniformsArrayType = [
{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
{name: 'dilation', type: 'i32', length: 2}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
let declareFunctions = `
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
}
@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo =
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
}`;
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
declareFunctions += `
const x = inputVariable(
'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
programUniforms.push(...createTensorShapeVariables(outputShape));
return {
name: 'Conv2DMatMul',
shaderCache: {hint: attributes.cacheKey},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms,
}),
getShaderSource: (shaderHelper: ShaderHelper) => `
}
return `
${utilFunctions('uniforms.result_strides')}
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)}
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1],
elementsSize[2], t)}
${
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
sequentialAccessByThreads)}`;
};
return {
name: 'Conv2DMatMul',
shaderCache: {
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
tileAOuter};${tileBOuter};${tileInner}`,
inputDependencies
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms,
}),
getShaderSource
};
};

View file

@ -21,8 +21,8 @@
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {getActivationSnippet} from '../fuse-utils';
@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet =
col % outWidth);
`;
const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';
const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
let WRow = ${col} / (uniforms.filter_dims[1] * inChannels);
let WCol = ${col} / inChannels % uniforms.filter_dims[1];
let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]);
let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${type}(0.0);
}
@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet =
const sampleA = isChannelsLast ? `
let col = colIn * ${innerElementSize};
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
${readASnippet}
}
return ${type}(0.0);` :
`
let col = colIn * ${innerElementSize};
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
${readASnippet}
}
return ${type}(0.0);`;
const sampleW = `
let col = colIn * ${innerElementSize};
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels);
let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1];
if (${
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' :
'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) {
let rowInner = row % inChannels;
let coord = vec4<i32>(coordX, coordY, col, rowInner);
${getWSnippet(innerElementSize)}
@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet =
return ${type}(0.0);
`;
const {activationFunction, applyActivation} = getActivationSnippet(attributes, type);
const applyActivation = getActivationSnippet(attributes, type);
const userCode = `
${activationFunction}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} {
${isChannelsLast ? sampleA : sampleW}
}
@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet =
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
let col = colIn * ${innerElementSize};
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {
var value = valueInput;
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
${coordResSnippet}
@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo =
const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const inputVariables = [x, w];
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
const filterDims =
[attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
const effectiveFilterDims = [
filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)),
filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1))
];
const pads = [
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2)
];
let declareFunctions = '';
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
return {
name: 'Conv2DTransposeMatMul',
shaderCache: {hint: attributes.cacheKey},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms
}),
getShaderSource: (shaderHelper: ShaderHelper) => `
const getShaderSource = (shaderHelper: ShaderHelper) => {
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const inputVariables = [x, w];
let declareFunctions = '';
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
const uniforms: UniformsArrayType = [
{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
{name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2},
{name: 'filter_dims', type: 'i32', length: filterDims.length},
{name: 'pads', type: 'i32', length: pads.length}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
return `
${utilFunctions('uniforms.result_strides')}
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)};
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${
attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
${
@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo =
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`
undefined, sequentialAccessByThreads)}`;
};
return {
name: 'Conv2DTransposeMatMul',
shaderCache:
{hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms
}),
getShaderSource
};
};

View file

@ -20,24 +20,18 @@
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo} from '../../types';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
const createConvTranspose2DOpProgramShaderSource =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes,
outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false,
dataType: string): string => {
const isChannelsLast = attributes.format === 'NHWC';
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean,
is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType,
isChannelsLast = false): string => {
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
const outputSize = ShapeUtil.size(outputShape);
const workPerThread = isVec4 ? 2 : 1;
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[0] / group;
const outputChannelsPerGroup = wShape[1];
let declareFunctions = `
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource =
}`;
}
const components = isVec4 ? 4 : 1;
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components);
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
const inputVariables = [dy, w];
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components));
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
}
const output = outputVariable('result', inputs[0].dataType, outputShape, components);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const codeSnippet4 = `{
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1];
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1];
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;
let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(pads);
let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource =
for (var i = 0; i < ${workPerThread}; i++) {
dotProd[i] = vec4<${dataType}>(0.0);
}
for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) {
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x);
let wRPerm = filterDims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) ||
for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
let wRPerm = uniforms.filter_dims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
fract(dyR) > 0.0 || wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) {
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y);
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y);
let wCPerm = filterDims[1] - 1 - wC;
for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims[1] - 1 - wC;
if (wCPerm < 0) {
continue;
}
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) ||
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC) > 0.0) {
bDyCVal = false;
}
if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) ||
if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC2) > 0.0) {
bDyCVal2 = false;
}
@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource =
let idyC: u32 = u32(dyC);
let idyC2: u32 = u32(dyC2);
if (bDyCVal && bDyCVal2) {
let d2Length = outBackprop[3];
let d2Length = uniforms.Dy_shape[3];
for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource =
dot(xValue, wValue3));
}
} else if (bDyCVal) {
let d2Length = outBackprop[${channelDim}];
let d2Length = uniforms.Dy_shape[${channelDim}];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource =
dotProd[0] = dotProd[0] + tmpval;
}
} else if (bDyCVal2) {
let d2Length = outBackprop[3];
let d2Length = uniforms.Dy_shape[3];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource =
let d1 = ${output.indicesGet('outputIndices', channelDim)};
let r = ${output.indicesGet('outputIndices', rowDim)};
let c = ${output.indicesGet('outputIndices', colDim)};
let dyCorner = vec2<i32>(i32(r), i32(c)) - pads;
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
let groupId = d1 / ${outputChannelsPerGroup};
let wOutChannel = d1 - groupId * ${outputChannelsPerGroup};
let groupId = d1 / uniforms.output_channels_per_group;
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${dataType}(0.0);
for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) {
if (wR % dilations.x != 0) {
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]);
let wRPerm = filterDims.x - 1 - wR / dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 ||
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) {
if (wC % dilations.y != 0) {
for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y);
let wCPerm = filterDims.y - 1 - wC / dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) ||
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * ${inputChannelsPerGroup};
for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) {
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
let xValue = ${
isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource =
`;
return `
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
const outShape : vec4<u32> = vec4<u32>(${outputShape.join(',')});
const outBackprop : vec4<u32> = vec4<u32>(${inputs[0].dims.join(',')});
const strides : vec2<u32> = vec2<u32>(${attributes.strides[0]}, ${attributes.strides[1]});
const filterDims : vec2<u32> = vec2<u32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const dilations : vec2<u32> = vec2<u32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const effectiveFilterDims : vec2<u32> = filterDims + vec2<u32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2);
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)};
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
${isVec4 ? codeSnippet4 : codeSnippet}}`;
};
@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo =
];
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const isChannelsLast = attributes.format === 'NHWC';
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const strides = [attributes.strides[0], attributes.strides[1]];
const filterDims =
[attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
const dilations = [attributes.dilations[0], attributes.dilations[1]];
const effectiveFilterDims = [
filterDims[0] +
(attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)),
filterDims[1] +
(attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1))
];
const pads = [
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2
];
const isVec4 = false;
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[0] / group;
const outputChannelsPerGroup = wShape[1];
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length},
{name: 'filter_dims', type: 'u32', length: filterDims.length},
{name: 'dilations', type: 'u32', length: filterDims.length},
{name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length},
{name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'},
{name: 'output_channels_per_group', type: 'u32'}
];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `${
createConvTranspose2DOpProgramShaderSource(
shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms,
isChannelsLast)}`;
};
return {
name: 'ConvTranspose2D',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies},
getRunData: () => ({
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
outputs: [{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType
}]
}],
programUniforms
}),
getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource(
shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false,
dataType),
getShaderSource
};
};

View file

@ -22,7 +22,7 @@
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';
import {typeSnippet} from './activation_util';
@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc: array<vec4<${type}>, rowPerThread>;
// Loop over shared dimension.
let tileRowB = localRow * ${rowPerThreadB};
for (var t = 0; t < numTiles; t = t + 1) {
for (var t = 0; t < num_tiles; t = t + 1) {
// Load one tile of A into local memory.
for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
let inputRow = tileRow + innerRow;
@ -204,7 +204,7 @@ export const makeMatMulPackedSource =
let globalColStart = i32(workgroupId.x) * ${tileBOuter};
// Loop over shared dimension.
for (var t = 0; t < numTiles; t = t + 1) {
for (var t = 0; t < num_tiles; t = t + 1) {
// Load one tile of A into local memory.
for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) {
for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) {
@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA};
let tileColA = i32(localId.x) * ${colPerThreadA};
let tileRowB = i32(localId.y) * ${rowPerThreadB};
// Loop over shared dimension.
for (var t = 0; t < numTiles; t = t + 1) {
for (var t = 0; t < num_tiles; t = t + 1) {
// Load one tile of A into local memory.
for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) {
for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) {
@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
@builtin(workgroup_id) workgroupId : vec3<u32>) {
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
let num_tiles = ${
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc : array<array<${type}, colPerThread>, rowPerThread>;
@ -379,7 +380,7 @@ const matMulReadWriteFnSource =
typeSnippet(component, dataType)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < uniforms.dimAOuter && col < uniforms.dimInner)
if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
{
${getAIndices()}
value = ${aVariable.getByIndices('aIndices')};
@ -391,7 +392,7 @@ const matMulReadWriteFnSource =
typeSnippet(component, dataType)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < uniforms.dimInner && col < uniforms.dimBOuter)
if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
{
${getBIndices()}
value = ${bVariable.getByIndices('bIndices')};
@ -401,7 +402,7 @@ const matMulReadWriteFnSource =
fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) {
let col = colIn * ${component};
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {
var value = valueIn;
let coords = vec3<i32>(batch, row, colIn);
${
@ -422,16 +423,10 @@ export const createMatmulProgramInfo =
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const enableBatchUniforms = enableShapesUniforms(outerDims.length);
const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
const batchSize = ShapeUtil.size(outerDims);
const dimAOuter = aShape[aShape.length - 2];
const dimInner = aShape[aShape.length - 1];
const dimBOuter = bShape[bShape.length - 1];
@ -446,72 +441,67 @@ export const createMatmulProgramInfo =
Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const components = isVec4 ? 4 : 1;
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length);
const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp;
const aShapeOrRank = aShapeTemp.length;
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length);
const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp;
const bShapeOrRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
const inputVariables = [A, B];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
if (enableBatchUniforms) {
programUniforms.push(...createTensorShapeVariables(outerDims));
if (activationAttributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: activationAttributes.clipMax!},
{type: 'float32', data: activationAttributes.clipMin!});
}
if (enableAShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(aShapeTemp));
}
if (enableBShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(bShapeTemp));
}
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims');
inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims');
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
...createTensorShapeVariables(bShapeTemp));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const hasBias = inputs.length > 2;
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
const declareFunctions = matMulReadWriteFnSource(
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
isChannelsLast);
if (hasBias) {
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
const getShaderSource = (shaderHelper: ShaderHelper) => `
const getShaderSource = (shaderHelper: ShaderHelper) => {
const batchShapeOrRank = outerDims.length;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
const inputVariables = [A, B];
if (hasBias) {
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
}
const uniforms: UniformsArrayType =
[{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}];
if (activationAttributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
const declareFunctions = matMulReadWriteFnSource(
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
isChannelsLast);
return `
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.registerInternalVariables(batchDims)
.declareVariables(...inputVariables, output)}
${activationFunction}
shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables(
...inputVariables, output)}
${declareFunctions}
${
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
`;
// TODO: turn clipMax and clipMin to uniforms.
};
return {
name: 'MatMul',
shaderCache: {
hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
`${isVec4}` +
`${isChannelsLast}`,
hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`,
inputDependencies
},
getRunData: () => ({

View file

@ -3,9 +3,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ProgramInfo, ProgramUniform} from '../types';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
import {calculateOutputShape, ConvAttributes} from './conv';
import {getActivationSnippet} from './fuse-utils';
@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo =
xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast);
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', inputs[0].dataType, outputShape);
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
const x = inputVariable('x', inputs[0].dataType, xShape);
const w = inputVariable('w', inputs[1].dataType, wShape);
const inputVars = [x, w];
if (hasBias) {
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims));
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const getShaderSource = (shaderHelper: ShaderHelper) => `
const strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u);
const pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const applyActivation = getActivationSnippet(attributes, output.type.value);
const x = inputVariable('x', inputs[0].dataType, xShape.length);
const w = inputVariable('w', inputs[1].dataType, wShape.length);
const inputVars = [x, w];
if (hasBias) {
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims));
}
${shaderHelper.declareVariables(...inputVars, output)}
${activationFunction}
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length},
{name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2},
{name: 'output_channels_per_group', type: 'u32'}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let outputIndices = ${output.offsetToIndices('global_idx')};
let batch: u32 = outputIndices[0];
let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}];
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${
isChannelLast ? 2 : 3}]) * strides - pads;
let group_id: u32 = output_channel / ${outputChannelsPerGroup}u;
isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads;
let group_id: u32 = output_channel / uniforms.output_channels_per_group;
var value: ${output.type.value} = ${output.type.value}(0);
for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) {
let input_channel = group_id * ${wShape[1]}u + wInChannel;
for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) {
let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u;
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
let input_channel = group_id * uniforms.w_shape[1] + wInChannel;
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) {
if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) {
continue;
}
for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) {
let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u;
if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) {
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) {
continue;
}
let xVal = ${
isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') :
x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') :
x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
value += xVal*wVal;
}
@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo =
${applyActivation}
${output.setByOffset('global_idx', 'value')}
}`;
};
return {
name: 'GroupedConv',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: attributes.cacheKey, inputDependencies},
getRunData: () => ({
outputs: [{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType
}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo =
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
const applyActivation = getActivationSnippet(attributes, output.type.value);
const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
const inputVars = [x, w];
@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo =
.registerUniform('strides', 'i32', 2)
.registerUniform('pads', 'i32', 2)
.declareVariables(...inputVars, output)}
${activationFunction}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let width0 = uniforms.output_shape[3];
@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo =
return {
name: 'GroupedConv-Vectorize',
shaderCache: {
hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank']
},
getRunData: () => ({

View file

@ -2,7 +2,6 @@
// Licensed under the MIT License.
import {TensorView} from '../../tensor-view';
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext} from '../types';
import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu';
@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes {
readonly outputShape: readonly number[];
}
const getAdjustedConvTransposeAttributes =
<T extends ConvTransposeAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes =
// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
const cacheKey = attributes.cacheKey + [
kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','),
dilations.join(',')
].join('_');
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey});
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides});
return newAttributes;
};
@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record<string, unknown>
const wIsConst = (attributes.wIsConst as () => boolean)();
const outputPadding = attributes.outputPadding as [number, number, number, number];
const outputShape = attributes.outputShape as [number, number];
return createAttributeWithCacheKey({
return {
autoPad,
format,
dilations,
@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record<string, unknown>
pads,
strides,
wIsConst,
...activationAttributes
});
...activationAttributes,
cacheKey: `${attributes.format};${activationAttributes.activation};`
};
};
const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {

View file

@ -3,7 +3,7 @@
import {TensorView} from '../../tensor-view';
import {PoolConvUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext} from '../types';
import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
@ -110,7 +110,7 @@ const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inpu
// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, pads});
return newAttributes;
};
@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
const strides = attributes.strides as [number, number];
const wIsConst = (attributes.w_is_const as () => boolean)();
return createAttributeWithCacheKey(
{autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes});
return {
autoPad,
format,
dilations,
group,
kernelShape,
pads,
strides,
wIsConst,
...activationAttributes,
cacheKey: `${attributes.format};${activationAttributes.activation};`
};
};
const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {

View file

@ -7,30 +7,21 @@ export interface InternalActivationAttributes {
readonly activation: string;
readonly clipMin?: number;
readonly clipMax?: number;
readonly activationCacheKey: string;
}
export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string):
{activationFunction: string; applyActivation: string} => {
switch (attributes.activation) {
case 'Relu':
return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`};
case 'Sigmoid':
return {
activationFunction: '',
applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`
};
case 'Clip':
return {
activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${
attributes.clipMax!});`,
applyActivation: 'value = clamp(value, clip_min_, clip_max_);'
};
// TODO: adding other activations that can be fused.
default:
return {activationFunction: '', applyActivation: ''};
}
};
export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => {
switch (attributes.activation) {
case 'Relu':
return `value = max(value, ${valueType}(0.0));`;
case 'Sigmoid':
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
case 'Clip':
return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`;
// TODO: adding other activations that can be fused.
default:
return '';
}
};
export const parseInternalActivationAttributes =
(attributes: Record<string, unknown>|undefined): InternalActivationAttributes => {
@ -38,7 +29,7 @@ export const parseInternalActivationAttributes =
if (activation === 'Clip') {
const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP];
return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`};
return {activation, clipMax, clipMin};
}
return {activation, activationCacheKey: activation};
return {activation};
};

View file

@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common';
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common';
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
export const createNaiveMatmulProgramInfo =
@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo =
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const batchSize = ShapeUtil.size(outerDims);
const outputShapeInShader = [batchSize, M, N];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
{type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
...createTensorShapeVariables(bShape)
{type: 'uint32', data: K}
];
if (activationAttributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: activationAttributes.clipMax!},
{type: 'float32', data: activationAttributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
...createTensorShapeVariables(bShape));
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
}
@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo =
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
const inputVariables = [a, b];
let processBias = '';
if (hasBias) {
@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo =
const outerDimsB = bShape.slice(0, -2);
const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'K', type: 'u32'}
];
if (activationAttributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
const rank = variable.rank;
const name = variable.name;
@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo =
return `
${
shaderHelper.registerUniform('outputSize', 'u32')
.registerUniform('M', 'u32')
.registerUniform('N', 'u32')
.registerUniform('K', 'u32')
.registerInternalVariables(batchDims)
.declareVariables(...inputVariables, output)}
${activationFunction}
shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables(
...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let col = (global_idx % (uniforms.N / ${components})) * ${components};
var index1 = global_idx / (uniforms.N / ${components});
let stride1 = uniforms.M / ${outputNumber};
@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo =
return {
name: 'MatMulNaive',
shaderCache: {
hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${
isChannelsLast}`,
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank']
},
getRunData: () => ({
@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => {
const N = outputShape[outputShape.length - 1];
const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
if (N < 8 && K < 8) {
context.compute(
createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
} else {
context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
}
};