diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 4ee1fd5442..bb86f147c9 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -338,51 +338,26 @@ export class WebGpuBackend { let uniformBufferBinding: GPUBindingResource|undefined; if (programUniforms) { let currentOffset = 0; - let preLength = 0; const offsets: number[] = []; - let maxAlignmentOfField = 1; + programUniforms.forEach(v => { const data = typeof v.data === 'number' ? [v.data] : v.data; if (data.length === 0) { return; } // https://www.w3.org/TR/WGSL/#alignof - let baseAlignment: number; - switch (data.length) { - case 1: - baseAlignment = 4; - break; - case 2: - baseAlignment = 8; - break; - case 3: - baseAlignment = 16; - break; - case 4: - baseAlignment = 16; - break; - case 5: - baseAlignment = 16; - break; - case 6: - baseAlignment = 16; - break; - default: - throw new Error(`unsupported data length: ${data.length}`); - } - - if (preLength === 5 || preLength === 6) { - baseAlignment = 16; - } - if (baseAlignment > maxAlignmentOfField) { - maxAlignmentOfField = baseAlignment; - } + const baseAlignment = data.length <= 2 ? data.length * 4 : 16; currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; - preLength = data.length; offsets.push(currentOffset); - currentOffset += data.length * 4; + // When data.length > 4, the uniform variable is of type array,N>, where N = + // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). + currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; }); + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16. + const maxAlignmentOfField = 16; currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField; const arrayBuffer = new ArrayBuffer(currentOffset); programUniforms.forEach((v, i) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index b7a391ee66..af7202903d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -325,6 +325,20 @@ export const sumVector = (name: string, components: number) => { return name; }; +/** + * A helper function that returns uniform element at index. + * @param name - the name of uniform element. + * @param index - the index of uniform element. + * @param length - the length of uniform element. + */ +export const getUniformElementAt = (name: string, index: number|string, length: number): string => { + if (typeof (index) === 'string') { + return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name; + } else { + return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name; + } +}; + /** * A helper function to get a IndicesHelper for a given input or output. * @@ -362,11 +376,12 @@ const createIndicesHelper = const uniformPrefix = useUniform ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; + let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { o2iSnippet += ` - let dim${i} = current / ${strides}[${i}]; - let rest${i} = current % ${strides}[${i}]; + let dim${i} = current / ${getUniformElementAt(strides, i, rank)}; + let rest${i} = current % ${getUniformElementAt(strides, i, rank)}; indices[${i}] = dim${i}; current = rest${i}; `; @@ -389,7 +404,7 @@ const createIndicesHelper = const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { - offsets.push(`${strides}[${i}] * (indices[${i}])`); + offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`); } } @@ -660,7 +675,8 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformsArrayType = Array<{name: string; type: string}>; +export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** * A ShaderHelper is a helper class for generating WGSL code. @@ -714,8 +730,9 @@ export interface ShaderHelper { * * @param name - the name of the uniform. * @param type - the type of the uniform. + * @param length - the length of the uniform, default to 1 when it is not provided. */ - registerUniform(name: string, type: string): ShaderHelper; + registerUniform(name: string, type: string, length?: number): ShaderHelper; /** * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. @@ -769,10 +786,10 @@ class ShaderHelperImpl implements ShaderHelper { private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { - this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); + this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank}); } if (variable.strides.startsWith('uniforms.')) { - this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); + this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank}); } } } @@ -808,8 +825,8 @@ class ShaderHelperImpl implements ShaderHelper { return this; } - registerUniform(name: string, type: string): ShaderHelper { - this.uniforms.push({name, type}); + registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper { + this.uniforms.push({name, type, length}); return this; } @@ -827,8 +844,13 @@ class ShaderHelperImpl implements ShaderHelper { } const uniformSnippets: string[] = []; - for (const {name, type} of this.uniforms) { - uniformSnippets.push(`${name}:${type}`); + for (const {name, type, length} of this.uniforms) { + if (length && length > 4) { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } else { + const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; + uniformSnippets.push(`${name}:${typeTemp}`); + } } return ` @@ -872,5 +894,5 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly return dims; }; -// TODO: remove this limitation once >4D dims are supported by uniform. -export const enableShapesUniforms = (rank: number): boolean => rank <= 4; +// TODO: remove this when all related uses have been removed. +export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 7458579bf4..aa68cd0b2c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,20 +77,15 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - enableInputShapeUniforms: boolean): string => - `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): + string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { - let input_shape_i = ${ - enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; - let steps_i = ${ - enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; - let signs_i = ${ - enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; - let starts_i = ${ - enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; + let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)}; + let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)}; + let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)}; + let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; var inputIndex = outputIndex * steps_i + starts_i + carry; carry = inputIndex / input_shape_i; @@ -145,47 +140,29 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice } }); // Output rank is expected to be less than or equal to the input rank. - const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); - const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; - const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); - const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; - const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); - const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = []; - const uniforms: UniformsArrayType = []; - if (enableShapeUniforms) { - uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); - uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); - uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); - programUniforms.push({type: 'uint32', data: starts}); - programUniforms.push({type: 'int32', data: signs}); - programUniforms.push({type: 'uint32', data: steps}); - } - uniforms.push({name: 'outputSize', type: 'u32'}); - programUniforms.push({type: 'uint32', data: outputSize}); - if (enableShapeUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + const uniforms: UniformsArrayType = [ + {name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length}, + {name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length} + ]; + + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, + {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} - ${enableShapeUniforms ? '' : [ - `const signs = array(${signs.map(i => `${i}i`).join(',')});`, - `const starts = array(${starts.map(i => `${i}u`).join(',')});`, - `const steps = array(${steps.map(i => `${i}u`).join(',')});`, - `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` - ].join('\n')} - - ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} + ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; @@ -194,11 +171,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice }`; return { name: 'Slice', - shaderCache: { - hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : - `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, - inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] - }, + shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo],