diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 991b01ea8a..7b6d72bc78 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -14,6 +14,8 @@ Do not modify directly.* | Acos | ai.onnx(7+) | | | Acosh | ai.onnx(9+) | | | Add | ai.onnx(7-12,13,14+) | | +| ArgMax | ai.onnx(1-10,11-12,13+) | | +| ArgMin | ai.onnx(1-10,11-12,13+) | | | Asin | ai.onnx(7+) | | | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | @@ -64,6 +66,7 @@ Do not modify directly.* | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | | Slice | ai.onnx(1-9,10,11-12,13+) | | +| Softmax | ai.onnx(1-10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | | Squeeze | ai.onnx(1-10,11-12,13+) | | diff --git a/js/web/lib/onnxjs/tensor.ts b/js/web/lib/onnxjs/tensor.ts index 9ab076e0de..1a4c1dfe74 100644 --- a/js/web/lib/onnxjs/tensor.ts +++ b/js/web/lib/onnxjs/tensor.ts @@ -22,6 +22,7 @@ export declare namespace Tensor { uint16: Uint16Array; int32: Int32Array; uint32: Uint32Array; + int64: BigInt64Array; } export type DataType = keyof DataTypeMap; @@ -409,6 +410,8 @@ function dataviewConstructor(type: Tensor.DataType) { return Int32Array; case 'uint32': return Uint32Array; + case 'int64': + return BigInt64Array; case 'float32': return Float32Array; case 'float64': diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 722993be1e..0a76d75e79 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -210,7 +210,7 @@ export class BroadcastUtil { // both inputs are scalars if (outputShape.length === 0) { - c.set([], op(a.get([]), b.get([]))); + c.set([], op(a.get([]) as number, b.get([]) as number)); } // atleast one input is a non-scalar @@ -223,11 +223,11 @@ export class BroadcastUtil { let isAScalar = false; let isBScalar = false; if (a.dims.length === 0) { - valA = a.get([]); + valA = a.get([]) as number; isAScalar = true; } if (b.dims.length === 0) { - valB = b.get([]); + valB = b.get([]) as number; isBScalar = true; } let rest: number; @@ -242,11 +242,11 @@ export class BroadcastUtil { if (!isAScalar) { // map outputIndices (which is actually broadcasted) to the originalIndices BroadcastUtil.fillIndex(outputIndices, a.dims, originalIndicesA); - valA = a.get(originalIndicesA); + valA = a.get(originalIndicesA) as number; } if (!isBScalar) { BroadcastUtil.fillIndex(outputIndices, b.dims, originalIndicesB); - valB = b.get(originalIndicesB); + valB = b.get(originalIndicesB) as number; } c.set(outputIndices, op(valA, valB)); diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ef05483a6d..4fa468cde4 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; @@ -13,6 +14,7 @@ import * as pool from './ops/pool'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSliceAttributes, slice} from './ops/slice'; +import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; @@ -27,6 +29,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Acos', [unaryOps.acos]], ['Acosh', [unaryOps.acosh]], ['Add', [binaryOps.add]], + ['ArgMax', [argMax, parseArgMinMaxAttributes]], + ['ArgMin', [argMin, parseArgMinMaxAttributes]], ['Asin', [unaryOps.asin]], ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], @@ -77,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Slice', [slice, parseSliceAttributes]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], + ['Softmax', [softmax, parseSoftmaxAttributes]], ['Sub', [binaryOps.sub]], ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts new file mode 100644 index 0000000000..35aac1c245 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TODO: this is the same naive implementation we use for reduce that has +// performance limitations when the reduced axis is long. Need to add +// a optimized codepath for this. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length === 0 || inputs.length > 2) { + throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Invalid input type.'); + } +}; + +export interface ArgMinMaxAttributes extends AttributeWithCacheKey { + keepDims: boolean; + axes: number; + selectLastIndex: number; +} + +type ArgMinMaxOp = (inputs: readonly TensorView[], axes: number[]) => string[]; + +const createReduceProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: ArgMinMaxAttributes, + argMinMaxOp: ArgMinMaxOp): ProgramInfo => { + const outputShape: number[] = []; + const inputShape = inputs[0].dims; + + const idxCopy: string[] = []; // copy output indexes to input indexes + + const axes = ShapeUtil.normalizeAxes([attributes.axes], inputs[0].dims.length); + const outputDimsLength = inputs[0].dims.length - (attributes.keepDims ? 0 : axes.length); + const ops = argMinMaxOp(inputs, axes); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const initInputIdx = (ops[1] === '') ? '' : `let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};`; + let reduceOps = ` + let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')}; + ${ops[2]};`; + for (let k = 0; k < inputs[0].dims.length; k++) { + // if this axis is reduced + if (axes.indexOf(k) >= 0) { + if (attributes.keepDims) { + outputShape.push(1); + } + // loop over the d-th axis + reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { + let lastIndex = j${k}; + inputIndices[${k}] = lastIndex; + ${reduceOps} + }`; + } else { + if (outputDimsLength > 1) { + idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`); + } else { + idxCopy.push(`inputIndices[${k}] = outputIndices;`); + } + outputShape.push(inputs[0].dims[k]); + } + } + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const outputSize = ShapeUtil.size(outputShape); + const dataType = 'f32'; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + @group(0) @binding(0) var _A : array<${dataType}>; + @group(0) @binding(1) var output : array; + + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + + ${idxCopy.join('\n')} + ${ops[0]} // init ops + ${initInputIdx} + ${ops[1]} + ${reduceOps} + ${ops[3]} // final values + output[global_idx*2] = bestIndex; // result it int64 + }`; + + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType: DataType.int64, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64)}) + }; + }; + +const createArgMinMaxAttributesFromInputs = + (inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes => + createAttributeWithCacheKey( + {axes: attributes.axes, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex}); + +const createReduceProgramInfoLoader = + (inputs: readonly TensorView[], name: string, attributes: ArgMinMaxAttributes, reduceOp: ArgMinMaxOp): + ProgramInfoLoader => { + const updatedAttributes: ArgMinMaxAttributes = + inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(inputs, attributes); + const metadata: + ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; + return {...metadata, get: () => createReduceProgramInfo(metadata, [inputs[0]], updatedAttributes, reduceOp)}; + }; + + +export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { + validateInputs(context.inputs); + const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIndices[${k}] = 0;`); // first element + } + } + return [ + `${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;', + 'if (_A[inputIdx] < value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', '' + ]; + }; + context.compute(createReduceProgramInfoLoader(context.inputs, 'ArgMin', attributes, argMinMaxOp), {inputs: [0]}); +}; + +export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { + validateInputs(context.inputs); + const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIndices[${k}] = 0;`); // first element + } + } + return [ + `${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;', + 'if (_A[inputIdx] > value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', '' + ]; + }; + context.compute(createReduceProgramInfoLoader(context.inputs, 'argMax', attributes, argMinMaxOp), {inputs: [0]}); +}; + +export const parseArgMinMaxAttributes = (attributes: Record): ArgMinMaxAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts new file mode 100644 index 0000000000..bdbf05e2f1 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TODO: this is the same naive implementation we use for reduce that has +// performance limitations when the reduced axis is long. Need to add +// a optimized codepath for this. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; + +import {ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Softmax op requires 1 input.'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Softmax input needs to be float.'); + } +}; + +export interface SoftmaxAttributes extends AttributeWithCacheKey { + readonly axis: number; +} + +export const softmaxProgramMetadata = { + name: 'Softmax', + inputTypes: [GpuDataType.default] +}; + + +const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => { + const dataType = 'f32'; + const shape = input.dims; + const outputSize = ShapeUtil.size(shape); + const WG = 64; + let axis = attributes.axis; + if (axis < 0) { + axis = shape.length + axis; + } + if (axis < shape.length - 1) { + throw new Error('softmax only supports last axis for now.'); + } + + const cols = shape[axis]; + const rows = outputSize / cols; + + const getShaderSource = (_shaderHelper: ShaderHelper) => ` + var rowMaxShared : ${dataType}; + var rowSumShared : ${dataType}; + var threadShared : array<${dataType}, ${WG}>; + + @group(0) @binding(0) var x : array<${dataType}>; + @group(0) @binding(1) var result : array<${dataType}>; + + fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + let index = row * row_stride + col; + return x[index]; + } + + fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + let index = row * row_stride + col; + result[index] = value; + } + + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(local_invocation_id) local_id : vec3, @builtin(global_invocation_id) global_id : vec3u) { + let gindex = i32(global_id.x); + let lindex = i32(local_id.x); + const wg = ${WG}; + let row = gindex / wg; + let cols = ${cols}; + let row_stride : i32 = ${cols}; + + // find the rows max + var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec + for (var col = lindex; col < cols; col += wg) { + let value = getValue(row, col, row_stride); + threadMax = max(threadMax, value); + } + if (lindex < cols) { + threadShared[lindex] = threadMax; + } + workgroupBarrier(); + + var reduceSize = min(cols, wg); + for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) { + reduceSize = currSize + (reduceSize & 1); + if (lindex < currSize) { + threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]); + } + workgroupBarrier(); + } + if (lindex == 0) { + rowMaxShared = threadShared[0]; + } + workgroupBarrier(); + + // find the rows sum + var threadSum = 0.0; + for (var col = lindex; col < cols; col += wg) { + let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); + threadSum += subExp; + } + threadShared[lindex] = threadSum; + workgroupBarrier(); + + for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) { + if (lindex < currSize) { + threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize]; + } + workgroupBarrier(); + } + if (lindex == 0) { + rowSumShared = threadShared[0]; + } + workgroupBarrier(); + + // calculate final value for each element in the row + for (var col = lindex; col < cols; col += wg) { + let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + setValue(row, col, row_stride, value); + } + }`; + return { + ...softmaxProgramMetadata, + outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: rows}) + }; +}; + + +export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => { + validateInputs(context.inputs); + context.compute({ + ...softmaxProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createSoftmaxProgramInfo(context.inputs[0], attributes) + }); +}; + +export const parseSoftmaxAttributes = (attributes: Record): SoftmaxAttributes => + createAttributeWithCacheKey({axis: attributes.axis as number}); diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c894f0c58f..00ac7acfc9 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -305,38 +305,38 @@ // "test_and2d", // "test_and3d", // "test_and4d", - // // "test_argmax_default_axis_example_select_last_index", - // // "test_argmax_default_axis_example", - // // "test_argmax_default_axis_random_select_last_index", - // // "test_argmax_default_axis_random", - // // "test_argmax_keepdims_example_select_last_index", - // // "test_argmax_keepdims_example", - // // "test_argmax_keepdims_random_select_last_index", - // // "test_argmax_keepdims_random", - // // "test_argmax_negative_axis_keepdims_example_select_last_index", - // // "test_argmax_negative_axis_keepdims_example", - // // "test_argmax_negative_axis_keepdims_random_select_last_index", - // // "test_argmax_negative_axis_keepdims_random", - // // "test_argmax_no_keepdims_example_select_last_index", - // // "test_argmax_no_keepdims_example", - // // "test_argmax_no_keepdims_random_select_last_index", - // // "test_argmax_no_keepdims_random", - // // "test_argmin_default_axis_example_select_last_index", - // // "test_argmin_default_axis_example", - // // "test_argmin_default_axis_random_select_last_index", - // // "test_argmin_default_axis_random", - // // "test_argmin_keepdims_example_select_last_index", - // // "test_argmin_keepdims_example", - // // "test_argmin_keepdims_random_select_last_index", - // // "test_argmin_keepdims_random", - // // "test_argmin_negative_axis_keepdims_example_select_last_index", - // // "test_argmin_negative_axis_keepdims_example", - // // "test_argmin_negative_axis_keepdims_random_select_last_index", - // // "test_argmin_negative_axis_keepdims_random", - // // "test_argmin_no_keepdims_example_select_last_index", - // // "test_argmin_no_keepdims_example", - // // "test_argmin_no_keepdims_random_select_last_index", - // // "test_argmin_no_keepdims_random", + // "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_example", + // "test_argmax_default_axis_random_select_last_index", + "test_argmax_default_axis_random", + // "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_example", + // "test_argmax_keepdims_random_select_last_index", + "test_argmax_keepdims_random", + // "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_example", + // "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_random", + // "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_example", + // "test_argmax_no_keepdims_random_select_last_index", + "test_argmax_no_keepdims_random", + // "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_example", + // "test_argmin_default_axis_random_select_last_index", + "test_argmin_default_axis_random", + // "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_example", + // "test_argmin_keepdims_random_select_last_index", + "test_argmin_keepdims_random", + // "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_example", + // "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_random", + // "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_example", + // "test_argmin_no_keepdims_random_select_last_index", + "test_argmin_no_keepdims_random", "test_asin_example", "test_asin", "test_asinh_example", @@ -1133,8 +1133,8 @@ // "test_softmax_axis_0", // "test_softmax_axis_1_expanded", // "test_softmax_axis_1", - // "test_softmax_axis_2_expanded", - // "test_softmax_axis_2", + "test_softmax_axis_2_expanded", + "test_softmax_axis_2", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", @@ -1203,14 +1203,14 @@ // "test_softmax_cross_entropy_sum_log_prob_expanded", // "test_softmax_cross_entropy_sum_log_prob", // "test_softmax_cross_entropy_sum", - // "test_softmax_default_axis_expanded", - // "test_softmax_default_axis", - // "test_softmax_example_expanded", - // "test_softmax_example", - // "test_softmax_large_number_expanded", - // "test_softmax_large_number", - // "test_softmax_negative_axis_expanded", - // "test_softmax_negative_axis", + "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis", + "test_softmax_example_expanded", + "test_softmax_example", + "test_softmax_large_number_expanded", + "test_softmax_large_number", + "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis", // // "test_softplus_example", // // "test_softplus", // // "test_softsign_example", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 0f9febec89..d923837326 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -388,6 +388,7 @@ export class TensorResultValidator { case 'int16': case 'int32': case 'uint32': + case 'int64': case 'bool': return this.integerEqual( actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 68ec51c6f0..dba68137c7 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -237,6 +237,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat); @@ -441,6 +452,18 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/argminmax.cc b/onnxruntime/core/providers/js/operators/argminmax.cc new file mode 100644 index 0000000000..08d886269b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/argminmax.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "argminmax.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMinMaxOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ArgMinMaxOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ArgMinMaxOp); + +#define REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMinMaxOp, sinceVersion) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ArgMinMaxOp, \ + kOnnxDomain, \ + sinceVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ArgMinMaxOp); + +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12); +REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMax, 13); + +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); +REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMin, 13); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/argminmax.h b/onnxruntime/core/providers/js/operators/argminmax.h new file mode 100644 index 0000000000..c3d782372c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/argminmax.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/reduction/reduction_ops.h" + +namespace onnxruntime { +namespace js { +#define JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMinMaxKernel) \ + template \ + class ArgMinMaxKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::select_last_index_; \ + using ReduceKernelBase::keepdims_; \ + ArgMinMaxKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + axes.push_back(-1); \ + } \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + JSEP_INIT_KERNEL_ATTRIBUTE(ArgMinMaxKernel, ({ \ + "keepDims" : !!$1, \ + "selectLastIndex" : !!$2, \ + "axes" : $3, \ + }), \ + static_cast(keepdims_), \ + static_cast(select_last_index_), \ + gsl::narrow_cast(axes[0])); \ + } \ + }; + +JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMax); +JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMin); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/softmax.cc b/onnxruntime/core/providers/js/operators/softmax.cc new file mode 100644 index 0000000000..cbaecf9e4c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/softmax.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "softmax.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + SoftmaxOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SoftmaxOp); + +#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SoftmaxOp, \ + kOnnxDomain, \ + sinceVersion, \ + float, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + SoftmaxOp); + +REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10); +REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12); +REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(Softmax, 13); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/softmax.h b/onnxruntime/core/providers/js/operators/softmax.h new file mode 100644 index 0000000000..068a59e6b2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/softmax.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/reduction/reduction_ops.h" + +namespace onnxruntime { +namespace js { +template +class Softmax : public JsKernel { + public: + Softmax(const OpKernelInfo& info) : JsKernel(info) { + const auto& node = info.node(); + opset_ = node.SinceVersion(); + + int64_t axis; + Status status = info.GetAttr("axis", &axis); + + if (status.IsOK()) { + axis_ = gsl::narrow_cast(axis); + } else { + if (opset_ < 13) { + axis_ = 1; // opset-12 and below, the default axis value is 1 + } else { + axis_ = -1; // opset-13, the default axis value is -1 + } + } + JSEP_INIT_KERNEL_ATTRIBUTE(Softmax, ({ + "axis" : $1 + }), + axis_); + } + + private: + int axis_; + int opset_; +}; + +} // namespace js +} // namespace onnxruntime