From 8b18d48c7c1ac2dde474dc69c5c1f2e4510cf316 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 23 Aug 2023 14:41:35 -0700 Subject: [PATCH] [js/webgpu] make IndicesHelper implementation implicit (#17193) ### Description This change makes it no longer required to call indicesHelper.impl() in shader code. --- .../ops/3rd-party/conv_backprop_webgpu.ts | 11 +- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 2 - js/web/lib/wasm/jsep/webgpu/ops/common.ts | 140 ++++++++++-------- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 3 - .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 3 - js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 - .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 8 +- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 6 +- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 6 - js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 3 - js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 2 - js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 - js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 - js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 2 - .../lib/wasm/jsep/webgpu/program-manager.ts | 4 +- js/web/lib/wasm/wasm-common.ts | 16 -- 16 files changed, 98 insertions(+), 114 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 1d490aa902..9343d65c27 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -207,10 +207,7 @@ const createConvTranspose2DOpProgramShaderSource = ${output.setByOffset('global_idx', 'dotProd')}; `; - return ` - ${w.impl('indicesToOffset', 'get')} - ${dy.impl('indicesToOffset', 'get')} - ${output.impl('offsetToIndices')} + const shader = ` ${declareFunctions} ${declareInputs.join('\n')} @group(0) @binding(${declareInputs.length}) var result: array<${isVec4 ? 'vec4' : 'f32'}>; @@ -234,6 +231,12 @@ const createConvTranspose2DOpProgramShaderSource = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; + + // TODO: use shaderHelper.declareVariables() to declare variables so that those impl() calls can be removed. + return ` ${w.impl()} + ${dy.impl()} + ${output.impl()} + ${shader}`; }; export const createConvTranspose2DProgramInfo = diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 5f3d156466..02b978a381 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -50,8 +50,6 @@ const createBinaryOpProgramShader = }; broadcastImpl = ` - ${output.impl('offsetToIndices')} - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { return ${calcOffsetImpl(dimsA)}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index e64c749725..7da57bcb9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util'; **/ export const WORKGROUP_SIZE = 64; -interface IndicesHelperImplementations { - /** - * implementation of `offsetToIndices` function. - */ - readonly offsetToIndices: string; - - /** - * implementation of `indicesToOffset` function. - */ - readonly indicesToOffset: string; - - /** - * implementation of `set`, `setByIndices` and `setByOffset` function. - */ - readonly set: string; - - /** - * implementation of `get`, `getByIndices` and `getByOffset` function. - */ - readonly get: string; -} - interface IndicesHelperTypes { /** * WGSL type of indices expression @@ -96,12 +74,10 @@ interface IndicesHelperTypes { */ export interface IndicesHelper { /** - * get WGSL code of function implementation for the util functions + * get WGSL code of function implementation for the util functions. * - * @param functions - a list of function names to get implementation for. If not specified, all functions will be - * returned. */ - readonly impl: (...functions: ReadonlyArray) => string; + readonly impl: () => string; /** * get type info @@ -215,9 +191,12 @@ export interface IndicesHelper { readonly shape: readonly number[]; } -const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => { +const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { // return type is [ storage type, runtime type ] or a single string for both switch (type) { + // TODO: enable after "shader-f16" WSGL extension release + // case DataType.float16: + // return components > 1 ? `vec${components}` : 'f16'; case DataType.float: return components > 1 ? `vec${components}` : 'f32'; case DataType.int32: @@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st } }; +export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => { + const mappedType = getWgslMappedType(type, components); + return typeof mappedType === 'string' ? mappedType : mappedType[0]; +}; + /** * A helper function to get a IndicesHelper for a given input or output. * @@ -260,13 +244,22 @@ const createIndicesHelper = components: 1|2|3|4): IndicesHelper => { const rank = shape.length; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; - const mappedType = getWgslValueType(tensorType, components); + const mappedType = getWgslMappedType(tensorType, components); const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0]; const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType}; const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`; + const implementationUsed = { + offsetToIndices: false, + indicesToOffset: false, + set: false, + setByIndices: false, + get: false, + getByIndices: false, + }; + const strides = ShapeUtil.computeStrides(shape); let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { @@ -287,7 +280,10 @@ const createIndicesHelper = return indices; }`; - const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + const offsetToIndices = (varOffset: string) => { + implementationUsed.offsetToIndices = true; + return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + }; const offsets: string[] = []; if (rank >= 2) { @@ -301,7 +297,10 @@ const createIndicesHelper = return ${offsets.join('+')}; }`; - const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + const indicesToOffset = (varIndices: string) => { + implementationUsed.indicesToOffset = true; + return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + }; const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -357,17 +356,18 @@ const createIndicesHelper = } })(); + const getByIndicesImplementation = rank < 2 ? '' : ` + fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { + return ${name}[i2o_${name}(indices)]; + }`; + const getImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; - } fn get_${name}(${params}) -> ${valueType} { return get_${name}ByIndices(${indices(dims)}); - } - `; + }`; })(); const get = (...indices: ReadonlyArray) => { @@ -376,14 +376,16 @@ const createIndicesHelper = } const normalizedIndices = indices.map(normalizeDim).join(','); - const funcName = `get_${name}`; if (rank === 0) { return getByOffset('0u'); } else if (rank === 1) { return getByOffset(normalizedIndices[0]); } else { - return `${funcName}(${normalizedIndices})`; + implementationUsed.get = true; + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; + return `get_${name}(${normalizedIndices})`; } }; @@ -391,21 +393,24 @@ const createIndicesHelper = if (rank < 2) { return getByOffset(varIndices); } else { + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; return `get_${name}ByIndices(${varIndices})`; } }; + const setByIndicesImplementation = rank < 2 ? '' : ` + fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { + ${setByOffset(`i2o_${name}(indices)`, 'value')} + }`; + const setImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { - ${setByOffset(`i2o_${name}(indices)`, 'value')} - } fn set_${name}(${params}, value: ${valueType}) { set_${name}ByIndices(${indices(dims)}, value); - } - `; + }`; })(); const set = (...indicesAndValue: ReadonlyArray) => { @@ -424,6 +429,9 @@ const createIndicesHelper = } else if (rank === 1) { return setByOffset(normalizedIndices[0], value); } else { + implementationUsed.set = true; + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}(${normalizedIndices}, ${value})`; } }; @@ -432,32 +440,34 @@ const createIndicesHelper = if (rank < 2) { return setByOffset(varIndices, value); } else { + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}ByIndices(${varIndices}, ${value});`; } }; - const funcImpls = { - offsetToIndices: offsetToIndicesImplementation, - indicesToOffset: indicesToOffsetImplementation, - set: setImplementation, - get: getImplementation, - }; - const impl = (...functions: Array) => { + const impl = () => { const impls = []; - if (functions.length === 0) { - functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get'); + if (implementationUsed.offsetToIndices) { + impls.push(offsetToIndicesImplementation); } - for (const func of functions) { - const impl = funcImpls[func]; - if (impl === undefined) { - throw new Error(`unknown function ${func}`); - } else { - impls.push(impl); - } + if (implementationUsed.indicesToOffset) { + impls.push(indicesToOffsetImplementation); + } + if (implementationUsed.set) { + impls.push(setImplementation); + } + if (implementationUsed.setByIndices) { + impls.push(setByIndicesImplementation); + } + if (implementationUsed.get) { + impls.push(getImplementation); + } + if (implementationUsed.getByIndices) { + impls.push(getByIndicesImplementation); } return impls.join('\n'); }; - impl.toString = () => impl(); return { impl, @@ -552,6 +562,11 @@ export interface ShaderHelper { * @param variables - an array of IndicesHelper for the variables. */ declareVariables(...variables: IndicesHelper[]): string; + + /** + * Get additional implementation that needs to be added to the shader source. + */ + readonly additionalImplementations: string; } class ShaderHelperImpl implements ShaderHelper { @@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper { } declareVariable(variable: IndicesHelper, bindingIndex: number): string { + this.indicesHelpers.push(variable); const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; @@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper { let i = 0; return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n'); } + + private indicesHelpers: IndicesHelper[] = []; + + get additionalImplementations(): string { + return this.indicesHelpers.map(i => i.impl()).join('\n'); + } } export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 8b91b64a09..9b294803d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -109,9 +109,6 @@ const createConcatProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(...inputVars, output)} - ${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')} - ${output.impl('offsetToIndices')} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateInputIndexImpl(sizeInConcatAxis.length)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7a0e1f01c4..8a794ce16a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -47,9 +47,6 @@ const createGroupedConvProgramInfo = ${shaderHelper.declareVariables(...inputVars, output)} ${activationFunction} - ${output.impl('offsetToIndices')} - ${x.impl('indicesToOffset', 'get')} - ${w.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index b07fe3a90f..2d845775f1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} let outputIndices = ${output.offsetToIndices('global_idx')}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2ce8427bb6..f62c766aa9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -45,7 +45,7 @@ const createInstanceNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; @@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo = const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 48627bfaec..8a9927b25a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -54,7 +54,7 @@ const createLayerNormProgramInfo = } } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9af8fc7b6d..79071d3244 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -128,9 +128,6 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')}); const inputDims = array(${inputDims.map(i => `${i}u`).join(',')}); const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')}); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index b645510d83..cb592c838d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -85,9 +85,6 @@ export const createReduceProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} - ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var inputIndices: ${input.type.indices}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 505bae7ce2..1d0b8229a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -484,8 +484,6 @@ const createResizeProgramInfo = } })()}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} if (${noScale}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 1f881a75ff..4211e52689 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -153,8 +153,6 @@ const createSliceProgramInfo = const steps = array(${steps.map(i => `${i}u`).join(',')}); const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index f5b8a7e3b0..9a150d21ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -87,8 +87,6 @@ const createSplitProgramInfo = const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, ...outputs)} - ${input.impl('indicesToOffset', 'offsetToIndices', 'get')} - ${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')} const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateOutputIndexImpl(sizeInConcatAxis.length)} ${writeBufferDataImpl(outputs)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 0b0185fc17..ebedc61712 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${shaderHelper.declareVariables(input, output)} ${permFunctionBody(perm, rank, input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index b46b35b714..da710b7dc2 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -114,7 +114,9 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { const device = this.backend.device; - const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize)); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize); + const userCode = programInfo.getShaderSource(shaderHelper); + const code = `${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code}); LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index a89a585906..389773f3e8 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; - -export const tensorTypeToWsglType = (type: DataType) => { - switch (type) { - case DataType.float: - return 'f32'; - // TODO: enable after "shader-f16" WSGL extension release - // case DataType.float16: - // return 'f16'; - case DataType.int32: - return 'i32'; - case DataType.uint32: - return 'u32'; - default: - throw new Error(`Unsupported type: ${type}`); - } -};