From 0df2e14038394dc3cef3de6384ff2c7514d1e187 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 2 Aug 2023 18:16:19 -0700 Subject: [PATCH] js/webgpu: argmax,argmin,softmax support (#16882) argmax and argmin are similar to reduce. Eventually we need to add optimized flavors of the shader. softmax is optimized but only works on the last axis for now which should be the common use case. todo: enable more ut for argmax/argmin --- js/web/docs/webgpu-operators.md | 3 + js/web/lib/onnxjs/tensor.ts | 3 + js/web/lib/onnxjs/util.ts | 10 +- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 5 + js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts | 156 ++++++++++++++++++ js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 147 +++++++++++++++++ js/web/test/suite-test-list.jsonc | 84 +++++----- js/web/test/test-runner.ts | 1 + .../providers/js/js_execution_provider.cc | 23 +++ .../core/providers/js/operators/argminmax.cc | 41 +++++ .../core/providers/js/operators/argminmax.h | 39 +++++ .../core/providers/js/operators/softmax.cc | 37 +++++ .../core/providers/js/operators/softmax.h | 42 +++++ 13 files changed, 544 insertions(+), 47 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/softmax.ts create mode 100644 onnxruntime/core/providers/js/operators/argminmax.cc create mode 100644 onnxruntime/core/providers/js/operators/argminmax.h create mode 100644 onnxruntime/core/providers/js/operators/softmax.cc create mode 100644 onnxruntime/core/providers/js/operators/softmax.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 991b01ea8a..7b6d72bc78 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -14,6 +14,8 @@ Do not modify directly.* | Acos | ai.onnx(7+) | | | Acosh | ai.onnx(9+) | | | Add | ai.onnx(7-12,13,14+) | | +| ArgMax | ai.onnx(1-10,11-12,13+) | | +| ArgMin | ai.onnx(1-10,11-12,13+) | | | Asin | ai.onnx(7+) | | | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | @@ -64,6 +66,7 @@ Do not modify directly.* | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | | Slice | ai.onnx(1-9,10,11-12,13+) | | +| Softmax | ai.onnx(1-10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | | Squeeze | ai.onnx(1-10,11-12,13+) | | diff --git a/js/web/lib/onnxjs/tensor.ts b/js/web/lib/onnxjs/tensor.ts index 9ab076e0de..1a4c1dfe74 100644 --- a/js/web/lib/onnxjs/tensor.ts +++ b/js/web/lib/onnxjs/tensor.ts @@ -22,6 +22,7 @@ export declare namespace Tensor { uint16: Uint16Array; int32: Int32Array; uint32: Uint32Array; + int64: BigInt64Array; } export type DataType = keyof DataTypeMap; @@ -409,6 +410,8 @@ function dataviewConstructor(type: Tensor.DataType) { return Int32Array; case 'uint32': return Uint32Array; + case 'int64': + return BigInt64Array; case 'float32': return Float32Array; case 'float64': diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 722993be1e..0a76d75e79 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -210,7 +210,7 @@ export class BroadcastUtil { // both inputs are scalars if (outputShape.length === 0) { - c.set([], op(a.get([]), b.get([]))); + c.set([], op(a.get([]) as number, b.get([]) as number)); } // atleast one input is a non-scalar @@ -223,11 +223,11 @@ export class BroadcastUtil { let isAScalar = false; let isBScalar = false; if (a.dims.length === 0) { - valA = a.get([]); + valA = a.get([]) as number; isAScalar = true; } if (b.dims.length === 0) { - valB = b.get([]); + valB = b.get([]) as number; isBScalar = true; } let rest: number; @@ -242,11 +242,11 @@ export class BroadcastUtil { if (!isAScalar) { // map outputIndices (which is actually broadcasted) to the originalIndices BroadcastUtil.fillIndex(outputIndices, a.dims, originalIndicesA); - valA = a.get(originalIndicesA); + valA = a.get(originalIndicesA) as number; } if (!isBScalar) { BroadcastUtil.fillIndex(outputIndices, b.dims, originalIndicesB); - valB = b.get(originalIndicesB); + valB = b.get(originalIndicesB) as number; } c.set(outputIndices, op(valA, valB)); diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ef05483a6d..4fa468cde4 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; @@ -13,6 +14,7 @@ import * as pool from './ops/pool'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSliceAttributes, slice} from './ops/slice'; +import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; @@ -27,6 +29,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Acos', [unaryOps.acos]], ['Acosh', [unaryOps.acosh]], ['Add', [binaryOps.add]], + ['ArgMax', [argMax, parseArgMinMaxAttributes]], + ['ArgMin', [argMin, parseArgMinMaxAttributes]], ['Asin', [unaryOps.asin]], ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], @@ -77,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Slice', [slice, parseSliceAttributes]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], + ['Softmax', [softmax, parseSoftmaxAttributes]], ['Sub', [binaryOps.sub]], ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts new file mode 100644 index 0000000000..35aac1c245 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TODO: this is the same naive implementation we use for reduce that has +// performance limitations when the reduced axis is long. Need to add +// a optimized codepath for this. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length === 0 || inputs.length > 2) { + throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Invalid input type.'); + } +}; + +export interface ArgMinMaxAttributes extends AttributeWithCacheKey { + keepDims: boolean; + axes: number; + selectLastIndex: number; +} + +type ArgMinMaxOp = (inputs: readonly TensorView[], axes: number[]) => string[]; + +const createReduceProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: ArgMinMaxAttributes, + argMinMaxOp: ArgMinMaxOp): ProgramInfo => { + const outputShape: number[] = []; + const inputShape = inputs[0].dims; + + const idxCopy: string[] = []; // copy output indexes to input indexes + + const axes = ShapeUtil.normalizeAxes([attributes.axes], inputs[0].dims.length); + const outputDimsLength = inputs[0].dims.length - (attributes.keepDims ? 0 : axes.length); + const ops = argMinMaxOp(inputs, axes); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const initInputIdx = (ops[1] === '') ? '' : `let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};`; + let reduceOps = ` + let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')}; + ${ops[2]};`; + for (let k = 0; k < inputs[0].dims.length; k++) { + // if this axis is reduced + if (axes.indexOf(k) >= 0) { + if (attributes.keepDims) { + outputShape.push(1); + } + // loop over the d-th axis + reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { + let lastIndex = j${k}; + inputIndices[${k}] = lastIndex; + ${reduceOps} + }`; + } else { + if (outputDimsLength > 1) { + idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`); + } else { + idxCopy.push(`inputIndices[${k}] = outputIndices;`); + } + outputShape.push(inputs[0].dims[k]); + } + } + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const outputSize = ShapeUtil.size(outputShape); + const dataType = 'f32'; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + @group(0) @binding(0) var _A : array<${dataType}>; + @group(0) @binding(1) var output : array; + + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + + ${idxCopy.join('\n')} + ${ops[0]} // init ops + ${initInputIdx} + ${ops[1]} + ${reduceOps} + ${ops[3]} // final values + output[global_idx*2] = bestIndex; // result it int64 + }`; + + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType: DataType.int64, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64)}) + }; + }; + +const createArgMinMaxAttributesFromInputs = + (inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes => + createAttributeWithCacheKey( + {axes: attributes.axes, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex}); + +const createReduceProgramInfoLoader = + (inputs: readonly TensorView[], name: string, attributes: ArgMinMaxAttributes, reduceOp: ArgMinMaxOp): + ProgramInfoLoader => { + const updatedAttributes: ArgMinMaxAttributes = + inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(inputs, attributes); + const metadata: + ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; + return {...metadata, get: () => createReduceProgramInfo(metadata, [inputs[0]], updatedAttributes, reduceOp)}; + }; + + +export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { + validateInputs(context.inputs); + const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIndices[${k}] = 0;`); // first element + } + } + return [ + `${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;', + 'if (_A[inputIdx] < value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', '' + ]; + }; + context.compute(createReduceProgramInfoLoader(context.inputs, 'ArgMin', attributes, argMinMaxOp), {inputs: [0]}); +}; + +export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { + validateInputs(context.inputs); + const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIndices[${k}] = 0;`); // first element + } + } + return [ + `${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;', + 'if (_A[inputIdx] > value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', '' + ]; + }; + context.compute(createReduceProgramInfoLoader(context.inputs, 'argMax', attributes, argMinMaxOp), {inputs: [0]}); +}; + +export const parseArgMinMaxAttributes = (attributes: Record): ArgMinMaxAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts new file mode 100644 index 0000000000..bdbf05e2f1 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TODO: this is the same naive implementation we use for reduce that has +// performance limitations when the reduced axis is long. Need to add +// a optimized codepath for this. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; + +import {ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Softmax op requires 1 input.'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Softmax input needs to be float.'); + } +}; + +export interface SoftmaxAttributes extends AttributeWithCacheKey { + readonly axis: number; +} + +export const softmaxProgramMetadata = { + name: 'Softmax', + inputTypes: [GpuDataType.default] +}; + + +const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => { + const dataType = 'f32'; + const shape = input.dims; + const outputSize = ShapeUtil.size(shape); + const WG = 64; + let axis = attributes.axis; + if (axis < 0) { + axis = shape.length + axis; + } + if (axis < shape.length - 1) { + throw new Error('softmax only supports last axis for now.'); + } + + const cols = shape[axis]; + const rows = outputSize / cols; + + const getShaderSource = (_shaderHelper: ShaderHelper) => ` + var rowMaxShared : ${dataType}; + var rowSumShared : ${dataType}; + var threadShared : array<${dataType}, ${WG}>; + + @group(0) @binding(0) var x : array<${dataType}>; + @group(0) @binding(1) var result : array<${dataType}>; + + fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + let index = row * row_stride + col; + return x[index]; + } + + fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + let index = row * row_stride + col; + result[index] = value; + } + + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(local_invocation_id) local_id : vec3, @builtin(global_invocation_id) global_id : vec3u) { + let gindex = i32(global_id.x); + let lindex = i32(local_id.x); + const wg = ${WG}; + let row = gindex / wg; + let cols = ${cols}; + let row_stride : i32 = ${cols}; + + // find the rows max + var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec + for (var col = lindex; col < cols; col += wg) { + let value = getValue(row, col, row_stride); + threadMax = max(threadMax, value); + } + if (lindex < cols) { + threadShared[lindex] = threadMax; + } + workgroupBarrier(); + + var reduceSize = min(cols, wg); + for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) { + reduceSize = currSize + (reduceSize & 1); + if (lindex < currSize) { + threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]); + } + workgroupBarrier(); + } + if (lindex == 0) { + rowMaxShared = threadShared[0]; + } + workgroupBarrier(); + + // find the rows sum + var threadSum = 0.0; + for (var col = lindex; col < cols; col += wg) { + let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); + threadSum += subExp; + } + threadShared[lindex] = threadSum; + workgroupBarrier(); + + for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) { + if (lindex < currSize) { + threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize]; + } + workgroupBarrier(); + } + if (lindex == 0) { + rowSumShared = threadShared[0]; + } + workgroupBarrier(); + + // calculate final value for each element in the row + for (var col = lindex; col < cols; col += wg) { + let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + setValue(row, col, row_stride, value); + } + }`; + return { + ...softmaxProgramMetadata, + outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: rows}) + }; +}; + + +export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => { + validateInputs(context.inputs); + context.compute({ + ...softmaxProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createSoftmaxProgramInfo(context.inputs[0], attributes) + }); +}; + +export const parseSoftmaxAttributes = (attributes: Record): SoftmaxAttributes => + createAttributeWithCacheKey({axis: attributes.axis as number}); diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c894f0c58f..00ac7acfc9 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -305,38 +305,38 @@ // "test_and2d", // "test_and3d", // "test_and4d", - // // "test_argmax_default_axis_example_select_last_index", - // // "test_argmax_default_axis_example", - // // "test_argmax_default_axis_random_select_last_index", - // // "test_argmax_default_axis_random", - // // "test_argmax_keepdims_example_select_last_index", - // // "test_argmax_keepdims_example", - // // "test_argmax_keepdims_random_select_last_index", - // // "test_argmax_keepdims_random", - // // "test_argmax_negative_axis_keepdims_example_select_last_index", - // // "test_argmax_negative_axis_keepdims_example", - // // "test_argmax_negative_axis_keepdims_random_select_last_index", - // // "test_argmax_negative_axis_keepdims_random", - // // "test_argmax_no_keepdims_example_select_last_index", - // // "test_argmax_no_keepdims_example", - // // "test_argmax_no_keepdims_random_select_last_index", - // // "test_argmax_no_keepdims_random", - // // "test_argmin_default_axis_example_select_last_index", - // // "test_argmin_default_axis_example", - // // "test_argmin_default_axis_random_select_last_index", - // // "test_argmin_default_axis_random", - // // "test_argmin_keepdims_example_select_last_index", - // // "test_argmin_keepdims_example", - // // "test_argmin_keepdims_random_select_last_index", - // // "test_argmin_keepdims_random", - // // "test_argmin_negative_axis_keepdims_example_select_last_index", - // // "test_argmin_negative_axis_keepdims_example", - // // "test_argmin_negative_axis_keepdims_random_select_last_index", - // // "test_argmin_negative_axis_keepdims_random", - // // "test_argmin_no_keepdims_example_select_last_index", - // // "test_argmin_no_keepdims_example", - // // "test_argmin_no_keepdims_random_select_last_index", - // // "test_argmin_no_keepdims_random", + // "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_example", + // "test_argmax_default_axis_random_select_last_index", + "test_argmax_default_axis_random", + // "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_example", + // "test_argmax_keepdims_random_select_last_index", + "test_argmax_keepdims_random", + // "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_example", + // "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_random", + // "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_example", + // "test_argmax_no_keepdims_random_select_last_index", + "test_argmax_no_keepdims_random", + // "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_example", + // "test_argmin_default_axis_random_select_last_index", + "test_argmin_default_axis_random", + // "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_example", + // "test_argmin_keepdims_random_select_last_index", + "test_argmin_keepdims_random", + // "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_example", + // "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_random", + // "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_example", + // "test_argmin_no_keepdims_random_select_last_index", + "test_argmin_no_keepdims_random", "test_asin_example", "test_asin", "test_asinh_example", @@ -1133,8 +1133,8 @@ // "test_softmax_axis_0", // "test_softmax_axis_1_expanded", // "test_softmax_axis_1", - // "test_softmax_axis_2_expanded", - // "test_softmax_axis_2", + "test_softmax_axis_2_expanded", + "test_softmax_axis_2", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", @@ -1203,14 +1203,14 @@ // "test_softmax_cross_entropy_sum_log_prob_expanded", // "test_softmax_cross_entropy_sum_log_prob", // "test_softmax_cross_entropy_sum", - // "test_softmax_default_axis_expanded", - // "test_softmax_default_axis", - // "test_softmax_example_expanded", - // "test_softmax_example", - // "test_softmax_large_number_expanded", - // "test_softmax_large_number", - // "test_softmax_negative_axis_expanded", - // "test_softmax_negative_axis", + "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis", + "test_softmax_example_expanded", + "test_softmax_example", + "test_softmax_large_number_expanded", + "test_softmax_large_number", + "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis", // // "test_softplus_example", // // "test_softplus", // // "test_softsign_example", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 0f9febec89..d923837326 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -388,6 +388,7 @@ export class TensorResultValidator { case 'int16': case 'int32': case 'uint32': + case 'int64': case 'bool': return this.integerEqual( actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 68ec51c6f0..dba68137c7 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -237,6 +237,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat); @@ -441,6 +452,18 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/argminmax.cc b/onnxruntime/core/providers/js/operators/argminmax.cc new file mode 100644 index 0000000000..08d886269b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/argminmax.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "argminmax.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMinMaxOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ArgMinMaxOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ArgMinMaxOp); + +#define REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMinMaxOp, sinceVersion) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ArgMinMaxOp, \ + kOnnxDomain, \ + sinceVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ArgMinMaxOp); + +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12); +REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMax, 13); + +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); +REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMin, 13); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/argminmax.h b/onnxruntime/core/providers/js/operators/argminmax.h new file mode 100644 index 0000000000..c3d782372c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/argminmax.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/reduction/reduction_ops.h" + +namespace onnxruntime { +namespace js { +#define JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMinMaxKernel) \ + template \ + class ArgMinMaxKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::select_last_index_; \ + using ReduceKernelBase::keepdims_; \ + ArgMinMaxKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + axes.push_back(-1); \ + } \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + JSEP_INIT_KERNEL_ATTRIBUTE(ArgMinMaxKernel, ({ \ + "keepDims" : !!$1, \ + "selectLastIndex" : !!$2, \ + "axes" : $3, \ + }), \ + static_cast(keepdims_), \ + static_cast(select_last_index_), \ + gsl::narrow_cast(axes[0])); \ + } \ + }; + +JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMax); +JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMin); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/softmax.cc b/onnxruntime/core/providers/js/operators/softmax.cc new file mode 100644 index 0000000000..cbaecf9e4c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/softmax.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "softmax.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + SoftmaxOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SoftmaxOp); + +#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SoftmaxOp, \ + kOnnxDomain, \ + sinceVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + SoftmaxOp); + +REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10); +REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12); +REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(Softmax, 13); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/softmax.h b/onnxruntime/core/providers/js/operators/softmax.h new file mode 100644 index 0000000000..068a59e6b2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/softmax.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/reduction/reduction_ops.h" + +namespace onnxruntime { +namespace js { +template +class Softmax : public JsKernel { + public: + Softmax(const OpKernelInfo& info) : JsKernel(info) { + const auto& node = info.node(); + opset_ = node.SinceVersion(); + + int64_t axis; + Status status = info.GetAttr("axis", &axis); + + if (status.IsOK()) { + axis_ = gsl::narrow_cast(axis); + } else { + if (opset_ < 13) { + axis_ = 1; // opset-12 and below, the default axis value is 1 + } else { + axis_ = -1; // opset-13, the default axis value is -1 + } + } + JSEP_INIT_KERNEL_ATTRIBUTE(Softmax, ({ + "axis" : $1 + }), + axis_); + } + + private: + int axis_; + int opset_; +}; + +} // namespace js +} // namespace onnxruntime