js/webgpu: argmax,argmin,softmax support (#16882)

argmax and argmin are similar to reduce. Eventually we need to add
optimized flavors of the shader.

softmax is optimized but only works on the last axis for now which
should be the common use case.

todo: enable more ut for argmax/argmin
This commit is contained in:
Guenther Schmuelling 2023-08-02 18:16:19 -07:00 committed by GitHub
parent 506ddb3d5d
commit 0df2e14038
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 544 additions and 47 deletions

View file

@ -14,6 +14,8 @@ Do not modify directly.*
| Acos | ai.onnx(7+) | |
| Acosh | ai.onnx(9+) | |
| Add | ai.onnx(7-12,13,14+) | |
| ArgMax | ai.onnx(1-10,11-12,13+) | |
| ArgMin | ai.onnx(1-10,11-12,13+) | |
| Asin | ai.onnx(7+) | |
| Asinh | ai.onnx(9+) | |
| Atan | ai.onnx(7+) | |
@ -64,6 +66,7 @@ Do not modify directly.*
| Sin | ai.onnx(7+) | |
| Sinh | ai.onnx(9+) | |
| Slice | ai.onnx(1-9,10,11-12,13+) | |
| Softmax | ai.onnx(1-10,11-12,13+) | |
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
| Sqrt | ai.onnx(6-12,13+) | |
| Squeeze | ai.onnx(1-10,11-12,13+) | |

View file

@ -22,6 +22,7 @@ export declare namespace Tensor {
uint16: Uint16Array;
int32: Int32Array;
uint32: Uint32Array;
int64: BigInt64Array;
}
export type DataType = keyof DataTypeMap;
@ -409,6 +410,8 @@ function dataviewConstructor(type: Tensor.DataType) {
return Int32Array;
case 'uint32':
return Uint32Array;
case 'int64':
return BigInt64Array;
case 'float32':
return Float32Array;
case 'float64':

View file

@ -210,7 +210,7 @@ export class BroadcastUtil {
// both inputs are scalars
if (outputShape.length === 0) {
c.set([], op(a.get([]), b.get([])));
c.set([], op(a.get([]) as number, b.get([]) as number));
}
// atleast one input is a non-scalar
@ -223,11 +223,11 @@ export class BroadcastUtil {
let isAScalar = false;
let isBScalar = false;
if (a.dims.length === 0) {
valA = a.get([]);
valA = a.get([]) as number;
isAScalar = true;
}
if (b.dims.length === 0) {
valB = b.get([]);
valB = b.get([]) as number;
isBScalar = true;
}
let rest: number;
@ -242,11 +242,11 @@ export class BroadcastUtil {
if (!isAScalar) {
// map outputIndices (which is actually broadcasted) to the originalIndices
BroadcastUtil.fillIndex(outputIndices, a.dims, originalIndicesA);
valA = a.get(originalIndicesA);
valA = a.get(originalIndicesA) as number;
}
if (!isBScalar) {
BroadcastUtil.fillIndex(outputIndices, b.dims, originalIndicesB);
valB = b.get(originalIndicesB);
valB = b.get(originalIndicesB) as number;
}
c.set(outputIndices, op(valA, valB));

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
@ -13,6 +14,7 @@ import * as pool from './ops/pool';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
@ -27,6 +29,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Acos', [unaryOps.acos]],
['Acosh', [unaryOps.acosh]],
['Add', [binaryOps.add]],
['ArgMax', [argMax, parseArgMinMaxAttributes]],
['ArgMin', [argMin, parseArgMinMaxAttributes]],
['Asin', [unaryOps.asin]],
['Asinh', [unaryOps.asinh]],
['Atan', [unaryOps.atan]],
@ -77,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Slice', [slice, parseSliceAttributes]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Softmax', [softmax, parseSoftmaxAttributes]],
['Sub', [binaryOps.sub]],
['Tan', [unaryOps.tan]],
['Tanh', [unaryOps.tanh]],

View file

@ -0,0 +1,156 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TODO: this is the same naive implementation we use for reduce that has
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createIndicesHelper, ShaderHelper} from './common';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length === 0 || inputs.length > 2) {
throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};
export interface ArgMinMaxAttributes extends AttributeWithCacheKey {
keepDims: boolean;
axes: number;
selectLastIndex: number;
}
type ArgMinMaxOp = (inputs: readonly TensorView[], axes: number[]) => string[];
const createReduceProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: ArgMinMaxAttributes,
argMinMaxOp: ArgMinMaxOp): ProgramInfo => {
const outputShape: number[] = [];
const inputShape = inputs[0].dims;
const idxCopy: string[] = []; // copy output indexes to input indexes
const axes = ShapeUtil.normalizeAxes([attributes.axes], inputs[0].dims.length);
const outputDimsLength = inputs[0].dims.length - (attributes.keepDims ? 0 : axes.length);
const ops = argMinMaxOp(inputs, axes);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const initInputIdx = (ops[1] === '') ? '' : `let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};`;
let reduceOps = `
let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};
${ops[2]};`;
for (let k = 0; k < inputs[0].dims.length; k++) {
// if this axis is reduced
if (axes.indexOf(k) >= 0) {
if (attributes.keepDims) {
outputShape.push(1);
}
// loop over the d-th axis
reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
let lastIndex = j${k};
inputIndices[${k}] = lastIndex;
${reduceOps}
}`;
} else {
if (outputDimsLength > 1) {
idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`);
} else {
idxCopy.push(`inputIndices[${k}] = outputIndices;`);
}
outputShape.push(inputs[0].dims[k]);
}
}
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const outputSize = ShapeUtil.size(outputShape);
const dataType = 'f32';
const getShaderSource = (shaderHelper: ShaderHelper) => `
@group(0) @binding(0) var<storage, read> _A : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output : array<i32>;
${outputIndicesHelper.o2iImpl}
${inputIndicesHelper.i2oImpl}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
${idxCopy.join('\n')}
${ops[0]} // init ops
${initInputIdx}
${ops[1]}
${reduceOps}
${ops[3]} // final values
output[global_idx*2] = bestIndex; // result it int64
}`;
return {
...metadata,
getShaderSource,
outputs: [{dims: outputShape, dataType: DataType.int64, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: Math.ceil(outputSize / 64)})
};
};
const createArgMinMaxAttributesFromInputs =
(inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes =>
createAttributeWithCacheKey(
{axes: attributes.axes, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex});
const createReduceProgramInfoLoader =
(inputs: readonly TensorView[], name: string, attributes: ArgMinMaxAttributes, reduceOp: ArgMinMaxOp):
ProgramInfoLoader => {
const updatedAttributes: ArgMinMaxAttributes =
inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createReduceProgramInfo(metadata, [inputs[0]], updatedAttributes, reduceOp)};
};
export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
validateInputs(context.inputs);
const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => {
const idxZero = [];
for (let k = 0; k < inputs[0].dims.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;',
'if (_A[inputIdx] < value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', ''
];
};
context.compute(createReduceProgramInfoLoader(context.inputs, 'ArgMin', attributes, argMinMaxOp), {inputs: [0]});
};
export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
validateInputs(context.inputs);
const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => {
const idxZero = [];
for (let k = 0; k < inputs[0].dims.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;',
'if (_A[inputIdx] > value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', ''
];
};
context.compute(createReduceProgramInfoLoader(context.inputs, 'argMax', attributes, argMinMaxOp), {inputs: [0]});
};
export const parseArgMinMaxAttributes = (attributes: Record<string, unknown>): ArgMinMaxAttributes =>
createAttributeWithCacheKey(attributes as Omit<ArgMinMaxAttributes, keyof AttributeWithCacheKey>);

