mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[js/webgpu] Add multidimensional(>4) uniform support (#18546)
This change removes the check of enableShapesUniforms. When all uses of this are removed, enableShapesUniforms can be removed too.
This commit is contained in:
parent
73a2eb82eb
commit
73d9b03509
3 changed files with 65 additions and 95 deletions
|
|
@ -338,51 +338,26 @@ export class WebGpuBackend {
|
|||
let uniformBufferBinding: GPUBindingResource|undefined;
|
||||
if (programUniforms) {
|
||||
let currentOffset = 0;
|
||||
let preLength = 0;
|
||||
const offsets: number[] = [];
|
||||
let maxAlignmentOfField = 1;
|
||||
|
||||
programUniforms.forEach(v => {
|
||||
const data = typeof v.data === 'number' ? [v.data] : v.data;
|
||||
if (data.length === 0) {
|
||||
return;
|
||||
}
|
||||
// https://www.w3.org/TR/WGSL/#alignof
|
||||
let baseAlignment: number;
|
||||
switch (data.length) {
|
||||
case 1:
|
||||
baseAlignment = 4;
|
||||
break;
|
||||
case 2:
|
||||
baseAlignment = 8;
|
||||
break;
|
||||
case 3:
|
||||
baseAlignment = 16;
|
||||
break;
|
||||
case 4:
|
||||
baseAlignment = 16;
|
||||
break;
|
||||
case 5:
|
||||
baseAlignment = 16;
|
||||
break;
|
||||
case 6:
|
||||
baseAlignment = 16;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`unsupported data length: ${data.length}`);
|
||||
}
|
||||
|
||||
if (preLength === 5 || preLength === 6) {
|
||||
baseAlignment = 16;
|
||||
}
|
||||
if (baseAlignment > maxAlignmentOfField) {
|
||||
maxAlignmentOfField = baseAlignment;
|
||||
}
|
||||
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
|
||||
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
|
||||
preLength = data.length;
|
||||
offsets.push(currentOffset);
|
||||
currentOffset += data.length * 4;
|
||||
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
|
||||
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
|
||||
// SizeOf(vec4<i32|u32|f32>).
|
||||
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
|
||||
});
|
||||
|
||||
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
|
||||
// maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16.
|
||||
const maxAlignmentOfField = 16;
|
||||
currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField;
|
||||
const arrayBuffer = new ArrayBuffer(currentOffset);
|
||||
programUniforms.forEach((v, i) => {
|
||||
|
|
|
|||
|
|
@ -325,6 +325,20 @@ export const sumVector = (name: string, components: number) => {
|
|||
return name;
|
||||
};
|
||||
|
||||
/**
|
||||
* A helper function that returns uniform element at index.
|
||||
* @param name - the name of uniform element.
|
||||
* @param index - the index of uniform element.
|
||||
* @param length - the length of uniform element.
|
||||
*/
|
||||
export const getUniformElementAt = (name: string, index: number|string, length: number): string => {
|
||||
if (typeof (index) === 'string') {
|
||||
return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name;
|
||||
} else {
|
||||
return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A helper function to get a IndicesHelper for a given input or output.
|
||||
*
|
||||
|
|
@ -362,11 +376,12 @@ const createIndicesHelper =
|
|||
const uniformPrefix = useUniform ? 'uniforms.' : '';
|
||||
const shape = `${uniformPrefix}${name}_shape`;
|
||||
const strides = `${uniformPrefix}${name}_strides`;
|
||||
|
||||
let o2iSnippet = '';
|
||||
for (let i = 0; i < rank - 1; i++) {
|
||||
o2iSnippet += `
|
||||
let dim${i} = current / ${strides}[${i}];
|
||||
let rest${i} = current % ${strides}[${i}];
|
||||
let dim${i} = current / ${getUniformElementAt(strides, i, rank)};
|
||||
let rest${i} = current % ${getUniformElementAt(strides, i, rank)};
|
||||
indices[${i}] = dim${i};
|
||||
current = rest${i};
|
||||
`;
|
||||
|
|
@ -389,7 +404,7 @@ const createIndicesHelper =
|
|||
const offsets: string[] = [];
|
||||
if (rank >= 2) {
|
||||
for (let i = rank - 1; i >= 0; i--) {
|
||||
offsets.push(`${strides}[${i}] * (indices[${i}])`);
|
||||
offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -660,7 +675,8 @@ export const internalVariable =
|
|||
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
|
||||
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
|
||||
|
||||
export type UniformsArrayType = Array<{name: string; type: string}>;
|
||||
export type UniformDataElementType = 'u32'|'f32'|'i32';
|
||||
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
|
||||
|
||||
/**
|
||||
* A ShaderHelper is a helper class for generating WGSL code.
|
||||
|
|
@ -714,8 +730,9 @@ export interface ShaderHelper {
|
|||
*
|
||||
* @param name - the name of the uniform.
|
||||
* @param type - the type of the uniform.
|
||||
* @param length - the length of the uniform, default to 1 when it is not provided.
|
||||
*/
|
||||
registerUniform(name: string, type: string): ShaderHelper;
|
||||
registerUniform(name: string, type: string, length?: number): ShaderHelper;
|
||||
|
||||
/**
|
||||
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
|
||||
|
|
@ -769,10 +786,10 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
private appendVariableUniforms(variable: IndicesHelper): void {
|
||||
if (variable.rank !== 0) {
|
||||
if (variable.shape.startsWith('uniforms.')) {
|
||||
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
|
||||
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank});
|
||||
}
|
||||
if (variable.strides.startsWith('uniforms.')) {
|
||||
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
|
||||
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -808,8 +825,8 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
return this;
|
||||
}
|
||||
|
||||
registerUniform(name: string, type: string): ShaderHelper {
|
||||
this.uniforms.push({name, type});
|
||||
registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper {
|
||||
this.uniforms.push({name, type, length});
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
@ -827,8 +844,13 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
}
|
||||
|
||||
const uniformSnippets: string[] = [];
|
||||
for (const {name, type} of this.uniforms) {
|
||||
uniformSnippets.push(`${name}:${type}`);
|
||||
for (const {name, type, length} of this.uniforms) {
|
||||
if (length && length > 4) {
|
||||
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
|
||||
} else {
|
||||
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
|
||||
uniformSnippets.push(`${name}:${typeTemp}`);
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
|
|
@ -872,5 +894,5 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
|
|||
return dims;
|
||||
};
|
||||
|
||||
// TODO: remove this limitation once >4D dims are supported by uniform.
|
||||
export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
|
||||
// TODO: remove this when all related uses have been removed.
|
||||
export const enableShapesUniforms = (_rank: number): boolean => true;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
|
||||
export interface SliceAttributes extends AttributeWithCacheKey {
|
||||
readonly starts: number[];
|
||||
|
|
@ -77,20 +77,15 @@ const fixStartEndValues =
|
|||
};
|
||||
|
||||
const calculateInputIndicesImpl =
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
|
||||
enableInputShapeUniforms: boolean): string =>
|
||||
`fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
|
||||
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
|
||||
var inputIndices: ${input.type.indices};
|
||||
var carry = 0u;
|
||||
for (var i = ${inputShape.length}; i >= 0; i--) {
|
||||
let input_shape_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
|
||||
let steps_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
|
||||
let signs_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
|
||||
let starts_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
|
||||
let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)};
|
||||
let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)};
|
||||
let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)};
|
||||
let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)};
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
|
||||
var inputIndex = outputIndex * steps_i + starts_i + carry;
|
||||
carry = inputIndex / input_shape_i;
|
||||
|
|
@ -145,47 +140,29 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
|
|||
}
|
||||
});
|
||||
// Output rank is expected to be less than or equal to the input rank.
|
||||
const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
|
||||
const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;
|
||||
|
||||
const outputShape = inputShape.slice(0);
|
||||
axes.forEach((axis, _) => {
|
||||
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
|
||||
});
|
||||
const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;
|
||||
|
||||
const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
|
||||
const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] = [];
|
||||
const uniforms: UniformsArrayType = [];
|
||||
if (enableShapeUniforms) {
|
||||
uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}<u32>` : 'u32'});
|
||||
uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}<i32>` : 'i32'});
|
||||
uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}<u32>` : 'u32'});
|
||||
programUniforms.push({type: 'uint32', data: starts});
|
||||
programUniforms.push({type: 'int32', data: signs});
|
||||
programUniforms.push({type: 'uint32', data: steps});
|
||||
}
|
||||
uniforms.push({name: 'outputSize', type: 'u32'});
|
||||
programUniforms.push({type: 'uint32', data: outputSize});
|
||||
if (enableShapeUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length},
|
||||
{name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length}
|
||||
];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs},
|
||||
{type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
|
||||
${enableShapeUniforms ? '' : [
|
||||
`const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});`,
|
||||
`const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});`,
|
||||
`const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});`,
|
||||
`const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});`
|
||||
].join('\n')}
|
||||
|
||||
${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
|
||||
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
|
|
@ -194,11 +171,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
|
|||
}`;
|
||||
return {
|
||||
name: 'Slice',
|
||||
shaderCache: {
|
||||
hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
|
||||
`${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
|
||||
inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
|
||||
},
|
||||
shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']},
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [outputTensorInfo],
|
||||
|
|
|
|||
Loading…
Reference in a new issue