[js/webgpu] make IndicesHelper implementation implicit (#17193)

### Description
This change makes it no longer required to call indicesHelper.impl() in
shader code.
This commit is contained in:
Yulong Wang 2023-08-23 14:41:35 -07:00 committed by GitHub
parent aed7c6ffc7
commit 8b18d48c7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 98 additions and 114 deletions

View file

@ -207,10 +207,7 @@ const createConvTranspose2DOpProgramShaderSource =
${output.setByOffset('global_idx', 'dotProd')};
`;
return `
${w.impl('indicesToOffset', 'get')}
${dy.impl('indicesToOffset', 'get')}
${output.impl('offsetToIndices')}
const shader = `
${declareFunctions}
${declareInputs.join('\n')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;
@ -234,6 +231,12 @@ const createConvTranspose2DOpProgramShaderSource =
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)};
${isVec4 ? codeSnippet4 : codeSnippet}}`;
// TODO: use shaderHelper.declareVariables() to declare variables so that those impl() calls can be removed.
return ` ${w.impl()}
${dy.impl()}
${output.impl()}
${shader}`;
};
export const createConvTranspose2DProgramInfo =

View file

@ -50,8 +50,6 @@ const createBinaryOpProgramShader =
};
broadcastImpl = `
${output.impl('offsetToIndices')}
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsA)};
}

View file

@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util';
**/
export const WORKGROUP_SIZE = 64;
interface IndicesHelperImplementations {
/**
* implementation of `offsetToIndices` function.
*/
readonly offsetToIndices: string;
/**
* implementation of `indicesToOffset` function.
*/
readonly indicesToOffset: string;
/**
* implementation of `set`, `setByIndices` and `setByOffset` function.
*/
readonly set: string;
/**
* implementation of `get`, `getByIndices` and `getByOffset` function.
*/
readonly get: string;
}
interface IndicesHelperTypes {
/**
* WGSL type of indices expression
@ -96,12 +74,10 @@ interface IndicesHelperTypes {
*/
export interface IndicesHelper {
/**
* get WGSL code of function implementation for the util functions
* get WGSL code of function implementation for the util functions.
*
* @param functions - a list of function names to get implementation for. If not specified, all functions will be
* returned.
*/
readonly impl: (...functions: ReadonlyArray<keyof IndicesHelperImplementations>) => string;
readonly impl: () => string;
/**
* get type info
@ -215,9 +191,12 @@ export interface IndicesHelper {
readonly shape: readonly number[];
}
const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => {
const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
// return type is [ storage type, runtime type ] or a single string for both
switch (type) {
// TODO: enable after "shader-f16" WSGL extension release
// case DataType.float16:
// return components > 1 ? `vec${components}<f16>` : 'f16';
case DataType.float:
return components > 1 ? `vec${components}<f32>` : 'f32';
case DataType.int32:
@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st
}
};
export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => {
const mappedType = getWgslMappedType(type, components);
return typeof mappedType === 'string' ? mappedType : mappedType[0];
};
/**
* A helper function to get a IndicesHelper for a given input or output.
*
@ -260,13 +244,22 @@ const createIndicesHelper =
components: 1|2|3|4): IndicesHelper => {
const rank = shape.length;
const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}<u32>` : `array<u32, ${rank}>`;
const mappedType = getWgslValueType(tensorType, components);
const mappedType = getWgslMappedType(tensorType, components);
const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1];
const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0];
const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType};
const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`;
const implementationUsed = {
offsetToIndices: false,
indicesToOffset: false,
set: false,
setByIndices: false,
get: false,
getByIndices: false,
};
const strides = ShapeUtil.computeStrides(shape);
let o2iSnippet = '';
for (let i = 0; i < rank - 1; i++) {
@ -287,7 +280,10 @@ const createIndicesHelper =
return indices;
}`;
const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
const offsetToIndices = (varOffset: string) => {
implementationUsed.offsetToIndices = true;
return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
};
const offsets: string[] = [];
if (rank >= 2) {
@ -301,7 +297,10 @@ const createIndicesHelper =
return ${offsets.join('+')};
}`;
const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
const indicesToOffset = (varIndices: string) => {
implementationUsed.indicesToOffset = true;
return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
};
const indices = (...init: ReadonlyArray<number|string>) =>
rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;
@ -357,17 +356,18 @@ const createIndicesHelper =
}
})();
const getByIndicesImplementation = rank < 2 ? '' : `
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
return ${name}[i2o_${name}(indices)];
}`;
const getImplementation = rank < 2 ? '' : (() => {
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
const dims = shape.map((_, i) => `d${i}`).join(', ');
return `
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
return ${name}[i2o_${name}(indices)];
}
fn get_${name}(${params}) -> ${valueType} {
return get_${name}ByIndices(${indices(dims)});
}
`;
}`;
})();
const get = (...indices: ReadonlyArray<number|string>) => {
@ -376,14 +376,16 @@ const createIndicesHelper =
}
const normalizedIndices = indices.map(normalizeDim).join(',');
const funcName = `get_${name}`;
if (rank === 0) {
return getByOffset('0u');
} else if (rank === 1) {
return getByOffset(normalizedIndices[0]);
} else {
return `${funcName}(${normalizedIndices})`;
implementationUsed.get = true;
implementationUsed.getByIndices = true;
implementationUsed.indicesToOffset = true;
return `get_${name}(${normalizedIndices})`;
}
};
@ -391,21 +393,24 @@ const createIndicesHelper =
if (rank < 2) {
return getByOffset(varIndices);
} else {
implementationUsed.getByIndices = true;
implementationUsed.indicesToOffset = true;
return `get_${name}ByIndices(${varIndices})`;
}
};
const setByIndicesImplementation = rank < 2 ? '' : `
fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
${setByOffset(`i2o_${name}(indices)`, 'value')}
}`;
const setImplementation = rank < 2 ? '' : (() => {
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
const dims = shape.map((_, i) => `d${i}`).join(', ');
return `
fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
${setByOffset(`i2o_${name}(indices)`, 'value')}
}
fn set_${name}(${params}, value: ${valueType}) {
set_${name}ByIndices(${indices(dims)}, value);
}
`;
}`;
})();
const set = (...indicesAndValue: ReadonlyArray<number|string>) => {
@ -424,6 +429,9 @@ const createIndicesHelper =
} else if (rank === 1) {
return setByOffset(normalizedIndices[0], value);
} else {
implementationUsed.set = true;
implementationUsed.setByIndices = true;
implementationUsed.indicesToOffset = true;
return `set_${name}(${normalizedIndices}, ${value})`;
}
};
@ -432,32 +440,34 @@ const createIndicesHelper =
if (rank < 2) {
return setByOffset(varIndices, value);
} else {
implementationUsed.setByIndices = true;
implementationUsed.indicesToOffset = true;
return `set_${name}ByIndices(${varIndices}, ${value});`;
}
};
const funcImpls = {
offsetToIndices: offsetToIndicesImplementation,
indicesToOffset: indicesToOffsetImplementation,
set: setImplementation,
get: getImplementation,
};
const impl = (...functions: Array<keyof IndicesHelperImplementations>) => {
const impl = () => {
const impls = [];
if (functions.length === 0) {
functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get');
if (implementationUsed.offsetToIndices) {
impls.push(offsetToIndicesImplementation);
}
for (const func of functions) {
const impl = funcImpls[func];
if (impl === undefined) {
throw new Error(`unknown function ${func}`);
} else {
impls.push(impl);
}
if (implementationUsed.indicesToOffset) {
impls.push(indicesToOffsetImplementation);
}
if (implementationUsed.set) {
impls.push(setImplementation);
}
if (implementationUsed.setByIndices) {
impls.push(setByIndicesImplementation);
}
if (implementationUsed.get) {
impls.push(getImplementation);
}
if (implementationUsed.getByIndices) {
impls.push(getByIndicesImplementation);
}
return impls.join('\n');
};
impl.toString = () => impl();
return {
impl,
@ -552,6 +562,11 @@ export interface ShaderHelper {
* @param variables - an array of IndicesHelper for the variables.
*/
declareVariables(...variables: IndicesHelper[]): string;
/**
* Get additional implementation that needs to be added to the shader source.
*/
readonly additionalImplementations: string;
}
class ShaderHelperImpl implements ShaderHelper {
@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper {
}
declareVariable(variable: IndicesHelper, bindingIndex: number): string {
this.indicesHelpers.push(variable);
const access = variable.usage === 'input' ? 'read' : 'read_write';
const storageType = variable.type.storage;
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper {
let i = 0;
return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n');
}
private indicesHelpers: IndicesHelper[] = [];
get additionalImplementations(): string {
return this.indicesHelpers.map(i => i.impl()).join('\n');
}
}
export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper =>

View file

@ -109,9 +109,6 @@ const createConcatProgramInfo =
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')}
${output.impl('offsetToIndices')}
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateInputIndexImpl(sizeInConcatAxis.length)}

View file

@ -47,9 +47,6 @@ const createGroupedConvProgramInfo =
${shaderHelper.declareVariables(...inputVars, output)}
${activationFunction}
${output.impl('offsetToIndices')}
${x.impl('indicesToOffset', 'get')}
${w.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}

View file

@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};

View file

@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
@ -45,7 +45,7 @@ const createInstanceNormProgramInfo =
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
}
const dataType = tensorTypeToWsglType(inputs[0].dataType);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const C: u32 = ${C};
@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo =
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
const dataType = tensorTypeToWsglType(inputs[0].dataType);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const normCount = C * N;
const getShaderSource = (shaderHelper: ShaderHelper) => `

View file

@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface LayerNormAttributes extends AttributeWithCacheKey {
axis: number;
@ -54,7 +54,7 @@ const createLayerNormProgramInfo =
}
}
const dataType = tensorTypeToWsglType(inputs[0].dataType);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;

View file

@ -128,9 +128,6 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
const poolingCode = `
${shaderHelper.declareVariables(x, output)}
${output.impl('offsetToIndices')}
${x.impl('indicesToOffset')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
@ -179,9 +176,6 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
const poolingCode = `
${shaderHelper.declareVariables(x, output)}
${output.impl('offsetToIndices')}
${x.impl('indicesToOffset')}
const pads = array<u32, ${padsRank}>(${attributes.pads.map(i => `${i}u`).join(',')});
const inputDims = array<u32, ${rank}>(${inputDims.map(i => `${i}u`).join(',')});
const kernelStrides = array<u32, ${stridesRank}>(${kernelStrides.map(i => `${i}u`).join(',')});

View file

@ -85,9 +85,6 @@ export const createReduceProgramInfo =
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var inputIndices: ${input.type.indices};

View file

@ -484,8 +484,6 @@ const createResizeProgramInfo =
}
})()};
${shaderHelper.declareVariables(input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
if (${noScale}) {

View file

@ -153,8 +153,6 @@ const createSliceProgramInfo =
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}

View file

@ -87,8 +87,6 @@ const createSplitProgramInfo =
const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, ...outputs)}
${input.impl('indicesToOffset', 'offsetToIndices', 'get')}
${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')}
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateOutputIndexImpl(sizeInConcatAxis.length)}
${writeBufferDataImpl(outputs)}

View file

@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
${shaderHelper.declareVariables(input, output)}
${permFunctionBody(perm, rank, input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}

View file

@ -114,7 +114,9 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
const device = this.backend.device;
const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize));
const shaderHelper = createShaderHelper(normalizedDispatchGroupSize);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({code});
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);

View file

@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
throw new Error(`unsupported logging level: ${logLevel}`);
}
};
export const tensorTypeToWsglType = (type: DataType) => {
switch (type) {
case DataType.float:
return 'f32';
// TODO: enable after "shader-f16" WSGL extension release
// case DataType.float16:
// return 'f16';
case DataType.int32:
return 'i32';
case DataType.uint32:
return 'u32';
default:
throw new Error(`Unsupported type: ${type}`);
}
};