onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
satyajandhyala 0ca84549ab
[JS/Web] Added uniforms to Reduce, Resize and Split Ops. (#18727)
### Description
<!-- Describe your changes. -->
Added uniforms to Reduce op


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve perforamnce.
2023-12-12 11:12:23 -08:00

83 lines
3.2 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TODO: this is the same naive implementation we use for reduce that has
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext} from '../types';
import {createReduceProgramInfo, ReduceOp} from './reduce';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length === 0 || inputs.length > 2) {
throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};
export interface ArgMinMaxAttributes extends AttributeWithCacheKey {
keepDims: boolean;
axis: number;
selectLastIndex: number;
}
export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
validateInputs(context.inputs);
const argMinMaxOp: ReduceOp = (input, output, axes) => {
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
value = ${input.getByIndices('input_indices')};
best_index = i32(last_index);
}`,
'', output.setByOffset('global_idx', 'best_index')
];
};
context.compute(
createReduceProgramInfo(
'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
[attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};
export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
validateInputs(context.inputs);
const argMinMaxOp: ReduceOp = (input, output, axes) => {
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
value = ${input.getByIndices('input_indices')};
best_index = i32(last_index);
}`,
'', output.setByOffset('global_idx', 'best_index')
];
};
context.compute(
createReduceProgramInfo(
'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
[attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};
export const parseArgMinMaxAttributes = (attributes: Record<string, unknown>): ArgMinMaxAttributes =>
createAttributeWithCacheKey(attributes as Omit<ArgMinMaxAttributes, keyof AttributeWithCacheKey>);