mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
[js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753)
This commit is contained in:
parent
a2867b911e
commit
656ca66186
9 changed files with 420 additions and 346 deletions
|
|
@ -21,8 +21,8 @@
|
|||
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ProgramInfo, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
|
||||
import {ConvAttributes} from '../conv';
|
||||
import {getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
|
|
@ -88,10 +88,10 @@ const conv2dCommonSnippet =
|
|||
let outRow = ${row} / outWidth;
|
||||
let outCol = ${row} % outWidth;
|
||||
|
||||
let WRow = ${col} / (filterDims[1] * inChannels);
|
||||
let WCol = ${col} / inChannels % filterDims[1];
|
||||
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
|
||||
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
|
||||
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
|
||||
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
|
||||
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
|
||||
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
|
||||
let xCh = ${col} % inChannels;
|
||||
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
|
||||
// The bounds checking is always needed since we use it to pad zero for
|
||||
|
|
@ -108,7 +108,7 @@ const conv2dCommonSnippet =
|
|||
${readXSnippet}` :
|
||||
`
|
||||
let col = colIn * ${innerElementSizeX};
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
|
||||
if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
|
||||
${readXSnippet}
|
||||
}
|
||||
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
|
||||
|
|
@ -117,7 +117,7 @@ const conv2dCommonSnippet =
|
|||
${readXSnippet}` :
|
||||
`
|
||||
let col = colIn * ${innerElementSizeX};
|
||||
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
|
||||
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
|
||||
${readXSnippet}
|
||||
}
|
||||
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
|
||||
|
|
@ -129,9 +129,8 @@ const conv2dCommonSnippet =
|
|||
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
|
||||
const bType =
|
||||
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType);
|
||||
const applyActivation = getActivationSnippet(attributes, resType);
|
||||
const userCode = `
|
||||
${activationFunction}
|
||||
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
|
||||
${isChannelsLast ? sampleX : sampleW}
|
||||
}
|
||||
|
|
@ -142,7 +141,7 @@ const conv2dCommonSnippet =
|
|||
|
||||
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
|
||||
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer)
|
||||
{
|
||||
var value = valueIn;
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
|
|
@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo =
|
|||
LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);
|
||||
|
||||
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
|
||||
|
||||
const tileAOuter = workGroupSize[1] * elementsPerThread[1];
|
||||
const tileBOuter = workGroupSize[0] * elementsPerThread[0];
|
||||
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
|
||||
|
||||
const fitAOuter = dimAOuter % tileAOuter === 0;
|
||||
const fitBOuter = dimBOuter % tileBOuter === 0;
|
||||
const fitInner = dimInner % tileInner === 0;
|
||||
|
||||
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
|
||||
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
// TODO: support component 2, 3.
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
const x =
|
||||
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
|
||||
const inputVariables = [x, w];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
|
||||
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
|
||||
{type: 'int32', data: attributes.dilations}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
programUniforms.push(
|
||||
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
|
||||
}
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
|
||||
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
|
||||
{name: 'dilation', type: 'i32', length: 2}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
|
||||
}
|
||||
|
||||
let declareFunctions = `
|
||||
// TODO: support component 2, 3.
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
let declareFunctions = `
|
||||
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
|
||||
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
|
||||
}
|
||||
|
|
@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo =
|
|||
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
|
||||
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
|
||||
}`;
|
||||
if (hasBias) {
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
declareFunctions += `
|
||||
const x = inputVariable(
|
||||
'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
|
||||
const inputVariables = [x, w];
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
if (hasBias) {
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
return {
|
||||
name: 'Conv2DMatMul',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource: (shaderHelper: ShaderHelper) => `
|
||||
}
|
||||
|
||||
return `
|
||||
${utilFunctions('uniforms.result_strides')}
|
||||
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
|
||||
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
|
||||
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.declareVariables(...inputVariables, output)}
|
||||
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
|
||||
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
|
||||
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
|
||||
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
|
||||
${declareFunctions}
|
||||
${
|
||||
conv2dCommonSnippet(
|
||||
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1],
|
||||
elementsSize[2], t)}
|
||||
${
|
||||
${
|
||||
isVec4 ?
|
||||
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
|
||||
makeMatMulPackedSource(
|
||||
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
|
||||
sequentialAccessByThreads)}`
|
||||
sequentialAccessByThreads)}`;
|
||||
};
|
||||
return {
|
||||
name: 'Conv2DMatMul',
|
||||
shaderCache: {
|
||||
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
|
||||
tileAOuter};${tileBOuter};${tileInner}`,
|
||||
inputDependencies
|
||||
},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@
|
|||
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ProgramInfo, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
|
||||
import {ConvTransposeAttributes} from '../conv-transpose';
|
||||
import {getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
|
|
@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet =
|
|||
col % outWidth);
|
||||
`;
|
||||
|
||||
const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
|
||||
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
|
||||
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
|
||||
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
|
||||
const row = isChannelsLast ? 'row' : 'col';
|
||||
const col = isChannelsLast ? 'col' : 'row';
|
||||
|
||||
const readASnippet = `
|
||||
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
|
||||
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
let outRow = ${row} / outWidth;
|
||||
let outCol = ${row} % outWidth;
|
||||
|
||||
let WRow = ${col} / (filterDims[1] * inChannels);
|
||||
let WCol = ${col} / inChannels % filterDims[1];
|
||||
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
|
||||
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
|
||||
let WRow = ${col} / (uniforms.filter_dims[1] * inChannels);
|
||||
let WCol = ${col} / inChannels % uniforms.filter_dims[1];
|
||||
let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]);
|
||||
let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]);
|
||||
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
|
||||
return ${type}(0.0);
|
||||
}
|
||||
|
|
@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet =
|
|||
|
||||
const sampleA = isChannelsLast ? `
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
|
||||
if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
|
||||
${readASnippet}
|
||||
}
|
||||
return ${type}(0.0);` :
|
||||
`
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
|
||||
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
|
||||
${readASnippet}
|
||||
}
|
||||
return ${type}(0.0);`;
|
||||
|
||||
const sampleW = `
|
||||
let col = colIn * ${innerElementSize};
|
||||
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
|
||||
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
|
||||
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
|
||||
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
|
||||
let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels);
|
||||
let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1];
|
||||
if (${
|
||||
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
|
||||
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
|
||||
isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' :
|
||||
'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) {
|
||||
let rowInner = row % inChannels;
|
||||
let coord = vec4<i32>(coordX, coordY, col, rowInner);
|
||||
${getWSnippet(innerElementSize)}
|
||||
|
|
@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet =
|
|||
return ${type}(0.0);
|
||||
`;
|
||||
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(attributes, type);
|
||||
const applyActivation = getActivationSnippet(attributes, type);
|
||||
const userCode = `
|
||||
${activationFunction}
|
||||
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} {
|
||||
${isChannelsLast ? sampleA : sampleW}
|
||||
}
|
||||
|
|
@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet =
|
|||
|
||||
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
|
||||
let col = colIn * ${innerElementSize};
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
||||
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {
|
||||
var value = valueInput;
|
||||
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
|
||||
${coordResSnippet}
|
||||
|
|
@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
const innerElementSize = isVec4 ? 4 : 1;
|
||||
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
const inputVariables = [x, w];
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
||||
const filterDims =
|
||||
[attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
|
||||
const effectiveFilterDims = [
|
||||
filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)),
|
||||
filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1))
|
||||
];
|
||||
const pads = [
|
||||
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
|
||||
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2)
|
||||
];
|
||||
|
||||
let declareFunctions = '';
|
||||
if (hasBias) {
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
|
||||
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
|
||||
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
programUniforms.push(
|
||||
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
|
||||
}
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
return {
|
||||
name: 'Conv2DTransposeMatMul',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource: (shaderHelper: ShaderHelper) => `
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
const inputVariables = [x, w];
|
||||
|
||||
let declareFunctions = '';
|
||||
if (hasBias) {
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
|
||||
{name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2},
|
||||
{name: 'filter_dims', type: 'i32', length: filterDims.length},
|
||||
{name: 'pads', type: 'i32', length: pads.length}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
|
||||
}
|
||||
return `
|
||||
${utilFunctions('uniforms.result_strides')}
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.declareVariables(...inputVariables, output)};
|
||||
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
|
||||
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
|
||||
attributes.kernelShape[isChannelsLast ? 2 : 3]});
|
||||
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
|
||||
${
|
||||
attributes.dilations[0] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
|
||||
${
|
||||
attributes.dilations[1] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
|
||||
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
|
||||
attributes.pads[0] + attributes.pads[2]})/2,
|
||||
i32(effectiveFilterDims[1]) - 1 - (${
|
||||
attributes.pads[1] + attributes.pads[3]})/2);
|
||||
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
|
||||
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
|
||||
const dimAOuter : i32 = ${dimAOuter};
|
||||
const dimBOuter : i32 = ${dimBOuter};
|
||||
const dimInner : i32 = ${dimInner};
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
|
||||
${declareFunctions}
|
||||
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
|
||||
${
|
||||
|
|
@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
|
||||
makeMatMulPackedSource(
|
||||
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
|
||||
undefined, sequentialAccessByThreads)}`
|
||||
undefined, sequentialAccessByThreads)}`;
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'Conv2DTransposeMatMul',
|
||||
shaderCache:
|
||||
{hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -20,24 +20,18 @@
|
|||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo} from '../../types';
|
||||
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
|
||||
import {ConvTransposeAttributes} from '../conv-transpose';
|
||||
|
||||
const createConvTranspose2DOpProgramShaderSource =
|
||||
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes,
|
||||
outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false,
|
||||
dataType: string): string => {
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean,
|
||||
is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType,
|
||||
isChannelsLast = false): string => {
|
||||
const rowDim = isChannelsLast ? 1 : 2;
|
||||
const colDim = isChannelsLast ? 2 : 3;
|
||||
const channelDim = isChannelsLast ? 3 : 1;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const workPerThread = isVec4 ? 2 : 1;
|
||||
const group = attributes.group;
|
||||
const wShape = inputs[1].dims;
|
||||
const inputChannelsPerGroup = wShape[0] / group;
|
||||
const outputChannelsPerGroup = wShape[1];
|
||||
|
||||
let declareFunctions = `
|
||||
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
|
||||
|
|
@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
}`;
|
||||
}
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components);
|
||||
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components);
|
||||
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
|
||||
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const inputVariables = [dy, w];
|
||||
if (hasBias) {
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components));
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
|
||||
}
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape, components);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
|
||||
|
||||
const codeSnippet4 = `{
|
||||
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1];
|
||||
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1];
|
||||
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
|
||||
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
|
||||
let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
|
||||
let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;
|
||||
|
||||
let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(pads);
|
||||
let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);
|
||||
|
||||
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
|
||||
// ? = to be determined. : = across all values in that axis.
|
||||
|
|
@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
for (var i = 0; i < ${workPerThread}; i++) {
|
||||
dotProd[i] = vec4<${dataType}>(0.0);
|
||||
}
|
||||
for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) {
|
||||
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x);
|
||||
let wRPerm = filterDims[0] - 1 - wR;
|
||||
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) ||
|
||||
for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
|
||||
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
|
||||
let wRPerm = uniforms.filter_dims[0] - 1 - wR;
|
||||
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
|
||||
fract(dyR) > 0.0 || wRPerm < 0) {
|
||||
continue;
|
||||
}
|
||||
let idyR: u32 = u32(dyR);
|
||||
|
||||
for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) {
|
||||
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y);
|
||||
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y);
|
||||
let wCPerm = filterDims[1] - 1 - wC;
|
||||
for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
|
||||
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
|
||||
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
|
||||
let wCPerm = uniforms.filter_dims[1] - 1 - wC;
|
||||
if (wCPerm < 0) {
|
||||
continue;
|
||||
}
|
||||
var bDyCVal = true;
|
||||
var bDyCVal2 = true;
|
||||
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) ||
|
||||
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
|
||||
fract(dyC) > 0.0) {
|
||||
bDyCVal = false;
|
||||
}
|
||||
if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) ||
|
||||
if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
|
||||
fract(dyC2) > 0.0) {
|
||||
bDyCVal2 = false;
|
||||
}
|
||||
|
|
@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
let idyC: u32 = u32(dyC);
|
||||
let idyC2: u32 = u32(dyC2);
|
||||
if (bDyCVal && bDyCVal2) {
|
||||
let d2Length = outBackprop[3];
|
||||
let d2Length = uniforms.Dy_shape[3];
|
||||
for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
|
||||
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
|
||||
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
|
||||
|
|
@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
dot(xValue, wValue3));
|
||||
}
|
||||
} else if (bDyCVal) {
|
||||
let d2Length = outBackprop[${channelDim}];
|
||||
let d2Length = uniforms.Dy_shape[${channelDim}];
|
||||
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
|
||||
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
|
||||
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
|
||||
|
|
@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
dotProd[0] = dotProd[0] + tmpval;
|
||||
}
|
||||
} else if (bDyCVal2) {
|
||||
let d2Length = outBackprop[3];
|
||||
let d2Length = uniforms.Dy_shape[3];
|
||||
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
|
||||
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
|
||||
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
|
||||
|
|
@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
let d1 = ${output.indicesGet('outputIndices', channelDim)};
|
||||
let r = ${output.indicesGet('outputIndices', rowDim)};
|
||||
let c = ${output.indicesGet('outputIndices', colDim)};
|
||||
let dyCorner = vec2<i32>(i32(r), i32(c)) - pads;
|
||||
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
|
||||
let dyRCorner = dyCorner.x;
|
||||
let dyCCorner = dyCorner.y;
|
||||
let groupId = d1 / ${outputChannelsPerGroup};
|
||||
let wOutChannel = d1 - groupId * ${outputChannelsPerGroup};
|
||||
let groupId = d1 / uniforms.output_channels_per_group;
|
||||
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
|
||||
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
|
||||
// ? = to be determined. : = across all values in that axis.
|
||||
var dotProd = ${dataType}(0.0);
|
||||
for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) {
|
||||
if (wR % dilations.x != 0) {
|
||||
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
|
||||
if (wR % uniforms.dilations.x != 0) {
|
||||
continue;
|
||||
}
|
||||
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]);
|
||||
let wRPerm = filterDims.x - 1 - wR / dilations.x;
|
||||
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 ||
|
||||
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
|
||||
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
|
||||
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
|
||||
wRPerm < 0) {
|
||||
continue;
|
||||
}
|
||||
let idyR: u32 = u32(dyR);
|
||||
|
||||
for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) {
|
||||
if (wC % dilations.y != 0) {
|
||||
for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
|
||||
if (wC % uniforms.dilations.y != 0) {
|
||||
continue;
|
||||
}
|
||||
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y);
|
||||
let wCPerm = filterDims.y - 1 - wC / dilations.y;
|
||||
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) ||
|
||||
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
|
||||
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
|
||||
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
|
||||
fract(dyC) > 0.0 || wCPerm < 0) {
|
||||
continue;
|
||||
}
|
||||
let idyC: u32 = u32(dyC);
|
||||
var inputChannel = groupId * ${inputChannelsPerGroup};
|
||||
for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) {
|
||||
var inputChannel = groupId * uniforms.input_channels_per_group;
|
||||
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
|
||||
let xValue = ${
|
||||
isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
|
||||
dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
|
||||
|
|
@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
`;
|
||||
|
||||
return `
|
||||
${shaderHelper.declareVariables(...inputVariables, output)}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
|
||||
${declareFunctions}
|
||||
const outShape : vec4<u32> = vec4<u32>(${outputShape.join(',')});
|
||||
const outBackprop : vec4<u32> = vec4<u32>(${inputs[0].dims.join(',')});
|
||||
const strides : vec2<u32> = vec2<u32>(${attributes.strides[0]}, ${attributes.strides[1]});
|
||||
const filterDims : vec2<u32> = vec2<u32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
|
||||
attributes.kernelShape[isChannelsLast ? 2 : 3]});
|
||||
const dilations : vec2<u32> = vec2<u32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
|
||||
const effectiveFilterDims : vec2<u32> = filterDims + vec2<u32>(
|
||||
${
|
||||
attributes.dilations[0] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
|
||||
${
|
||||
attributes.dilations[1] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
|
||||
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2,
|
||||
i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2);
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)};
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
|
||||
${isVec4 ? codeSnippet4 : codeSnippet}}`;
|
||||
};
|
||||
|
||||
|
|
@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo =
|
|||
];
|
||||
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
const strides = [attributes.strides[0], attributes.strides[1]];
|
||||
const filterDims =
|
||||
[attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
|
||||
const dilations = [attributes.dilations[0], attributes.dilations[1]];
|
||||
const effectiveFilterDims = [
|
||||
filterDims[0] +
|
||||
(attributes.dilations[0] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)),
|
||||
filterDims[1] +
|
||||
(attributes.dilations[1] <= 1 ?
|
||||
0 :
|
||||
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1))
|
||||
];
|
||||
const pads = [
|
||||
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
|
||||
effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2
|
||||
];
|
||||
|
||||
const isVec4 = false;
|
||||
const group = attributes.group;
|
||||
const wShape = inputs[1].dims;
|
||||
const inputChannelsPerGroup = wShape[0] / group;
|
||||
const outputChannelsPerGroup = wShape[1];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
|
||||
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
|
||||
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
|
||||
];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length},
|
||||
{name: 'filter_dims', type: 'u32', length: filterDims.length},
|
||||
{name: 'dilations', type: 'u32', length: filterDims.length},
|
||||
{name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length},
|
||||
{name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'},
|
||||
{name: 'output_channels_per_group', type: 'u32'}
|
||||
];
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
return `${
|
||||
createConvTranspose2DOpProgramShaderSource(
|
||||
shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms,
|
||||
isChannelsLast)}`;
|
||||
};
|
||||
return {
|
||||
name: 'ConvTranspose2D',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
outputs: [{
|
||||
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
|
||||
dataType: inputs[0].dataType
|
||||
}]
|
||||
}],
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource(
|
||||
shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false,
|
||||
dataType),
|
||||
getShaderSource
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
|
||||
import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
|
||||
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';
|
||||
|
||||
import {typeSnippet} from './activation_util';
|
||||
|
|
@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
|
|||
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
|
||||
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
|
||||
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
|
||||
let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'};
|
||||
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
|
||||
|
||||
var acc: array<vec4<${type}>, rowPerThread>;
|
||||
|
||||
// Loop over shared dimension.
|
||||
let tileRowB = localRow * ${rowPerThreadB};
|
||||
for (var t = 0; t < numTiles; t = t + 1) {
|
||||
for (var t = 0; t < num_tiles; t = t + 1) {
|
||||
// Load one tile of A into local memory.
|
||||
for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
|
||||
let inputRow = tileRow + innerRow;
|
||||
|
|
@ -204,7 +204,7 @@ export const makeMatMulPackedSource =
|
|||
let globalColStart = i32(workgroupId.x) * ${tileBOuter};
|
||||
|
||||
// Loop over shared dimension.
|
||||
for (var t = 0; t < numTiles; t = t + 1) {
|
||||
for (var t = 0; t < num_tiles; t = t + 1) {
|
||||
// Load one tile of A into local memory.
|
||||
for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) {
|
||||
for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) {
|
||||
|
|
@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA};
|
|||
let tileColA = i32(localId.x) * ${colPerThreadA};
|
||||
let tileRowB = i32(localId.y) * ${rowPerThreadB};
|
||||
// Loop over shared dimension.
|
||||
for (var t = 0; t < numTiles; t = t + 1) {
|
||||
for (var t = 0; t < num_tiles; t = t + 1) {
|
||||
// Load one tile of A into local memory.
|
||||
for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) {
|
||||
for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) {
|
||||
|
|
@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
|
|||
@builtin(workgroup_id) workgroupId : vec3<u32>) {
|
||||
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
|
||||
${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
|
||||
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
|
||||
let num_tiles = ${
|
||||
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'};
|
||||
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
|
||||
|
||||
var acc : array<array<${type}, colPerThread>, rowPerThread>;
|
||||
|
|
@ -379,7 +380,7 @@ const matMulReadWriteFnSource =
|
|||
typeSnippet(component, dataType)} {
|
||||
var value = ${typeSnippet(component, dataType)}(0.0);
|
||||
let col = colIn * ${component};
|
||||
if(row < uniforms.dimAOuter && col < uniforms.dimInner)
|
||||
if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
|
||||
{
|
||||
${getAIndices()}
|
||||
value = ${aVariable.getByIndices('aIndices')};
|
||||
|
|
@ -391,7 +392,7 @@ const matMulReadWriteFnSource =
|
|||
typeSnippet(component, dataType)} {
|
||||
var value = ${typeSnippet(component, dataType)}(0.0);
|
||||
let col = colIn * ${component};
|
||||
if(row < uniforms.dimInner && col < uniforms.dimBOuter)
|
||||
if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
|
||||
{
|
||||
${getBIndices()}
|
||||
value = ${bVariable.getByIndices('bIndices')};
|
||||
|
|
@ -401,7 +402,7 @@ const matMulReadWriteFnSource =
|
|||
|
||||
fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) {
|
||||
let col = colIn * ${component};
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
||||
if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {
|
||||
var value = valueIn;
|
||||
let coords = vec3<i32>(batch, row, colIn);
|
||||
${
|
||||
|
|
@ -422,16 +423,10 @@ export const createMatmulProgramInfo =
|
|||
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
|
||||
const aShape = inputs[0].dims;
|
||||
const bShape = inputs[1].dims;
|
||||
|
||||
const outerDimsA = aShape.slice(0, -2);
|
||||
const outerDimsB = bShape.slice(0, -2);
|
||||
|
||||
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
|
||||
const enableBatchUniforms = enableShapesUniforms(outerDims.length);
|
||||
const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims;
|
||||
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
|
||||
const batchSize = ShapeUtil.size(outerDims);
|
||||
|
||||
const dimAOuter = aShape[aShape.length - 2];
|
||||
const dimInner = aShape[aShape.length - 1];
|
||||
const dimBOuter = bShape[bShape.length - 1];
|
||||
|
|
@ -446,72 +441,67 @@ export const createMatmulProgramInfo =
|
|||
Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
|
||||
];
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
const components = isVec4 ? 4 : 1;
|
||||
|
||||
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
|
||||
const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length);
|
||||
const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp;
|
||||
|
||||
const aShapeOrRank = aShapeTemp.length;
|
||||
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
|
||||
const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length);
|
||||
const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp;
|
||||
|
||||
const bShapeOrRank = bShapeTemp.length;
|
||||
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
|
||||
|
||||
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
|
||||
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
|
||||
const inputVariables = [A, B];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
if (enableBatchUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outerDims));
|
||||
if (activationAttributes.activation === 'Clip') {
|
||||
programUniforms.push(
|
||||
{type: 'float32', data: activationAttributes.clipMax!},
|
||||
{type: 'float32', data: activationAttributes.clipMin!});
|
||||
}
|
||||
if (enableAShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(aShapeTemp));
|
||||
}
|
||||
if (enableBShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(bShapeTemp));
|
||||
}
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims');
|
||||
inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims');
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
|
||||
...createTensorShapeVariables(bShapeTemp));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
|
||||
const hasBias = inputs.length > 2;
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const declareFunctions = matMulReadWriteFnSource(
|
||||
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
|
||||
isChannelsLast);
|
||||
if (hasBias) {
|
||||
const biasComponents = isChannelsLast ? components : 1;
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const batchShapeOrRank = outerDims.length;
|
||||
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
|
||||
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
|
||||
const inputVariables = [A, B];
|
||||
if (hasBias) {
|
||||
const biasComponents = isChannelsLast ? components : 1;
|
||||
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
|
||||
}
|
||||
const uniforms: UniformsArrayType =
|
||||
[{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}];
|
||||
if (activationAttributes.activation === 'Clip') {
|
||||
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
|
||||
}
|
||||
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const declareFunctions = matMulReadWriteFnSource(
|
||||
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
|
||||
isChannelsLast);
|
||||
return `
|
||||
${
|
||||
shaderHelper.registerUniform('dimAOuter', 'i32')
|
||||
.registerUniform('dimBOuter', 'i32')
|
||||
.registerUniform('dimInner', 'i32')
|
||||
.registerInternalVariables(batchDims)
|
||||
.declareVariables(...inputVariables, output)}
|
||||
${activationFunction}
|
||||
shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables(
|
||||
...inputVariables, output)}
|
||||
${declareFunctions}
|
||||
${
|
||||
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
|
||||
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
|
||||
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
|
||||
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
|
||||
`;
|
||||
// TODO: turn clipMax and clipMin to uniforms.
|
||||
};
|
||||
return {
|
||||
name: 'MatMul',
|
||||
shaderCache: {
|
||||
hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
|
||||
`${isVec4}` +
|
||||
`${isChannelsLast}`,
|
||||
hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`,
|
||||
inputDependencies
|
||||
},
|
||||
getRunData: () => ({
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ProgramInfo, ProgramUniform} from '../types';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
import {calculateOutputShape, ConvAttributes} from './conv';
|
||||
import {getActivationSnippet} from './fuse-utils';
|
||||
|
||||
|
|
@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo =
|
|||
xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
|
||||
const x = inputVariable('x', inputs[0].dataType, xShape);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape);
|
||||
const inputVars = [x, w];
|
||||
if (hasBias) {
|
||||
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims));
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
|
||||
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
|
||||
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
programUniforms.push(
|
||||
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
|
||||
}
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
|
||||
...createTensorShapeVariables(outputShape));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const strides: vec2<u32> = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u);
|
||||
const pads: vec2<u32> = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u);
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const applyActivation = getActivationSnippet(attributes, output.type.value);
|
||||
const x = inputVariable('x', inputs[0].dataType, xShape.length);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape.length);
|
||||
const inputVars = [x, w];
|
||||
if (hasBias) {
|
||||
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims));
|
||||
}
|
||||
|
||||
${shaderHelper.declareVariables(...inputVars, output)}
|
||||
|
||||
${activationFunction}
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length},
|
||||
{name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2},
|
||||
{name: 'output_channels_per_group', type: 'u32'}
|
||||
];
|
||||
if (attributes.activation === 'Clip') {
|
||||
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
|
||||
}
|
||||
return `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
let batch: u32 = outputIndices[0];
|
||||
let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}];
|
||||
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${
|
||||
isChannelLast ? 2 : 3}]) * strides - pads;
|
||||
let group_id: u32 = output_channel / ${outputChannelsPerGroup}u;
|
||||
isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads;
|
||||
let group_id: u32 = output_channel / uniforms.output_channels_per_group;
|
||||
|
||||
var value: ${output.type.value} = ${output.type.value}(0);
|
||||
for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) {
|
||||
let input_channel = group_id * ${wShape[1]}u + wInChannel;
|
||||
for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) {
|
||||
let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u;
|
||||
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
|
||||
let input_channel = group_id * uniforms.w_shape[1] + wInChannel;
|
||||
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
|
||||
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
|
||||
|
||||
if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) {
|
||||
if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) {
|
||||
let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u;
|
||||
if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) {
|
||||
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
|
||||
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
|
||||
if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let xVal = ${
|
||||
isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') :
|
||||
x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
|
||||
isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') :
|
||||
x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
|
||||
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
|
||||
value += xVal*wVal;
|
||||
}
|
||||
|
|
@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo =
|
|||
${applyActivation}
|
||||
${output.setByOffset('global_idx', 'value')}
|
||||
}`;
|
||||
};
|
||||
return {
|
||||
name: 'GroupedConv',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
shaderCache: {hint: attributes.cacheKey, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{
|
||||
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
|
||||
dataType: inputs[0].dataType
|
||||
}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo =
|
|||
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
|
||||
const applyActivation = getActivationSnippet(attributes, output.type.value);
|
||||
const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
|
||||
const inputVars = [x, w];
|
||||
|
|
@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo =
|
|||
.registerUniform('strides', 'i32', 2)
|
||||
.registerUniform('pads', 'i32', 2)
|
||||
.declareVariables(...inputVars, output)}
|
||||
${activationFunction}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let width0 = uniforms.output_shape[3];
|
||||
|
|
@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo =
|
|||
return {
|
||||
name: 'GroupedConv-Vectorize',
|
||||
shaderCache: {
|
||||
hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
|
||||
hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
|
||||
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank']
|
||||
},
|
||||
getRunData: () => ({
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext} from '../types';
|
||||
|
||||
import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu';
|
||||
|
|
@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes {
|
|||
readonly outputShape: readonly number[];
|
||||
}
|
||||
|
||||
|
||||
const getAdjustedConvTransposeAttributes =
|
||||
<T extends ConvTransposeAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
|
||||
const kernelShape = attributes.kernelShape.slice();
|
||||
|
|
@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes =
|
|||
|
||||
// always return a new object so does not modify the original attributes
|
||||
const newAttributes: T = Object.assign({}, attributes);
|
||||
const cacheKey = attributes.cacheKey + [
|
||||
kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','),
|
||||
dilations.join(',')
|
||||
].join('_');
|
||||
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey});
|
||||
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides});
|
||||
return newAttributes;
|
||||
};
|
||||
|
||||
|
|
@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record<string, unknown>
|
|||
const wIsConst = (attributes.wIsConst as () => boolean)();
|
||||
const outputPadding = attributes.outputPadding as [number, number, number, number];
|
||||
const outputShape = attributes.outputShape as [number, number];
|
||||
return createAttributeWithCacheKey({
|
||||
return {
|
||||
autoPad,
|
||||
format,
|
||||
dilations,
|
||||
|
|
@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record<string, unknown>
|
|||
pads,
|
||||
strides,
|
||||
wIsConst,
|
||||
...activationAttributes
|
||||
});
|
||||
...activationAttributes,
|
||||
cacheKey: `${attributes.format};${activationAttributes.activation};`
|
||||
};
|
||||
};
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {PoolConvUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext} from '../types';
|
||||
|
||||
import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
|
||||
|
|
@ -110,7 +110,7 @@ const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inpu
|
|||
|
||||
// always return a new object so does not modify the original attributes
|
||||
const newAttributes: T = Object.assign({}, attributes);
|
||||
Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey});
|
||||
Object.assign(newAttributes, {kernelShape, pads});
|
||||
return newAttributes;
|
||||
};
|
||||
|
||||
|
|
@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
|
|||
const strides = attributes.strides as [number, number];
|
||||
const wIsConst = (attributes.w_is_const as () => boolean)();
|
||||
|
||||
return createAttributeWithCacheKey(
|
||||
{autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes});
|
||||
return {
|
||||
autoPad,
|
||||
format,
|
||||
dilations,
|
||||
group,
|
||||
kernelShape,
|
||||
pads,
|
||||
strides,
|
||||
wIsConst,
|
||||
...activationAttributes,
|
||||
cacheKey: `${attributes.format};${activationAttributes.activation};`
|
||||
};
|
||||
};
|
||||
|
||||
const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
|
||||
|
|
|
|||
|
|
@ -7,30 +7,21 @@ export interface InternalActivationAttributes {
|
|||
readonly activation: string;
|
||||
readonly clipMin?: number;
|
||||
readonly clipMax?: number;
|
||||
readonly activationCacheKey: string;
|
||||
}
|
||||
|
||||
export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string):
|
||||
{activationFunction: string; applyActivation: string} => {
|
||||
switch (attributes.activation) {
|
||||
case 'Relu':
|
||||
return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`};
|
||||
case 'Sigmoid':
|
||||
return {
|
||||
activationFunction: '',
|
||||
applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`
|
||||
};
|
||||
case 'Clip':
|
||||
return {
|
||||
activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${
|
||||
attributes.clipMax!});`,
|
||||
applyActivation: 'value = clamp(value, clip_min_, clip_max_);'
|
||||
};
|
||||
// TODO: adding other activations that can be fused.
|
||||
default:
|
||||
return {activationFunction: '', applyActivation: ''};
|
||||
}
|
||||
};
|
||||
export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => {
|
||||
switch (attributes.activation) {
|
||||
case 'Relu':
|
||||
return `value = max(value, ${valueType}(0.0));`;
|
||||
case 'Sigmoid':
|
||||
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
|
||||
case 'Clip':
|
||||
return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`;
|
||||
// TODO: adding other activations that can be fused.
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
export const parseInternalActivationAttributes =
|
||||
(attributes: Record<string, unknown>|undefined): InternalActivationAttributes => {
|
||||
|
|
@ -38,7 +29,7 @@ export const parseInternalActivationAttributes =
|
|||
|
||||
if (activation === 'Clip') {
|
||||
const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP];
|
||||
return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`};
|
||||
return {activation, clipMax, clipMin};
|
||||
}
|
||||
return {activation, activationCacheKey: activation};
|
||||
return {activation};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util';
|
|||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
|
||||
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common';
|
||||
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common';
|
||||
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
|
||||
|
||||
export const createNaiveMatmulProgramInfo =
|
||||
|
|
@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo =
|
|||
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
|
||||
const batchSize = ShapeUtil.size(outerDims);
|
||||
const outputShapeInShader = [batchSize, M, N];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
|
||||
{type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
|
||||
...createTensorShapeVariables(bShape)
|
||||
{type: 'uint32', data: K}
|
||||
];
|
||||
if (activationAttributes.activation === 'Clip') {
|
||||
programUniforms.push(
|
||||
{type: 'float32', data: activationAttributes.clipMax!},
|
||||
{type: 'float32', data: activationAttributes.clipMin!});
|
||||
}
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
|
||||
...createTensorShapeVariables(bShape));
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
}
|
||||
|
|
@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo =
|
|||
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
|
||||
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
|
||||
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
|
||||
const inputVariables = [a, b];
|
||||
let processBias = '';
|
||||
if (hasBias) {
|
||||
|
|
@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo =
|
|||
const outerDimsB = bShape.slice(0, -2);
|
||||
const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
|
||||
const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'K', type: 'u32'}
|
||||
];
|
||||
if (activationAttributes.activation === 'Clip') {
|
||||
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
|
||||
}
|
||||
|
||||
const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
|
||||
const rank = variable.rank;
|
||||
const name = variable.name;
|
||||
|
|
@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo =
|
|||
|
||||
return `
|
||||
${
|
||||
shaderHelper.registerUniform('outputSize', 'u32')
|
||||
.registerUniform('M', 'u32')
|
||||
.registerUniform('N', 'u32')
|
||||
.registerUniform('K', 'u32')
|
||||
.registerInternalVariables(batchDims)
|
||||
.declareVariables(...inputVariables, output)}
|
||||
${activationFunction}
|
||||
shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables(
|
||||
...inputVariables, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let col = (global_idx % (uniforms.N / ${components})) * ${components};
|
||||
var index1 = global_idx / (uniforms.N / ${components});
|
||||
let stride1 = uniforms.M / ${outputNumber};
|
||||
|
|
@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo =
|
|||
return {
|
||||
name: 'MatMulNaive',
|
||||
shaderCache: {
|
||||
hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${
|
||||
isChannelsLast}`,
|
||||
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
|
||||
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank']
|
||||
},
|
||||
getRunData: () => ({
|
||||
|
|
@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => {
|
|||
const N = outputShape[outputShape.length - 1];
|
||||
const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
|
||||
if (N < 8 && K < 8) {
|
||||
context.compute(
|
||||
createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
|
||||
context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
|
||||
} else {
|
||||
context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
|
||||
context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue