mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[js/web] optimize reduce related operators (#17957)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
8978bdc59d
commit
8d48d3e9cc
2 changed files with 380 additions and 11 deletions
266
js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Normal file
266
js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce';
|
||||
import {createTransposeProgramInfo} from './transpose';
|
||||
|
||||
const reduceOps: {[key: string]: string} = {
|
||||
max: 'select(bestValue, candidate, candidate > bestValue)',
|
||||
min: 'select(bestValue, candidate, candidate < bestValue)',
|
||||
mean: 'bestValue + candidate',
|
||||
sum: 'bestValue + candidate',
|
||||
prod: 'bestValue * candidate',
|
||||
sumSquare: 'bestValue + candidate * candidate',
|
||||
logSumExp: 'bestValue + exp(candidate)',
|
||||
l1: 'bestValue + abs(candidate)',
|
||||
l2: 'bestValue + candidate * candidate',
|
||||
logSum: 'bestValue + candidate'
|
||||
};
|
||||
|
||||
const reduceSharedOps: {[key: string]: string} = {
|
||||
max: 'select(bestValue, candidate, candidate > bestValue)',
|
||||
min: 'select(bestValue, candidate, candidate < bestValue)',
|
||||
mean: 'bestValue + candidate',
|
||||
sum: 'bestValue + candidate',
|
||||
prod: 'bestValue * candidate',
|
||||
sumSquare: 'bestValue + candidate',
|
||||
logSumExp: 'bestValue + candidate',
|
||||
l1: 'bestValue + candidate',
|
||||
l2: 'bestValue + candidate',
|
||||
logSum: 'bestValue + candidate'
|
||||
};
|
||||
|
||||
const reduceInitValues: {[key: string]: string} = {
|
||||
max: '_A[offset]',
|
||||
min: '_A[offset]',
|
||||
mean: '0',
|
||||
sum: '0',
|
||||
prod: '1',
|
||||
sumSquare: '0',
|
||||
logSumExp: '0',
|
||||
l1: '0',
|
||||
l2: '0',
|
||||
logSum: '0'
|
||||
};
|
||||
|
||||
const reduceOutputValues: {[key: string]: string} = {
|
||||
max: 'bestValue',
|
||||
min: 'bestValue',
|
||||
sum: 'bestValue',
|
||||
prod: 'bestValue',
|
||||
sumSquare: 'bestValue',
|
||||
logSumExp: 'log(bestValue)',
|
||||
l1: 'bestValue',
|
||||
l2: 'sqrt(bestValue)',
|
||||
logSum: 'log(bestValue)'
|
||||
};
|
||||
|
||||
const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => {
|
||||
const res = [];
|
||||
for (let i = rank - numInnerAxes; i < rank; ++i) {
|
||||
res.push(i);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => {
|
||||
const outputShape = [];
|
||||
const rank = shape.length;
|
||||
for (let dim = 0; dim < rank; dim++) {
|
||||
if (axes.indexOf(dim) === -1) {
|
||||
outputShape.push(shape[dim]);
|
||||
}
|
||||
}
|
||||
const reduceShape = axes.map(dim => shape[dim]);
|
||||
return [outputShape, reduceShape];
|
||||
};
|
||||
|
||||
const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => {
|
||||
const rank = shape.length + axes.length;
|
||||
const expandShape = [];
|
||||
let shapeIdx = 0;
|
||||
for (let dim = 0; dim < rank; dim++) {
|
||||
if (axes.indexOf(dim) === -1) {
|
||||
expandShape.push(shape[shapeIdx++]);
|
||||
} else {
|
||||
expandShape.push(1);
|
||||
}
|
||||
}
|
||||
return expandShape;
|
||||
};
|
||||
|
||||
const areAxesInnerMostDims = (axes: number[], rank: number): boolean => {
|
||||
for (let i = 0; i < axes.length; ++i) {
|
||||
if (axes[axes.length - i - 1] !== rank - 1 - i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const getAxesPermutation = (axes: number[], rank: number): number[] => {
|
||||
const res = [];
|
||||
if (!areAxesInnerMostDims(axes, rank)) {
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
if (axes.indexOf(i) === -1) {
|
||||
res.push(i);
|
||||
}
|
||||
}
|
||||
axes.forEach(axis => res.push(axis));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
export const createReduceSharedProgramInfo =
|
||||
(name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string,
|
||||
outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const reduceSize = ShapeUtil.size(reduceShape);
|
||||
|
||||
const input = inputVariable('_A', inputs[0].dataType, inputShape);
|
||||
const output = outputVariable('output', outputDataType, outputShape);
|
||||
|
||||
const workgroupSize = 32;
|
||||
|
||||
const sharedMemorySnippet = `
|
||||
var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
|
||||
`;
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)}
|
||||
${sharedMemorySnippet}
|
||||
fn DIV_CEIL(a : u32, b : u32) -> u32 {
|
||||
return ((a - 1u) / b + 1u);
|
||||
}
|
||||
${shaderHelper.mainStart(workgroupSize)}
|
||||
let local_idx = local_id.x;
|
||||
|
||||
let outputIndex = global_idx / ${workgroupSize};
|
||||
let offset = outputIndex * uniforms.reduceSize;
|
||||
|
||||
var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
|
||||
let Length = uniforms.reduceSize;
|
||||
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
|
||||
let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
|
||||
bestValue = ${reduceOps[reduceType]};
|
||||
}
|
||||
aBestValues[local_idx] = bestValue;
|
||||
workgroupBarrier();
|
||||
|
||||
var reduceSize = min(Length, ${workgroupSize}u);
|
||||
for (var currentSize = reduceSize / 2u; reduceSize > 1u;
|
||||
currentSize = reduceSize / 2u) {
|
||||
let interval = DIV_CEIL(reduceSize, 2u);
|
||||
if (local_idx < currentSize) {
|
||||
let candidate = aBestValues[local_idx + interval];
|
||||
bestValue = ${reduceSharedOps[reduceType]};
|
||||
aBestValues[local_idx] = bestValue;
|
||||
}
|
||||
reduceSize = interval;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (local_idx == 0u) {
|
||||
${
|
||||
output.setByOffset(
|
||||
'outputIndex',
|
||||
`${
|
||||
reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
|
||||
`${reduceOutputValues[reduceType]}`}`)};
|
||||
}
|
||||
}`;
|
||||
|
||||
// One work group is responsible for only one element of output.
|
||||
return {
|
||||
name,
|
||||
shaderCache,
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: outputSize},
|
||||
programUniforms: [{type: 'uint32', data: reduceSize}]
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
||||
const reduceCommon =
|
||||
(context: ComputeContext, name: string, attributes: ReduceAttributes,
|
||||
reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => {
|
||||
const updatedAttributes: ReduceAttributes =
|
||||
context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes);
|
||||
|
||||
let updatedAxes = updatedAttributes.axes;
|
||||
if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) {
|
||||
updatedAxes = context.inputs[0].dims.map((s, i) => i);
|
||||
}
|
||||
const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length);
|
||||
|
||||
let axes = normalizeAxes;
|
||||
let input = context.inputs[0];
|
||||
const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length);
|
||||
if (permutedAxes.length > 0) {
|
||||
input = context.compute(
|
||||
createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0];
|
||||
axes = getInnerMostAxes(axes.length, input.dims.length);
|
||||
}
|
||||
|
||||
const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes);
|
||||
let finalOutputShape = outputShape;
|
||||
if (updatedAttributes.keepDims) {
|
||||
finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes);
|
||||
}
|
||||
|
||||
context.compute(
|
||||
createReduceSharedProgramInfo(
|
||||
name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType,
|
||||
context.inputs[0].dataType, finalOutputShape, reduceShape),
|
||||
{inputs: [input]});
|
||||
};
|
||||
|
||||
export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceMeanShared', attributes, 'mean');
|
||||
};
|
||||
|
||||
export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceL1Shared', attributes, 'l1');
|
||||
};
|
||||
|
||||
export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceL2Shared', attributes, 'l2');
|
||||
};
|
||||
|
||||
export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp');
|
||||
};
|
||||
|
||||
export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceMaxShared', attributes, 'max');
|
||||
};
|
||||
|
||||
export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceMinShared', attributes, 'min');
|
||||
};
|
||||
|
||||
export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceProdShared', attributes, 'prod');
|
||||
};
|
||||
|
||||
export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceSumShared', attributes, 'sum');
|
||||
};
|
||||
|
||||
export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare');
|
||||
};
|
||||
|
||||
export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum');
|
||||
};
|
||||
|
|
@ -8,6 +8,7 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w
|
|||
import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length === 0 || inputs.length > 2) {
|
||||
|
|
@ -106,7 +107,7 @@ export const createReduceProgramInfo =
|
|||
};
|
||||
};
|
||||
|
||||
const createReduceAttributesFromInputs =
|
||||
export const createReduceAttributesFromInputs =
|
||||
(inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => {
|
||||
const axes: number[] = [];
|
||||
if (inputs[1].dims[0] > 0) {
|
||||
|
|
@ -131,7 +132,7 @@ const runReduceProgram =
|
|||
{inputs: [0]});
|
||||
};
|
||||
|
||||
export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var value = ${output.type.storage}(0);`,
|
||||
|
|
@ -142,7 +143,7 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut
|
|||
runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var value = ${output.type.storage}(0);`,
|
||||
|
|
@ -153,7 +154,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes):
|
|||
runReduceProgram(context, 'ReduceL1', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
|
||||
|
|
@ -164,7 +165,7 @@ export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes):
|
|||
runReduceProgram(context, 'ReduceL2', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var value = ${output.type.storage}(0);`,
|
||||
|
|
@ -175,7 +176,7 @@ export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttri
|
|||
runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, _output, axes) => {
|
||||
const idxZero = [];
|
||||
|
|
@ -195,7 +196,7 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes)
|
|||
runReduceProgram(context, 'ReduceMax', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output, axes) => {
|
||||
let size = 1.0;
|
||||
|
|
@ -216,7 +217,7 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes
|
|||
runReduceProgram(context, 'ReduceMean', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, _output, axes) => {
|
||||
const idxZero = [];
|
||||
|
|
@ -236,7 +237,7 @@ export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes)
|
|||
runReduceProgram(context, 'ReduceMin', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var value = ${output.type.storage}(1);`,
|
||||
|
|
@ -247,7 +248,7 @@ export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes
|
|||
runReduceProgram(context, 'ReduceProd', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var value = ${output.type.storage}(0);`,
|
||||
|
|
@ -258,7 +259,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes)
|
|||
runReduceProgram(context, 'ReduceSum', attributes, reduceOp);
|
||||
};
|
||||
|
||||
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const reduceOp: ReduceOp = (input, output) =>
|
||||
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
|
||||
|
|
@ -269,5 +270,107 @@ export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttri
|
|||
runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp);
|
||||
};
|
||||
|
||||
const useNaiveReduceMethod =
|
||||
(shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => {
|
||||
if (axes.length === 0) {
|
||||
return noopWithEmptyAxes ? true : false;
|
||||
}
|
||||
|
||||
let outputSize = 1;
|
||||
let reduceSize = 1;
|
||||
for (let dim = 0; dim < axes.length; dim++) {
|
||||
if (axes.indexOf(dim) === -1) {
|
||||
outputSize *= shape[dim];
|
||||
} else {
|
||||
reduceSize *= shape[dim];
|
||||
}
|
||||
}
|
||||
|
||||
// The condition data is very rough, although considering the count of Execution Unit (EU), the potential
|
||||
// work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments
|
||||
// on some machines.
|
||||
return reduceSize < 32 && outputSize > 1024 ? true : false;
|
||||
};
|
||||
|
||||
export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceMeanNaive(context, attributes);
|
||||
} else {
|
||||
reduceMeanShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceL1Naive(context, attributes);
|
||||
} else {
|
||||
reduceL1Shared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceL2Naive(context, attributes);
|
||||
} else {
|
||||
reduceL2Shared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceLogSumExpNaive(context, attributes);
|
||||
} else {
|
||||
reduceLogSumExpShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceMaxNaive(context, attributes);
|
||||
} else {
|
||||
reduceMaxShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceMinNaive(context, attributes);
|
||||
} else {
|
||||
reduceMinShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceProdNaive(context, attributes);
|
||||
} else {
|
||||
reduceProdShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceSumNaive(context, attributes);
|
||||
} else {
|
||||
reduceSumShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceSumSquareNaive(context, attributes);
|
||||
} else {
|
||||
reduceSumSquareShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
|
||||
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
|
||||
reduceLogSumNaive(context, attributes);
|
||||
} else {
|
||||
reduceLogSumShared(context, attributes);
|
||||
}
|
||||
};
|
||||
|
||||
export const parseReduceAttributes = (attributes: Record<string, unknown>): ReduceAttributes =>
|
||||
createAttributeWithCacheKey(attributes as Omit<ReduceAttributes, keyof AttributeWithCacheKey>);
|
||||
|
|
|
|||
Loading…
Reference in a new issue