View file

@ -0,0 +1,147 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TODO: this is the same naive implementation we use for reduce that has
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
import {ShaderHelper} from './common';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Softmax op requires 1 input.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Softmax input needs to be float.');
}
};
export interface SoftmaxAttributes extends AttributeWithCacheKey {
readonly axis: number;
}
export const softmaxProgramMetadata = {
name: 'Softmax',
inputTypes: [GpuDataType.default]
};
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const dataType = 'f32';
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const WG = 64;
let axis = attributes.axis;
if (axis < 0) {
axis = shape.length + axis;
}
if (axis < shape.length - 1) {
throw new Error('softmax only supports last axis for now.');
}
const cols = shape[axis];
const rows = outputSize / cols;
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
let index = row * row_stride + col;
result[index] = value;
}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(global_invocation_id) global_id : vec3u) {
let gindex = i32(global_id.x);
let lindex = i32(local_id.x);
const wg = ${WG};
let row = gindex / wg;
let cols = ${cols};
let row_stride : i32 = ${cols};
// find the rows max
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
for (var col = lindex; col < cols; col += wg) {
let value = getValue(row, col, row_stride);
threadMax = max(threadMax, value);
}
if (lindex < cols) {
threadShared[lindex] = threadMax;
}
workgroupBarrier();
var reduceSize = min(cols, wg);
for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
reduceSize = currSize + (reduceSize & 1);
if (lindex < currSize) {
threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
}
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = threadShared[0];
}
workgroupBarrier();
// find the rows sum
var threadSum = 0.0;
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
}
threadShared[lindex] = threadSum;
workgroupBarrier();
for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {
if (lindex < currSize) {
threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
}
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = threadShared[0];
}
workgroupBarrier();
// calculate final value for each element in the row
for (var col = lindex; col < cols; col += wg) {
let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
setValue(row, col, row_stride, value);
}
}`;
return {
...softmaxProgramMetadata,
outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: rows})
};
};
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
validateInputs(context.inputs);
context.compute({
...softmaxProgramMetadata,
cacheHint: attributes.cacheKey,
get: () => createSoftmaxProgramInfo(context.inputs[0], attributes)
});
};
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
createAttributeWithCacheKey({axis: attributes.axis as number});

View file

@ -305,38 +305,38 @@
// "test_and2d",
// "test_and3d",
// "test_and4d",
// // "test_argmax_default_axis_example_select_last_index",
// // "test_argmax_default_axis_example",
// // "test_argmax_default_axis_random_select_last_index",
// // "test_argmax_default_axis_random",
// // "test_argmax_keepdims_example_select_last_index",
// // "test_argmax_keepdims_example",
// // "test_argmax_keepdims_random_select_last_index",
// // "test_argmax_keepdims_random",
// // "test_argmax_negative_axis_keepdims_example_select_last_index",
// // "test_argmax_negative_axis_keepdims_example",
// // "test_argmax_negative_axis_keepdims_random_select_last_index",
// // "test_argmax_negative_axis_keepdims_random",
// // "test_argmax_no_keepdims_example_select_last_index",
// // "test_argmax_no_keepdims_example",
// // "test_argmax_no_keepdims_random_select_last_index",
// // "test_argmax_no_keepdims_random",
// // "test_argmin_default_axis_example_select_last_index",
// // "test_argmin_default_axis_example",
// // "test_argmin_default_axis_random_select_last_index",
// // "test_argmin_default_axis_random",
// // "test_argmin_keepdims_example_select_last_index",
// // "test_argmin_keepdims_example",
// // "test_argmin_keepdims_random_select_last_index",
// // "test_argmin_keepdims_random",
// // "test_argmin_negative_axis_keepdims_example_select_last_index",
// // "test_argmin_negative_axis_keepdims_example",
// // "test_argmin_negative_axis_keepdims_random_select_last_index",
// // "test_argmin_negative_axis_keepdims_random",
// // "test_argmin_no_keepdims_example_select_last_index",
// // "test_argmin_no_keepdims_example",
// // "test_argmin_no_keepdims_random_select_last_index",
// // "test_argmin_no_keepdims_random",
// "test_argmax_default_axis_example_select_last_index",
"test_argmax_default_axis_example",
// "test_argmax_default_axis_random_select_last_index",
"test_argmax_default_axis_random",
// "test_argmax_keepdims_example_select_last_index",
"test_argmax_keepdims_example",
// "test_argmax_keepdims_random_select_last_index",
"test_argmax_keepdims_random",
// "test_argmax_negative_axis_keepdims_example_select_last_index",
"test_argmax_negative_axis_keepdims_example",
// "test_argmax_negative_axis_keepdims_random_select_last_index",
"test_argmax_negative_axis_keepdims_random",
// "test_argmax_no_keepdims_example_select_last_index",
"test_argmax_no_keepdims_example",
// "test_argmax_no_keepdims_random_select_last_index",
"test_argmax_no_keepdims_random",
// "test_argmin_default_axis_example_select_last_index",
"test_argmin_default_axis_example",
// "test_argmin_default_axis_random_select_last_index",
"test_argmin_default_axis_random",
// "test_argmin_keepdims_example_select_last_index",
"test_argmin_keepdims_example",
// "test_argmin_keepdims_random_select_last_index",
"test_argmin_keepdims_random",
// "test_argmin_negative_axis_keepdims_example_select_last_index",
"test_argmin_negative_axis_keepdims_example",
// "test_argmin_negative_axis_keepdims_random_select_last_index",
"test_argmin_negative_axis_keepdims_random",
// "test_argmin_no_keepdims_example_select_last_index",
"test_argmin_no_keepdims_example",
// "test_argmin_no_keepdims_random_select_last_index",
"test_argmin_no_keepdims_random",
"test_asin_example",
"test_asin",
"test_asinh_example",
@ -1133,8 +1133,8 @@
// "test_softmax_axis_0",
// "test_softmax_axis_1_expanded",
// "test_softmax_axis_1",
// "test_softmax_axis_2_expanded",
// "test_softmax_axis_2",
"test_softmax_axis_2_expanded",
"test_softmax_axis_2",
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded",
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded",
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob",
@ -1203,14 +1203,14 @@
// "test_softmax_cross_entropy_sum_log_prob_expanded",
// "test_softmax_cross_entropy_sum_log_prob",
// "test_softmax_cross_entropy_sum",
// "test_softmax_default_axis_expanded",
// "test_softmax_default_axis",
// "test_softmax_example_expanded",
// "test_softmax_example",
// "test_softmax_large_number_expanded",
// "test_softmax_large_number",
// "test_softmax_negative_axis_expanded",
// "test_softmax_negative_axis",
"opset13/test_softmax_default_axis_expanded",
"opset13/test_softmax_default_axis",
"test_softmax_example_expanded",
"test_softmax_example",
"test_softmax_large_number_expanded",
"test_softmax_large_number",
"test_softmax_negative_axis_expanded",
"test_softmax_negative_axis",
// // "test_softplus_example",
// // "test_softplus",
// // "test_softsign_example",

View file

@ -388,6 +388,7 @@ export class TensorResultValidator {
case 'int16':
case 'int32':
case 'uint32':
case 'int64':
case 'bool':
return this.integerEqual(
actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |

View file

@ -237,6 +237,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat);
@ -441,6 +452,18 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat)>,

View file

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "argminmax.h"
namespace onnxruntime {
namespace js {
#define REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMinMaxOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMinMaxOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
ArgMinMaxOp<float>);
#define REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMinMaxOp, sinceVersion) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ArgMinMaxOp, \
kOnnxDomain, \
sinceVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
ArgMinMaxOp<float>);
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10);
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12);
REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMax, 13);
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12);
REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMin, 13);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
#include "core/providers/cpu/reduction/reduction_ops.h"
namespace onnxruntime {
namespace js {
#define JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMinMaxKernel) \
template <typename T, bool allow_multi_axes = false> \
class ArgMinMaxKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
using ReduceKernelBase<allow_multi_axes>::select_last_index_; \
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
ArgMinMaxKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
std::vector<int32_t> axes(axes_.size()); \
if (axes_.size() > 0) { \
axes.push_back(-1); \
} \
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
JSEP_INIT_KERNEL_ATTRIBUTE(ArgMinMaxKernel, ({ \
"keepDims" : !!$1, \
"selectLastIndex" : !!$2, \
"axes" : $3, \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(select_last_index_), \
gsl::narrow_cast<int32_t>(axes[0])); \
} \
};
JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMax);
JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMin);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "softmax.h"
namespace onnxruntime {
namespace js {
#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
SoftmaxOp<float>);
#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
SoftmaxOp<float>);
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10);
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12);
REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(Softmax, 13);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
#include "core/providers/cpu/reduction/reduction_ops.h"
namespace onnxruntime {
namespace js {
template <typename T>
class Softmax : public JsKernel {
public:
Softmax(const OpKernelInfo& info) : JsKernel(info) {
const auto& node = info.node();
opset_ = node.SinceVersion();
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);
if (status.IsOK()) {
axis_ = gsl::narrow_cast<int>(axis);
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}
JSEP_INIT_KERNEL_ATTRIBUTE(Softmax, ({
"axis" : $1
}),
axis_);
}
private:
int axis_;
int opset_;
};
} // namespace js
} // namespace onnxruntime