diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 95b464ad9a..698dd89b87 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -14,19 +14,23 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('Reduce op requires 1 or 2 inputs.'); } + if (inputs.length === 2 && inputs[1].dims.length !== 1) { + throw new Error('Invalid axes input dims.'); + } + if (inputs[0].dataType !== DataType.float) { throw new Error('Invalid input type.'); } }; export interface ReduceAttributes extends AttributeWithCacheKey { - axes: number[]; keepDims: boolean; noopWithEmptyAxes: boolean; + axes: number[]; } type ReduceOp = (inputs: readonly TensorView[], axes: number[]) => string[]; - +const noOp: ReduceOp = (): string[] => ['', '', 'value = _A[inputIdx];', '']; const createReduceProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: ReduceAttributes, reduceOp: ReduceOp): ProgramInfo => { @@ -36,17 +40,18 @@ const createReduceProgramInfo = const idxCopy: string[] = []; // copy output indexes to input indexes const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length); + const outputDimsLength = inputs[0].dims.length - (attributes.keepDims ? 0 : axes.length); const ops = reduceOp(inputs, axes); const inputIndicesHelper = createIndicesHelper('input', inputShape); const initInputIdx = (ops[1] === '') ? '' : `let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};`; let reduceOps = ` let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')}; ${ops[2]};`; - + const reduceOnAllAxes = !attributes.noopWithEmptyAxes && attributes.axes.length === 0; for (let k = 0; k < inputs[0].dims.length; k++) { // if this axis is reduced - if (axes.indexOf(k) >= 0 || axes.length === 0) { - if (attributes.keepDims === true) { + if (reduceOnAllAxes || axes.indexOf(k) >= 0) { + if (attributes.keepDims) { outputShape.push(1); } // else { remove the axis from outputShape; } @@ -56,7 +61,11 @@ const createReduceProgramInfo = ${reduceOps} }`; } else { - idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`); + if (outputDimsLength > 1) { + idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`); + } else { + idxCopy.push(`inputIndices[${k}] = outputIndices;`); + } outputShape.push(inputs[0].dims[k]); } } @@ -97,23 +106,28 @@ const createReduceProgramInfo = }; }; -const createReduceAttributesFromInput = (input: TensorView, attributes: ReduceAttributes): ReduceAttributes => { - const axes: number[] = []; - input.getBigInt64Array().forEach(v => axes.push(Number(v))); - const keepDims = attributes.keepDims; - const noopWithEmptyAxes = attributes.noopWithEmptyAxes; - return createAttributeWithCacheKey({axes, keepDims, noopWithEmptyAxes}); -}; +const createReduceAttributesFromInputs = + (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { + const axes: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => axes.push(Number(v))); + } + return createAttributeWithCacheKey( + {axes, keepDims: attributes.keepDims, noopWithEmptyAxes: attributes.noopWithEmptyAxes}); + }; const createReduceProgramInfoLoader = (inputs: readonly TensorView[], name: string, attributes: ReduceAttributes, reduceOp: ReduceOp): ProgramInfoLoader => { - const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default]}; + const updatedAttributes: ReduceAttributes = + inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes); + const metadata: + ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; return { ...metadata, get: () => createReduceProgramInfo( - metadata, [inputs[0]], - (inputs.length === 1) ? attributes : createReduceAttributesFromInput(inputs[1], attributes), reduceOp) + metadata, [inputs[0]], updatedAttributes, + updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp) }; };