From 50e6235af111e5113860dfd7a0ece55dc00316a0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:15:59 -0800 Subject: [PATCH] [js/web] allow ShaderHelper to use internal (non-I/O) variables (#18525) ### Description This PR includes a change that inspired from #18452 to resolve a requirement: a shader may depend on an instance of `IndicesHelper` to generate WGSL code snippet, but the IndicesHelper instance is not necessarily an input/output of the program. So the existing `declareVariables()` function does not work with this scenario. In order to support this requirement, I added this "use" function to `interface ShaderHelper`, which takes a helper-like object as parameter. The hidden implementation `ShaderHelperImpl` class will iterate the helpers and call `impl()` for each. @axinging @qjia7 --- .../ops/3rd-party/matmul_packed_webgpu.ts | 26 ++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 108 ++++++++++++------ 2 files changed, 83 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 3e52057177..a8f296ea0c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -341,13 +341,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, const matMulReadWriteFnSource = (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], batchShapes: Array, isChannelsLast = false): string => { - const batchAShape = batchShapes[0]; - const batchBShape = batchShapes[1]; - const batchShape = batchShapes[2]; - const batchVariable = variables[0]; - const aVariable = variables[1]; - const bVariable = variables[2]; - const outputVariable = variables[3]; + const [batchAShape, batchBShape, batchShape] = batchShapes; + const [batchVariable, aVariable, bVariable, outputVariable] = variables; const broadCastADims = getBroadcastDims(batchAShape, batchShape); const broadCastBDims = getBroadcastDims(batchBShape, batchShape); const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); @@ -434,9 +429,7 @@ export const createMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const enableBatchUniforms = enableShapesUniforms(outerDims.length); const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true); - const variables = [batchDims]; - const batchShapes = [outerDimsA, outerDimsB, outerDims]; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); const dimAOuter = aShape[aShape.length - 2]; @@ -469,10 +462,7 @@ export const createMatmulProgramInfo = const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - variables.push(A); - variables.push(B); - variables.push(output); - const inputVariables = [batchDims, A, B]; + const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; if (enableBatchUniforms) { @@ -490,8 +480,9 @@ export const createMatmulProgramInfo = const hasBias = inputs.length > 2; const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = - matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); if (hasBias) { const biasComponents = isChannelsLast ? components : 1; inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); @@ -506,6 +497,7 @@ export const createMatmulProgramInfo = shaderHelper.registerUniform('dimAOuter', 'i32') .registerUniform('dimBOuter', 'i32') .registerUniform('dimInner', 'i32') + .registerInternalVariables(batchDims) .declareVariables(...inputVariables, output)} ${activationFunction} ${declareFunctions} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index f7ae18998b..b7a391ee66 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -58,10 +58,11 @@ interface IndicesHelperTypes { * create an instance of an indices helper: * - `inputVariable()`: create an indices helper instance for an input. * - `outputVariable()`: create an indices helper instance for an output. + * - `internalVariable()`: create an indices helper instance for an internal variable. * * An indices helper instance contains helper functions for the following operations: * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an - * input or an output) and `shape`(the passed in shape). + * input, an output or an internal variable) and `shape`(the passed in shape). * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView). * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate @@ -192,9 +193,9 @@ export interface IndicesHelper { readonly name: string; /** - * whether the helper is for an input or an output. + * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input'|'output'; + readonly usage: 'input'|'output'|'internal'; /** * the rank of the input or output. @@ -210,11 +211,6 @@ export interface IndicesHelper { * a string representing the variable name for the strides of the input or output. */ readonly strides: string; - - /** - * representing variable with uniforms, but without binding. - */ - readonly uniformOnly: boolean; } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { @@ -335,13 +331,13 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of the input or output. * @param tensorType - the tensor type of the input or output. * @param shapeOrRank - the tensor shape or the rank of the input or output. - * @param isInput - whether the helper is for an input or an output. + * @param usage - the usage of the indices helper. * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4, - uniformOnly = false): IndicesHelper => { + (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], + components: 1|2|3|4): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; const rankIdentity = [...new Array(rank).keys()]; @@ -363,7 +359,7 @@ const createIndicesHelper = getByIndices: false, }; - const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : ''; + const uniformPrefix = useUniform ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; let o2iSnippet = ''; @@ -617,12 +613,11 @@ const createIndicesHelper = getByOffset, getByIndices, // isVec4, - usage: isInput ? 'input' : 'output', + usage, name, strides, shape, - rank, - uniformOnly + rank }; }; @@ -636,8 +631,8 @@ const createIndicesHelper = * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false): - IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. @@ -650,7 +645,20 @@ export const inputVariable = */ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, false, components); + createIndicesHelper(name, type, shapeOrRank, 'output', components); + +/** + * Create a IndicesHelper for an internal variable. + * + * @param name - the name of the variable. + * @param type - the tensor type of the variable. + * @param shapeOrRank - the tensor shape or the rank of the variable. + * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. + * @returns an IndicesHelper for the variable. + */ +export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); export type UniformsArrayType = Array<{name: string; type: string}>; @@ -703,9 +711,27 @@ export interface ShaderHelper { /** * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. + * + * @param name - the name of the uniform. + * @param type - the type of the uniform. */ registerUniform(name: string, type: string): ShaderHelper; - registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. + * + * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and + * `type`. + */ + registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple internal variables. Can be called multiple times to register multiple + * internal variables. + * + * @param variables - an array of IndicesHelper for the variables. + */ + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -740,8 +766,7 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex = -1): string { - this.indicesHelpers.push(variable); + private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); @@ -750,24 +775,37 @@ class ShaderHelperImpl implements ShaderHelper { this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); } } - if (variable.uniformOnly) { - return ''; + } + + private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + if (variable.usage === 'internal') { + throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.'); } + this.variables.push(variable); + this.appendVariableUniforms(variable); + const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } declareVariables(...variables: IndicesHelper[]): string { - return variables - .map(v => { - if (v.uniformOnly === true) { - return this.declareVariable(v); - } else { - return this.declareVariable(v, this.variableIndex++); - } - }) - .join('\n'); + return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + } + + private registerInternalVariable(variable: IndicesHelper): void { + if (variable.usage !== 'internal') { + throw new Error( + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + } + + this.internalVariables.push(variable); + this.appendVariableUniforms(variable); + } + + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { + variables.forEach(v => this.registerInternalVariable(v)); + return this; } registerUniform(name: string, type: string): ShaderHelper { @@ -780,7 +818,8 @@ class ShaderHelperImpl implements ShaderHelper { return this; } - private indicesHelpers: IndicesHelper[] = []; + private internalVariables: IndicesHelper[] = []; + private variables: IndicesHelper[] = []; private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { @@ -802,7 +841,8 @@ class ShaderHelperImpl implements ShaderHelper { * Get additional implementation that needs to be added to the shader source. */ get additionalImplementations(): string { - return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n'); + return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + + this.internalVariables.map(i => i.impl()).join('\n'); } }