From 42ba2aed54458a2b645d859efb830a3c431ec896 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 10 Jan 2024 01:34:56 +0800 Subject: [PATCH] [js/webgpu] Support pad uniforms (#19057) ### Description ### Motivation and Context --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 220 +++++++++--------- 2 files changed, 106 insertions(+), 118 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index c182d3c4ea..ed174bc098 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -20,7 +20,7 @@ import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; -import {pad, parsePadAttributes} from './ops/pad'; +import {pad} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; @@ -95,7 +95,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], - ['Pad', [pad, parsePadAttributes]], + ['Pad', [pad]], ['Pow', [binaryOps.pow]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 18859e253a..eca3fa7d94 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,15 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; +import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common'; -export interface PadAttributes extends AttributeWithCacheKey { +interface PadAttributes { // 0-constant, 1-reflect, 2-edge, 3-wrap readonly mode: number; readonly value: number; @@ -35,27 +34,23 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const getPadConstant = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[], - dataType: string, constantValue: number): string => { - const inputRank = inputDims.length; - - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { break; } - if (k >= ${inputDims[i]}) { + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { break; } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` - value = ${dataType}(${constantValue}); + return ` + value = ${output.type.value}(uniforms.constant_value); for (var i = 0; i < 1; i++) { var offset = 0; var k = 0; @@ -63,143 +58,143 @@ const getPadConstant = value = x[offset]; } `; - }; +}; -const getPadReflect = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; - - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { k = -k; } { - let _2n_1 = ${2 * (inputDims[i] - 1)}; + let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1); k = k % _2n_1; - if(k >= ${inputDims[i]}) { + if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { k = _2n_1 - k; } } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; +}; -const getPadEdge = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; - - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { k = 0; } - if (k >= ${inputDims[i]}) { - k = ${inputDims[i] - 1}; + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { + k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1; } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; +}; -const getPadWrap = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; - - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { - k += ${inputDims[i]}; + k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]); } - if (k >= ${inputDims[i]}) { - k -= ${inputDims[i]}; + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { + k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)}); } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; +}; -const getPadSnippet = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes, - dataType: string): string => { - switch (attributes.mode) { - case 0: - return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value); - case 1: - return getPadReflect(output, inputDims, inputStrides, attributes.pads); - case 2: - return getPadEdge(output, inputDims, inputStrides, attributes.pads); - case 3: - return getPadWrap(output, inputDims, inputStrides, attributes.pads); - default: - throw new Error('Invalid mode'); - } - }; - -const generatePadCode = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string): - string => { - const inputDims = inputs[0].dims; - const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads); - const outputSize = ShapeUtil.size(outputDims); - const inputStrides = ShapeUtil.computeStrides(inputDims); - - const output = outputVariable('output', inputs[0].dataType, outputDims); - const input = inputVariable('x', inputs[0].dataType, inputDims); - - const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType); - const padCode = ` - ${shaderHelper.declareVariables(input, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - - let indices = ${output.offsetToIndices('global_idx')}; - - var value = ${dataType}(0); - ${padSnippet} - output[global_idx] = value; - }`; - return padCode; - }; +const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => { + switch (attributes.mode) { + case 0: + return getPadConstant(output, inputRank, attributes.pads.length); + case 1: + return getPadReflect(output, inputRank, attributes.pads.length); + case 2: + return getPadEdge(output, inputRank, attributes.pads.length); + case 3: + return getPadWrap(output, inputRank, attributes.pads.length); + default: + throw new Error('Invalid mode'); + } +}; const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => { const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); + const inputDims = inputs[0].dims; + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + if (attributes.mode === 0) { + const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; + programUniforms.push({type: tensorDataType, data: attributes.value}); + } + + programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const input = inputVariable('x', inputs[0].dataType, inputDims.length); + const dataType = input.type.value; + const padSnippet = getPadSnippet(output, inputDims.length, attributes); + const uniforms: UniformsArrayType = + [{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}]; + if (attributes.mode === 0) { + uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType}); + } + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + + let indices = ${output.offsetToIndices('global_idx')}; + + var value = ${dataType}(0); + ${padSnippet} + output[global_idx] = value; + }`; + }; + return { name: 'Pad', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.mode}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, + programUniforms }), - getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), + getShaderSource, }; }; @@ -223,7 +218,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes const pads: number[] = []; updatePads.forEach(v => pads.push(v)); - return createAttributeWithCacheKey({mode: attributes.mode, value, pads}); + return {mode: attributes.mode, value, pads}; } else { return attributes; } @@ -234,10 +229,3 @@ export const pad = (context: ComputeContext, attributes: PadAttributes): void => const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes); context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); }; - -export const parsePadAttributes = (attributes: Record): PadAttributes => { - const mode = attributes.mode as number; - const value = attributes.value as number; - const pads = attributes.pads as number[]; - return createAttributeWithCacheKey({mode, value, pads}); -};