[js/webgpu] Add multidimensional(>4) uniform support (#18546)

This change removes the check of enableShapesUniforms. When all uses of
this are removed, enableShapesUniforms can be removed too.
This commit is contained in:
Xu Xing 2023-12-01 09:10:33 +08:00 committed by GitHub
parent 73a2eb82eb
commit 73d9b03509
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 95 deletions

View file

@ -338,51 +338,26 @@ export class WebGpuBackend {
let uniformBufferBinding: GPUBindingResource|undefined;
if (programUniforms) {
let currentOffset = 0;
let preLength = 0;
const offsets: number[] = [];
let maxAlignmentOfField = 1;
programUniforms.forEach(v => {
const data = typeof v.data === 'number' ? [v.data] : v.data;
if (data.length === 0) {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
let baseAlignment: number;
switch (data.length) {
case 1:
baseAlignment = 4;
break;
case 2:
baseAlignment = 8;
break;
case 3:
baseAlignment = 16;
break;
case 4:
baseAlignment = 16;
break;
case 5:
baseAlignment = 16;
break;
case 6:
baseAlignment = 16;
break;
default:
throw new Error(`unsupported data length: ${data.length}`);
}
if (preLength === 5 || preLength === 6) {
baseAlignment = 16;
}
if (baseAlignment > maxAlignmentOfField) {
maxAlignmentOfField = baseAlignment;
}
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
preLength = data.length;
offsets.push(currentOffset);
currentOffset += data.length * 4;
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>).
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
});
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
// maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16.
const maxAlignmentOfField = 16;
currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField;
const arrayBuffer = new ArrayBuffer(currentOffset);
programUniforms.forEach((v, i) => {

View file

@ -325,6 +325,20 @@ export const sumVector = (name: string, components: number) => {
return name;
};
/**
* A helper function that returns uniform element at index.
* @param name - the name of uniform element.
* @param index - the index of uniform element.
* @param length - the length of uniform element.
*/
export const getUniformElementAt = (name: string, index: number|string, length: number): string => {
if (typeof (index) === 'string') {
return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name;
} else {
return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name;
}
};
/**
* A helper function to get a IndicesHelper for a given input or output.
*
@ -362,11 +376,12 @@ const createIndicesHelper =
const uniformPrefix = useUniform ? 'uniforms.' : '';
const shape = `${uniformPrefix}${name}_shape`;
const strides = `${uniformPrefix}${name}_strides`;
let o2iSnippet = '';
for (let i = 0; i < rank - 1; i++) {
o2iSnippet += `
let dim${i} = current / ${strides}[${i}];
let rest${i} = current % ${strides}[${i}];
let dim${i} = current / ${getUniformElementAt(strides, i, rank)};
let rest${i} = current % ${getUniformElementAt(strides, i, rank)};
indices[${i}] = dim${i};
current = rest${i};
`;
@ -389,7 +404,7 @@ const createIndicesHelper =
const offsets: string[] = [];
if (rank >= 2) {
for (let i = rank - 1; i >= 0; i--) {
offsets.push(`${strides}[${i}] * (indices[${i}])`);
offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`);
}
}
@ -660,7 +675,8 @@ export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
export type UniformsArrayType = Array<{name: string; type: string}>;
export type UniformDataElementType = 'u32'|'f32'|'i32';
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
/**
* A ShaderHelper is a helper class for generating WGSL code.
@ -714,8 +730,9 @@ export interface ShaderHelper {
*
* @param name - the name of the uniform.
* @param type - the type of the uniform.
* @param length - the length of the uniform, default to 1 when it is not provided.
*/
registerUniform(name: string, type: string): ShaderHelper;
registerUniform(name: string, type: string, length?: number): ShaderHelper;
/**
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
@ -769,10 +786,10 @@ class ShaderHelperImpl implements ShaderHelper {
private appendVariableUniforms(variable: IndicesHelper): void {
if (variable.rank !== 0) {
if (variable.shape.startsWith('uniforms.')) {
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank});
}
if (variable.strides.startsWith('uniforms.')) {
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank});
}
}
}
@ -808,8 +825,8 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}
registerUniform(name: string, type: string): ShaderHelper {
this.uniforms.push({name, type});
registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper {
this.uniforms.push({name, type, length});
return this;
}
@ -827,8 +844,13 @@ class ShaderHelperImpl implements ShaderHelper {
}
const uniformSnippets: string[] = [];
for (const {name, type} of this.uniforms) {
uniformSnippets.push(`${name}:${type}`);
for (const {name, type, length} of this.uniforms) {
if (length && length > 4) {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
} else {
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
uniformSnippets.push(`${name}:${typeTemp}`);
}
}
return `
@ -872,5 +894,5 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
return dims;
};
// TODO: remove this limitation once >4D dims are supported by uniform.
export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
// TODO: remove this when all related uses have been removed.
export const enableShapesUniforms = (_rank: number): boolean => true;

View file

@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
@ -77,20 +77,15 @@ const fixStartEndValues =
};
const calculateInputIndicesImpl =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
enableInputShapeUniforms: boolean): string =>
`fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
var inputIndices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
let input_shape_i = ${
enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
let steps_i = ${
enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
let signs_i = ${
enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
let starts_i = ${
enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)};
let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)};
let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)};
let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)};
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps_i + starts_i + carry;
carry = inputIndex / input_shape_i;
@ -145,47 +140,29 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
}
});
// Output rank is expected to be less than or equal to the input rank.
const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;
const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;
const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [];
const uniforms: UniformsArrayType = [];
if (enableShapeUniforms) {
uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}<u32>` : 'u32'});
uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}<i32>` : 'i32'});
uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}<u32>` : 'u32'});
programUniforms.push({type: 'uint32', data: starts});
programUniforms.push({type: 'int32', data: signs});
programUniforms.push({type: 'uint32', data: steps});
}
uniforms.push({name: 'outputSize', type: 'u32'});
programUniforms.push({type: 'uint32', data: outputSize});
if (enableShapeUniforms) {
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
}
const uniforms: UniformsArrayType = [
{name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length},
{name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length}
];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs},
{type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape)
];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${enableShapeUniforms ? '' : [
`const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});`,
`const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});`,
`const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});`,
`const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});`
].join('\n')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
@ -194,11 +171,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
}`;
return {
name: 'Slice',
shaderCache: {
hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
`${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
},
shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']},
getShaderSource,
getRunData: () => ({
outputs: [outputTensorInfo],