mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[js/webgpu] make IndicesHelper implementation implicit (#17193)
### Description This change makes it no longer required to call indicesHelper.impl() in shader code.
This commit is contained in:
parent
aed7c6ffc7
commit
8b18d48c7c
16 changed files with 98 additions and 114 deletions
|
|
@ -207,10 +207,7 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
${output.setByOffset('global_idx', 'dotProd')};
|
||||
`;
|
||||
|
||||
return `
|
||||
${w.impl('indicesToOffset', 'get')}
|
||||
${dy.impl('indicesToOffset', 'get')}
|
||||
${output.impl('offsetToIndices')}
|
||||
const shader = `
|
||||
${declareFunctions}
|
||||
${declareInputs.join('\n')}
|
||||
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;
|
||||
|
|
@ -234,6 +231,12 @@ const createConvTranspose2DOpProgramShaderSource =
|
|||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)};
|
||||
${isVec4 ? codeSnippet4 : codeSnippet}}`;
|
||||
|
||||
// TODO: use shaderHelper.declareVariables() to declare variables so that those impl() calls can be removed.
|
||||
return ` ${w.impl()}
|
||||
${dy.impl()}
|
||||
${output.impl()}
|
||||
${shader}`;
|
||||
};
|
||||
|
||||
export const createConvTranspose2DProgramInfo =
|
||||
|
|
|
|||
|
|
@ -50,8 +50,6 @@ const createBinaryOpProgramShader =
|
|||
};
|
||||
|
||||
broadcastImpl = `
|
||||
${output.impl('offsetToIndices')}
|
||||
|
||||
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
|
||||
return ${calcOffsetImpl(dimsA)};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util';
|
|||
**/
|
||||
export const WORKGROUP_SIZE = 64;
|
||||
|
||||
interface IndicesHelperImplementations {
|
||||
/**
|
||||
* implementation of `offsetToIndices` function.
|
||||
*/
|
||||
readonly offsetToIndices: string;
|
||||
|
||||
/**
|
||||
* implementation of `indicesToOffset` function.
|
||||
*/
|
||||
readonly indicesToOffset: string;
|
||||
|
||||
/**
|
||||
* implementation of `set`, `setByIndices` and `setByOffset` function.
|
||||
*/
|
||||
readonly set: string;
|
||||
|
||||
/**
|
||||
* implementation of `get`, `getByIndices` and `getByOffset` function.
|
||||
*/
|
||||
readonly get: string;
|
||||
}
|
||||
|
||||
interface IndicesHelperTypes {
|
||||
/**
|
||||
* WGSL type of indices expression
|
||||
|
|
@ -96,12 +74,10 @@ interface IndicesHelperTypes {
|
|||
*/
|
||||
export interface IndicesHelper {
|
||||
/**
|
||||
* get WGSL code of function implementation for the util functions
|
||||
* get WGSL code of function implementation for the util functions.
|
||||
*
|
||||
* @param functions - a list of function names to get implementation for. If not specified, all functions will be
|
||||
* returned.
|
||||
*/
|
||||
readonly impl: (...functions: ReadonlyArray<keyof IndicesHelperImplementations>) => string;
|
||||
readonly impl: () => string;
|
||||
|
||||
/**
|
||||
* get type info
|
||||
|
|
@ -215,9 +191,12 @@ export interface IndicesHelper {
|
|||
readonly shape: readonly number[];
|
||||
}
|
||||
|
||||
const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => {
|
||||
const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
|
||||
// return type is [ storage type, runtime type ] or a single string for both
|
||||
switch (type) {
|
||||
// TODO: enable after "shader-f16" WSGL extension release
|
||||
// case DataType.float16:
|
||||
// return components > 1 ? `vec${components}<f16>` : 'f16';
|
||||
case DataType.float:
|
||||
return components > 1 ? `vec${components}<f32>` : 'f32';
|
||||
case DataType.int32:
|
||||
|
|
@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st
|
|||
}
|
||||
};
|
||||
|
||||
export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => {
|
||||
const mappedType = getWgslMappedType(type, components);
|
||||
return typeof mappedType === 'string' ? mappedType : mappedType[0];
|
||||
};
|
||||
|
||||
/**
|
||||
* A helper function to get a IndicesHelper for a given input or output.
|
||||
*
|
||||
|
|
@ -260,13 +244,22 @@ const createIndicesHelper =
|
|||
components: 1|2|3|4): IndicesHelper => {
|
||||
const rank = shape.length;
|
||||
const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}<u32>` : `array<u32, ${rank}>`;
|
||||
const mappedType = getWgslValueType(tensorType, components);
|
||||
const mappedType = getWgslMappedType(tensorType, components);
|
||||
const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1];
|
||||
const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0];
|
||||
const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType};
|
||||
|
||||
const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`;
|
||||
|
||||
const implementationUsed = {
|
||||
offsetToIndices: false,
|
||||
indicesToOffset: false,
|
||||
set: false,
|
||||
setByIndices: false,
|
||||
get: false,
|
||||
getByIndices: false,
|
||||
};
|
||||
|
||||
const strides = ShapeUtil.computeStrides(shape);
|
||||
let o2iSnippet = '';
|
||||
for (let i = 0; i < rank - 1; i++) {
|
||||
|
|
@ -287,7 +280,10 @@ const createIndicesHelper =
|
|||
return indices;
|
||||
}`;
|
||||
|
||||
const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
|
||||
const offsetToIndices = (varOffset: string) => {
|
||||
implementationUsed.offsetToIndices = true;
|
||||
return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
|
||||
};
|
||||
|
||||
const offsets: string[] = [];
|
||||
if (rank >= 2) {
|
||||
|
|
@ -301,7 +297,10 @@ const createIndicesHelper =
|
|||
return ${offsets.join('+')};
|
||||
}`;
|
||||
|
||||
const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
|
||||
const indicesToOffset = (varIndices: string) => {
|
||||
implementationUsed.indicesToOffset = true;
|
||||
return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
|
||||
};
|
||||
|
||||
const indices = (...init: ReadonlyArray<number|string>) =>
|
||||
rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;
|
||||
|
|
@ -357,17 +356,18 @@ const createIndicesHelper =
|
|||
}
|
||||
})();
|
||||
|
||||
const getByIndicesImplementation = rank < 2 ? '' : `
|
||||
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
|
||||
return ${name}[i2o_${name}(indices)];
|
||||
}`;
|
||||
|
||||
const getImplementation = rank < 2 ? '' : (() => {
|
||||
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
|
||||
const dims = shape.map((_, i) => `d${i}`).join(', ');
|
||||
return `
|
||||
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
|
||||
return ${name}[i2o_${name}(indices)];
|
||||
}
|
||||
fn get_${name}(${params}) -> ${valueType} {
|
||||
return get_${name}ByIndices(${indices(dims)});
|
||||
}
|
||||
`;
|
||||
}`;
|
||||
})();
|
||||
|
||||
const get = (...indices: ReadonlyArray<number|string>) => {
|
||||
|
|
@ -376,14 +376,16 @@ const createIndicesHelper =
|
|||
}
|
||||
|
||||
const normalizedIndices = indices.map(normalizeDim).join(',');
|
||||
const funcName = `get_${name}`;
|
||||
|
||||
if (rank === 0) {
|
||||
return getByOffset('0u');
|
||||
} else if (rank === 1) {
|
||||
return getByOffset(normalizedIndices[0]);
|
||||
} else {
|
||||
return `${funcName}(${normalizedIndices})`;
|
||||
implementationUsed.get = true;
|
||||
implementationUsed.getByIndices = true;
|
||||
implementationUsed.indicesToOffset = true;
|
||||
return `get_${name}(${normalizedIndices})`;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -391,21 +393,24 @@ const createIndicesHelper =
|
|||
if (rank < 2) {
|
||||
return getByOffset(varIndices);
|
||||
} else {
|
||||
implementationUsed.getByIndices = true;
|
||||
implementationUsed.indicesToOffset = true;
|
||||
return `get_${name}ByIndices(${varIndices})`;
|
||||
}
|
||||
};
|
||||
|
||||
const setByIndicesImplementation = rank < 2 ? '' : `
|
||||
fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
|
||||
${setByOffset(`i2o_${name}(indices)`, 'value')}
|
||||
}`;
|
||||
|
||||
const setImplementation = rank < 2 ? '' : (() => {
|
||||
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
|
||||
const dims = shape.map((_, i) => `d${i}`).join(', ');
|
||||
return `
|
||||
fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
|
||||
${setByOffset(`i2o_${name}(indices)`, 'value')}
|
||||
}
|
||||
fn set_${name}(${params}, value: ${valueType}) {
|
||||
set_${name}ByIndices(${indices(dims)}, value);
|
||||
}
|
||||
`;
|
||||
}`;
|
||||
})();
|
||||
|
||||
const set = (...indicesAndValue: ReadonlyArray<number|string>) => {
|
||||
|
|
@ -424,6 +429,9 @@ const createIndicesHelper =
|
|||
} else if (rank === 1) {
|
||||
return setByOffset(normalizedIndices[0], value);
|
||||
} else {
|
||||
implementationUsed.set = true;
|
||||
implementationUsed.setByIndices = true;
|
||||
implementationUsed.indicesToOffset = true;
|
||||
return `set_${name}(${normalizedIndices}, ${value})`;
|
||||
}
|
||||
};
|
||||
|
|
@ -432,32 +440,34 @@ const createIndicesHelper =
|
|||
if (rank < 2) {
|
||||
return setByOffset(varIndices, value);
|
||||
} else {
|
||||
implementationUsed.setByIndices = true;
|
||||
implementationUsed.indicesToOffset = true;
|
||||
return `set_${name}ByIndices(${varIndices}, ${value});`;
|
||||
}
|
||||
};
|
||||
|
||||
const funcImpls = {
|
||||
offsetToIndices: offsetToIndicesImplementation,
|
||||
indicesToOffset: indicesToOffsetImplementation,
|
||||
set: setImplementation,
|
||||
get: getImplementation,
|
||||
};
|
||||
const impl = (...functions: Array<keyof IndicesHelperImplementations>) => {
|
||||
const impl = () => {
|
||||
const impls = [];
|
||||
if (functions.length === 0) {
|
||||
functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get');
|
||||
if (implementationUsed.offsetToIndices) {
|
||||
impls.push(offsetToIndicesImplementation);
|
||||
}
|
||||
for (const func of functions) {
|
||||
const impl = funcImpls[func];
|
||||
if (impl === undefined) {
|
||||
throw new Error(`unknown function ${func}`);
|
||||
} else {
|
||||
impls.push(impl);
|
||||
}
|
||||
if (implementationUsed.indicesToOffset) {
|
||||
impls.push(indicesToOffsetImplementation);
|
||||
}
|
||||
if (implementationUsed.set) {
|
||||
impls.push(setImplementation);
|
||||
}
|
||||
if (implementationUsed.setByIndices) {
|
||||
impls.push(setByIndicesImplementation);
|
||||
}
|
||||
if (implementationUsed.get) {
|
||||
impls.push(getImplementation);
|
||||
}
|
||||
if (implementationUsed.getByIndices) {
|
||||
impls.push(getByIndicesImplementation);
|
||||
}
|
||||
return impls.join('\n');
|
||||
};
|
||||
impl.toString = () => impl();
|
||||
|
||||
return {
|
||||
impl,
|
||||
|
|
@ -552,6 +562,11 @@ export interface ShaderHelper {
|
|||
* @param variables - an array of IndicesHelper for the variables.
|
||||
*/
|
||||
declareVariables(...variables: IndicesHelper[]): string;
|
||||
|
||||
/**
|
||||
* Get additional implementation that needs to be added to the shader source.
|
||||
*/
|
||||
readonly additionalImplementations: string;
|
||||
}
|
||||
|
||||
class ShaderHelperImpl implements ShaderHelper {
|
||||
|
|
@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
}
|
||||
|
||||
declareVariable(variable: IndicesHelper, bindingIndex: number): string {
|
||||
this.indicesHelpers.push(variable);
|
||||
const access = variable.usage === 'input' ? 'read' : 'read_write';
|
||||
const storageType = variable.type.storage;
|
||||
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
|
||||
|
|
@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
let i = 0;
|
||||
return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n');
|
||||
}
|
||||
|
||||
private indicesHelpers: IndicesHelper[] = [];
|
||||
|
||||
get additionalImplementations(): string {
|
||||
return this.indicesHelpers.map(i => i.impl()).join('\n');
|
||||
}
|
||||
}
|
||||
|
||||
export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper =>
|
||||
|
|
|
|||
|
|
@ -109,9 +109,6 @@ const createConcatProgramInfo =
|
|||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(...inputVars, output)}
|
||||
|
||||
${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')}
|
||||
${output.impl('offsetToIndices')}
|
||||
|
||||
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
|
||||
${calculateInputIndexImpl(sizeInConcatAxis.length)}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,9 +47,6 @@ const createGroupedConvProgramInfo =
|
|||
${shaderHelper.declareVariables(...inputVars, output)}
|
||||
|
||||
${activationFunction}
|
||||
${output.impl('offsetToIndices')}
|
||||
${x.impl('indicesToOffset', 'get')}
|
||||
${w.impl('indicesToOffset', 'get')}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
|
|
|||
|
|
@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten
|
|||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const inputShape = ${input.indices(...inputShape)};
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
${output.impl('offsetToIndices')}
|
||||
${input.impl('indicesToOffset', 'get')}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
export interface InstanceNormAttributes extends AttributeWithCacheKey {
|
||||
epsilon: number;
|
||||
|
|
@ -45,7 +45,7 @@ const createInstanceNormProgramInfo =
|
|||
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
|
||||
}
|
||||
|
||||
const dataType = tensorTypeToWsglType(inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const C: u32 = ${C};
|
||||
|
|
@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo =
|
|||
const C = xShape[xShape.length - 1];
|
||||
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
|
||||
|
||||
const dataType = tensorTypeToWsglType(inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const normCount = C * N;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
export interface LayerNormAttributes extends AttributeWithCacheKey {
|
||||
axis: number;
|
||||
|
|
@ -54,7 +54,7 @@ const createLayerNormProgramInfo =
|
|||
}
|
||||
}
|
||||
|
||||
const dataType = tensorTypeToWsglType(inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const hasMeanDataOutput = outputCount > 1;
|
||||
const hasInvStdOutput = outputCount > 2;
|
||||
|
|
|
|||
|
|
@ -128,9 +128,6 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
|
|||
const poolingCode = `
|
||||
${shaderHelper.declareVariables(x, output)}
|
||||
|
||||
${output.impl('offsetToIndices')}
|
||||
${x.impl('indicesToOffset')}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
||||
|
|
@ -179,9 +176,6 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
|
|||
const poolingCode = `
|
||||
${shaderHelper.declareVariables(x, output)}
|
||||
|
||||
${output.impl('offsetToIndices')}
|
||||
${x.impl('indicesToOffset')}
|
||||
|
||||
const pads = array<u32, ${padsRank}>(${attributes.pads.map(i => `${i}u`).join(',')});
|
||||
const inputDims = array<u32, ${rank}>(${inputDims.map(i => `${i}u`).join(',')});
|
||||
const kernelStrides = array<u32, ${stridesRank}>(${kernelStrides.map(i => `${i}u`).join(',')});
|
||||
|
|
|
|||
|
|
@ -85,9 +85,6 @@ export const createReduceProgramInfo =
|
|||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
|
||||
${output.impl('offsetToIndices')}
|
||||
${input.impl('indicesToOffset')}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
var inputIndices: ${input.type.indices};
|
||||
|
|
|
|||
|
|
@ -484,8 +484,6 @@ const createResizeProgramInfo =
|
|||
}
|
||||
})()};
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
${output.impl('offsetToIndices')}
|
||||
${input.impl('indicesToOffset')}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
if (${noScale}) {
|
||||
|
|
|
|||
|
|
@ -153,8 +153,6 @@ const createSliceProgramInfo =
|
|||
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
|
||||
${output.impl('offsetToIndices')}
|
||||
${input.impl('indicesToOffset', 'get')}
|
||||
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
|
|
|||
|
|
@ -87,8 +87,6 @@ const createSplitProgramInfo =
|
|||
const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(input, ...outputs)}
|
||||
${input.impl('indicesToOffset', 'offsetToIndices', 'get')}
|
||||
${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')}
|
||||
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
|
||||
${calculateOutputIndexImpl(sizeInConcatAxis.length)}
|
||||
${writeBufferDataImpl(outputs)}
|
||||
|
|
|
|||
|
|
@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
${shaderHelper.declareVariables(input, output)}
|
||||
|
||||
${permFunctionBody(perm, rank, input, output)}
|
||||
${output.impl('offsetToIndices')}
|
||||
${input.impl('indicesToOffset', 'get')}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
|
|
|||
|
|
@ -114,7 +114,9 @@ export class ProgramManager {
|
|||
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
|
||||
const device = this.backend.device;
|
||||
|
||||
const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize));
|
||||
const shaderHelper = createShaderHelper(normalizedDispatchGroupSize);
|
||||
const userCode = programInfo.getShaderSource(shaderHelper);
|
||||
const code = `${shaderHelper.additionalImplementations}\n${userCode}`;
|
||||
const shaderModule = device.createShaderModule({code});
|
||||
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);
|
||||
|
||||
|
|
|
|||
|
|
@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
|
|||
throw new Error(`unsupported logging level: ${logLevel}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const tensorTypeToWsglType = (type: DataType) => {
|
||||
switch (type) {
|
||||
case DataType.float:
|
||||
return 'f32';
|
||||
// TODO: enable after "shader-f16" WSGL extension release
|
||||
// case DataType.float16:
|
||||
// return 'f16';
|
||||
case DataType.int32:
|
||||
return 'i32';
|
||||
case DataType.uint32:
|
||||
return 'u32';
|
||||
default:
|
||||
throw new Error(`Unsupported type: ${type}`);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue