[js/webgpu] Support pad uniforms (#19057)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Xu Xing 2024-01-10 01:34:56 +08:00 committed by GitHub
parent eb92681bfb
commit 42ba2aed54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 118 deletions

View file

@ -20,7 +20,7 @@ import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
import {pad, parsePadAttributes} from './ops/pad';
import {pad} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
@ -95,7 +95,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
['Not', [unaryOps.not]],
['Pad', [pad, parsePadAttributes]],
['Pad', [pad]],
['Pow', [binaryOps.pow]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],

View file

@ -1,15 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common';
export interface PadAttributes extends AttributeWithCacheKey {
interface PadAttributes {
// 0-constant, 1-reflect, 2-edge, 3-wrap
readonly mode: number;
readonly value: number;
@ -35,27 +34,23 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
}
};
const getPadConstant =
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[],
dataType: string, constantValue: number): string => {
const inputRank = inputDims.length;
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
if (k < 0) {
break;
}
if (k >= ${inputDims[i]}) {
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
break;
}
offset += k * ${inputStrides[i]};
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
`;
}
}
return `
value = ${dataType}(${constantValue});
return `
value = ${output.type.value}(uniforms.constant_value);
for (var i = 0; i < 1; i++) {
var offset = 0;
var k = 0;
@ -63,143 +58,143 @@ const getPadConstant =
value = x[offset];
}
`;
};
};
const getPadReflect =
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
if (k < 0) {
k = -k;
}
{
let _2n_1 = ${2 * (inputDims[i] - 1)};
let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1);
k = k % _2n_1;
if(k >= ${inputDims[i]}) {
if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
k = _2n_1 - k;
}
}
offset += k * ${inputStrides[i]};
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
`;
}
}
return `
return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};
};
const getPadEdge =
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
if (k < 0) {
k = 0;
}
if (k >= ${inputDims[i]}) {
k = ${inputDims[i] - 1};
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1;
}
offset += k * ${inputStrides[i]};
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
`;
}
}
return `
return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};
};
const getPadWrap =
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
if (k < 0) {
k += ${inputDims[i]};
k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]);
}
if (k >= ${inputDims[i]}) {
k -= ${inputDims[i]};
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)});
}
offset += k * ${inputStrides[i]};
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
`;
}
}
return `
return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};
};
const getPadSnippet =
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes,
dataType: string): string => {
switch (attributes.mode) {
case 0:
return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value);
case 1:
return getPadReflect(output, inputDims, inputStrides, attributes.pads);
case 2:
return getPadEdge(output, inputDims, inputStrides, attributes.pads);
case 3:
return getPadWrap(output, inputDims, inputStrides, attributes.pads);
default:
throw new Error('Invalid mode');
}
};
const generatePadCode =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string):
string => {
const inputDims = inputs[0].dims;
const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads);
const outputSize = ShapeUtil.size(outputDims);
const inputStrides = ShapeUtil.computeStrides(inputDims);
const output = outputVariable('output', inputs[0].dataType, outputDims);
const input = inputVariable('x', inputs[0].dataType, inputDims);
const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType);
const padCode = `
${shaderHelper.declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let indices = ${output.offsetToIndices('global_idx')};
var value = ${dataType}(0);
${padSnippet}
output[global_idx] = value;
}`;
return padCode;
};
const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => {
switch (attributes.mode) {
case 0:
return getPadConstant(output, inputRank, attributes.pads.length);
case 1:
return getPadReflect(output, inputRank, attributes.pads.length);
case 2:
return getPadEdge(output, inputRank, attributes.pads.length);
case 3:
return getPadWrap(output, inputRank, attributes.pads.length);
default:
throw new Error('Invalid mode');
}
};
const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => {
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
const inputDims = inputs[0].dims;
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}];
if (attributes.mode === 0) {
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
programUniforms.push({type: tensorDataType, data: attributes.value});
}
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const input = inputVariable('x', inputs[0].dataType, inputDims.length);
const dataType = input.type.value;
const padSnippet = getPadSnippet(output, inputDims.length, attributes);
const uniforms: UniformsArrayType =
[{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}];
if (attributes.mode === 0) {
uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType});
}
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let indices = ${output.offsetToIndices('global_idx')};
var value = ${dataType}(0);
${padSnippet}
output[global_idx] = value;
}`;
};
return {
name: 'Pad',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.mode}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'),
getShaderSource,
};
};
@ -223,7 +218,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes
const pads: number[] = [];
updatePads.forEach(v => pads.push(v));
return createAttributeWithCacheKey({mode: attributes.mode, value, pads});
return {mode: attributes.mode, value, pads};
} else {
return attributes;
}
@ -234,10 +229,3 @@ export const pad = (context: ComputeContext, attributes: PadAttributes): void =>
const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes);
context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]});
};
export const parsePadAttributes = (attributes: Record<string, unknown>): PadAttributes => {
const mode = attributes.mode as number;
const value = attributes.value as number;
const pads = attributes.pads as number[];
return createAttributeWithCacheKey({mode, value, pads});
};