diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 38dc14f236..014d9d02f6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -646,6 +646,8 @@ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components); +export type UniformsArrayType = Array<{name: string; type: string}>; + /** * A ShaderHelper is a helper class for generating WGSL code. */ @@ -697,6 +699,7 @@ export interface ShaderHelper { * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. */ registerUniform(name: string, type: string): ShaderHelper; + registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper { return this; } + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + private indicesHelpers: IndicesHelper[] = []; - private uniforms: Array<{name: string; type: string}> = []; + private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { return ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index d607351f69..7458579bf4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,17 +77,26 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], + enableInputShapeUniforms: boolean): string => + `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { + let input_shape_i = ${ + enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; + let steps_i = ${ + enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; + let signs_i = ${ + enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; + let starts_i = ${ + enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + var inputIndex = outputIndex * steps_i + starts_i + carry; + carry = inputIndex / input_shape_i; + inputIndex = inputIndex % input_shape_i; + if (signs_i < 0) { + inputIndex = input_shape_i - inputIndex - 1u + starts_i; } ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; } @@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } + if (axes.length !== inputShape.length) { for (let i = 0; i < inputShape.length; ++i) { if (!axes.includes(i)) { @@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice array[i] = -step; } }); + // Output rank is expected to be less than or equal to the input rank. + const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); + const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = []; + const uniforms: UniformsArrayType = []; + if (enableShapeUniforms) { + uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); + uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); + uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); + programUniforms.push({type: 'uint32', data: starts}); + programUniforms.push({type: 'int32', data: signs}); + programUniforms.push({type: 'uint32', data: steps}); + } + uniforms.push({name: 'outputSize', type: 'u32'}); + programUniforms.push({type: 'uint32', data: outputSize}); + if (enableShapeUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + } const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${enableShapeUniforms ? '' : [ + `const signs = array(${signs.map(i => `${i}i`).join(',')});`, + `const starts = array(${starts.map(i => `${i}u`).join(',')});`, + `const steps = array(${steps.map(i => `${i}u`).join(',')});`, + `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` + ].join('\n')} - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; let inputIndices = calculateInputIndices(outputIndices); ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; return { name: 'Slice', - shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, + shaderCache: { + hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : + `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, + inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] + }, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80..beef154a29 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice",