mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
js/webgpu: argmax,argmin,softmax support (#16882)
argmax and argmin are similar to reduce. Eventually we need to add optimized flavors of the shader. softmax is optimized but only works on the last axis for now which should be the common use case. todo: enable more ut for argmax/argmin
This commit is contained in:
parent
506ddb3d5d
commit
0df2e14038
13 changed files with 544 additions and 47 deletions
|
|
@ -14,6 +14,8 @@ Do not modify directly.*
|
|||
| Acos | ai.onnx(7+) | |
|
||||
| Acosh | ai.onnx(9+) | |
|
||||
| Add | ai.onnx(7-12,13,14+) | |
|
||||
| ArgMax | ai.onnx(1-10,11-12,13+) | |
|
||||
| ArgMin | ai.onnx(1-10,11-12,13+) | |
|
||||
| Asin | ai.onnx(7+) | |
|
||||
| Asinh | ai.onnx(9+) | |
|
||||
| Atan | ai.onnx(7+) | |
|
||||
|
|
@ -64,6 +66,7 @@ Do not modify directly.*
|
|||
| Sin | ai.onnx(7+) | |
|
||||
| Sinh | ai.onnx(9+) | |
|
||||
| Slice | ai.onnx(1-9,10,11-12,13+) | |
|
||||
| Softmax | ai.onnx(1-10,11-12,13+) | |
|
||||
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
|
||||
| Sqrt | ai.onnx(6-12,13+) | |
|
||||
| Squeeze | ai.onnx(1-10,11-12,13+) | |
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ export declare namespace Tensor {
|
|||
uint16: Uint16Array;
|
||||
int32: Int32Array;
|
||||
uint32: Uint32Array;
|
||||
int64: BigInt64Array;
|
||||
}
|
||||
|
||||
export type DataType = keyof DataTypeMap;
|
||||
|
|
@ -409,6 +410,8 @@ function dataviewConstructor(type: Tensor.DataType) {
|
|||
return Int32Array;
|
||||
case 'uint32':
|
||||
return Uint32Array;
|
||||
case 'int64':
|
||||
return BigInt64Array;
|
||||
case 'float32':
|
||||
return Float32Array;
|
||||
case 'float64':
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ export class BroadcastUtil {
|
|||
|
||||
// both inputs are scalars
|
||||
if (outputShape.length === 0) {
|
||||
c.set([], op(a.get([]), b.get([])));
|
||||
c.set([], op(a.get([]) as number, b.get([]) as number));
|
||||
}
|
||||
|
||||
// atleast one input is a non-scalar
|
||||
|
|
@ -223,11 +223,11 @@ export class BroadcastUtil {
|
|||
let isAScalar = false;
|
||||
let isBScalar = false;
|
||||
if (a.dims.length === 0) {
|
||||
valA = a.get([]);
|
||||
valA = a.get([]) as number;
|
||||
isAScalar = true;
|
||||
}
|
||||
if (b.dims.length === 0) {
|
||||
valB = b.get([]);
|
||||
valB = b.get([]) as number;
|
||||
isBScalar = true;
|
||||
}
|
||||
let rest: number;
|
||||
|
|
@ -242,11 +242,11 @@ export class BroadcastUtil {
|
|||
if (!isAScalar) {
|
||||
// map outputIndices (which is actually broadcasted) to the originalIndices
|
||||
BroadcastUtil.fillIndex(outputIndices, a.dims, originalIndicesA);
|
||||
valA = a.get(originalIndicesA);
|
||||
valA = a.get(originalIndicesA) as number;
|
||||
}
|
||||
if (!isBScalar) {
|
||||
BroadcastUtil.fillIndex(outputIndices, b.dims, originalIndicesB);
|
||||
valB = b.get(originalIndicesB);
|
||||
valB = b.get(originalIndicesB) as number;
|
||||
}
|
||||
|
||||
c.set(outputIndices, op(valA, valB));
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
|
||||
import * as binaryOps from './ops/binary-op';
|
||||
import {concat, parseConcatAttributes} from './ops/concat';
|
||||
import {conv, parseConvAttributes} from './ops/conv';
|
||||
|
|
@ -13,6 +14,7 @@ import * as pool from './ops/pool';
|
|||
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
|
||||
import {parseResizeAttributes, resize} from './ops/resize';
|
||||
import {parseSliceAttributes, slice} from './ops/slice';
|
||||
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
|
||||
import {parseSplitAttributes, split} from './ops/split';
|
||||
import {parseTransposeAttributes, transpose} from './ops/transpose';
|
||||
import * as unaryOps from './ops/unary-op';
|
||||
|
|
@ -27,6 +29,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Acos', [unaryOps.acos]],
|
||||
['Acosh', [unaryOps.acosh]],
|
||||
['Add', [binaryOps.add]],
|
||||
['ArgMax', [argMax, parseArgMinMaxAttributes]],
|
||||
['ArgMin', [argMin, parseArgMinMaxAttributes]],
|
||||
['Asin', [unaryOps.asin]],
|
||||
['Asinh', [unaryOps.asinh]],
|
||||
['Atan', [unaryOps.atan]],
|
||||
|
|
@ -77,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Slice', [slice, parseSliceAttributes]],
|
||||
['Split', [split, parseSplitAttributes]],
|
||||
['Sqrt', [unaryOps.sqrt]],
|
||||
['Softmax', [softmax, parseSoftmaxAttributes]],
|
||||
['Sub', [binaryOps.sub]],
|
||||
['Tan', [unaryOps.tan]],
|
||||
['Tanh', [unaryOps.tanh]],
|
||||
|
|
|
|||
156
js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
Normal file
156
js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// TODO: this is the same naive implementation we use for reduce that has
|
||||
// performance limitations when the reduced axis is long. Need to add
|
||||
// a optimized codepath for this.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {createIndicesHelper, ShaderHelper} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length === 0 || inputs.length > 2) {
|
||||
throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Invalid input type.');
|
||||
}
|
||||
};
|
||||
|
||||
export interface ArgMinMaxAttributes extends AttributeWithCacheKey {
|
||||
keepDims: boolean;
|
||||
axes: number;
|
||||
selectLastIndex: number;
|
||||
}
|
||||
|
||||
type ArgMinMaxOp = (inputs: readonly TensorView[], axes: number[]) => string[];
|
||||
|
||||
const createReduceProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: ArgMinMaxAttributes,
|
||||
argMinMaxOp: ArgMinMaxOp): ProgramInfo => {
|
||||
const outputShape: number[] = [];
|
||||
const inputShape = inputs[0].dims;
|
||||
|
||||
const idxCopy: string[] = []; // copy output indexes to input indexes
|
||||
|
||||
const axes = ShapeUtil.normalizeAxes([attributes.axes], inputs[0].dims.length);
|
||||
const outputDimsLength = inputs[0].dims.length - (attributes.keepDims ? 0 : axes.length);
|
||||
const ops = argMinMaxOp(inputs, axes);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const initInputIdx = (ops[1] === '') ? '' : `let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};`;
|
||||
let reduceOps = `
|
||||
let inputIdx = ${inputIndicesHelper.i2oExpression('inputIndices')};
|
||||
${ops[2]};`;
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
// if this axis is reduced
|
||||
if (axes.indexOf(k) >= 0) {
|
||||
if (attributes.keepDims) {
|
||||
outputShape.push(1);
|
||||
}
|
||||
// loop over the d-th axis
|
||||
reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
|
||||
let lastIndex = j${k};
|
||||
inputIndices[${k}] = lastIndex;
|
||||
${reduceOps}
|
||||
}`;
|
||||
} else {
|
||||
if (outputDimsLength > 1) {
|
||||
idxCopy.push(`inputIndices[${k}] = outputIndices[${outputShape.length}];`);
|
||||
} else {
|
||||
idxCopy.push(`inputIndices[${k}] = outputIndices;`);
|
||||
}
|
||||
outputShape.push(inputs[0].dims[k]);
|
||||
}
|
||||
}
|
||||
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const dataType = 'f32';
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
@group(0) @binding(0) var<storage, read> _A : array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> output : array<i32>;
|
||||
|
||||
${outputIndicesHelper.o2iImpl}
|
||||
${inputIndicesHelper.i2oImpl}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
|
||||
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
|
||||
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
|
||||
|
||||
${idxCopy.join('\n')}
|
||||
${ops[0]} // init ops
|
||||
${initInputIdx}
|
||||
${ops[1]}
|
||||
${reduceOps}
|
||||
${ops[3]} // final values
|
||||
output[global_idx*2] = bestIndex; // result it int64
|
||||
}`;
|
||||
|
||||
return {
|
||||
...metadata,
|
||||
getShaderSource,
|
||||
outputs: [{dims: outputShape, dataType: DataType.int64, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64)})
|
||||
};
|
||||
};
|
||||
|
||||
const createArgMinMaxAttributesFromInputs =
|
||||
(inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes =>
|
||||
createAttributeWithCacheKey(
|
||||
{axes: attributes.axes, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex});
|
||||
|
||||
const createReduceProgramInfoLoader =
|
||||
(inputs: readonly TensorView[], name: string, attributes: ArgMinMaxAttributes, reduceOp: ArgMinMaxOp):
|
||||
ProgramInfoLoader => {
|
||||
const updatedAttributes: ArgMinMaxAttributes =
|
||||
inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(inputs, attributes);
|
||||
const metadata:
|
||||
ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
|
||||
return {...metadata, get: () => createReduceProgramInfo(metadata, [inputs[0]], updatedAttributes, reduceOp)};
|
||||
};
|
||||
|
||||
|
||||
export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => {
|
||||
const idxZero = [];
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
idxZero.push(`inputIndices[${k}] = 0;`); // first element
|
||||
}
|
||||
}
|
||||
return [
|
||||
`${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;',
|
||||
'if (_A[inputIdx] < value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', ''
|
||||
];
|
||||
};
|
||||
context.compute(createReduceProgramInfoLoader(context.inputs, 'ArgMin', attributes, argMinMaxOp), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const argMinMaxOp: ArgMinMaxOp = (inputs: TensorView[], axes: number[]): string[] => {
|
||||
const idxZero = [];
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
idxZero.push(`inputIndices[${k}] = 0;`); // first element
|
||||
}
|
||||
}
|
||||
return [
|
||||
`${idxZero.join('\n')}`, 'var value = _A[inputIdx];\nvar bestIndex : i32 = 0;',
|
||||
'if (_A[inputIdx] > value) {value = _A[inputIdx]; bestIndex = i32(lastIndex);} ', ''
|
||||
];
|
||||
};
|
||||
context.compute(createReduceProgramInfoLoader(context.inputs, 'argMax', attributes, argMinMaxOp), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const parseArgMinMaxAttributes = (attributes: Record<string, unknown>): ArgMinMaxAttributes =>
|
||||
createAttributeWithCacheKey(attributes as Omit<ArgMinMaxAttributes, keyof AttributeWithCacheKey>);
|
||||
147
js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Normal file
147
js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// TODO: this is the same naive implementation we use for reduce that has
|
||||
// performance limitations when the reduced axis is long. Need to add
|
||||
// a optimized codepath for this.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
throw new Error('Softmax op requires 1 input.');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Softmax input needs to be float.');
|
||||
}
|
||||
};
|
||||
|
||||
export interface SoftmaxAttributes extends AttributeWithCacheKey {
|
||||
readonly axis: number;
|
||||
}
|
||||
|
||||
export const softmaxProgramMetadata = {
|
||||
name: 'Softmax',
|
||||
inputTypes: [GpuDataType.default]
|
||||
};
|
||||
|
||||
|
||||
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
|
||||
const dataType = 'f32';
|
||||
const shape = input.dims;
|
||||
const outputSize = ShapeUtil.size(shape);
|
||||
const WG = 64;
|
||||
let axis = attributes.axis;
|
||||
if (axis < 0) {
|
||||
axis = shape.length + axis;
|
||||
}
|
||||
if (axis < shape.length - 1) {
|
||||
throw new Error('softmax only supports last axis for now.');
|
||||
}
|
||||
|
||||
const cols = shape[axis];
|
||||
const rows = outputSize / cols;
|
||||
|
||||
const getShaderSource = (_shaderHelper: ShaderHelper) => `
|
||||
var<workgroup> rowMaxShared : ${dataType};
|
||||
var<workgroup> rowSumShared : ${dataType};
|
||||
var<workgroup> threadShared : array<${dataType}, ${WG}>;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
|
||||
|
||||
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
|
||||
let index = row * row_stride + col;
|
||||
return x[index];
|
||||
}
|
||||
|
||||
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
|
||||
let index = row * row_stride + col;
|
||||
result[index] = value;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(${WG}, 1, 1)
|
||||
fn main(@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(global_invocation_id) global_id : vec3u) {
|
||||
let gindex = i32(global_id.x);
|
||||
let lindex = i32(local_id.x);
|
||||
const wg = ${WG};
|
||||
let row = gindex / wg;
|
||||
let cols = ${cols};
|
||||
let row_stride : i32 = ${cols};
|
||||
|
||||
// find the rows max
|
||||
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let value = getValue(row, col, row_stride);
|
||||
threadMax = max(threadMax, value);
|
||||
}
|
||||
if (lindex < cols) {
|
||||
threadShared[lindex] = threadMax;
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
var reduceSize = min(cols, wg);
|
||||
for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
|
||||
reduceSize = currSize + (reduceSize & 1);
|
||||
if (lindex < currSize) {
|
||||
threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
if (lindex == 0) {
|
||||
rowMaxShared = threadShared[0];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// find the rows sum
|
||||
var threadSum = 0.0;
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
|
||||
threadSum += subExp;
|
||||
}
|
||||
threadShared[lindex] = threadSum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {
|
||||
if (lindex < currSize) {
|
||||
threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
if (lindex == 0) {
|
||||
rowSumShared = threadShared[0];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// calculate final value for each element in the row
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
|
||||
setValue(row, col, row_stride, value);
|
||||
}
|
||||
}`;
|
||||
return {
|
||||
...softmaxProgramMetadata,
|
||||
outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}],
|
||||
getShaderSource,
|
||||
dispatchGroup: () => ({x: rows})
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute({
|
||||
...softmaxProgramMetadata,
|
||||
cacheHint: attributes.cacheKey,
|
||||
get: () => createSoftmaxProgramInfo(context.inputs[0], attributes)
|
||||
});
|
||||
};
|
||||
|
||||
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
|
||||
createAttributeWithCacheKey({axis: attributes.axis as number});
|
||||
|
|
@ -305,38 +305,38 @@
|
|||
// "test_and2d",
|
||||
// "test_and3d",
|
||||
// "test_and4d",
|
||||
// // "test_argmax_default_axis_example_select_last_index",
|
||||
// // "test_argmax_default_axis_example",
|
||||
// // "test_argmax_default_axis_random_select_last_index",
|
||||
// // "test_argmax_default_axis_random",
|
||||
// // "test_argmax_keepdims_example_select_last_index",
|
||||
// // "test_argmax_keepdims_example",
|
||||
// // "test_argmax_keepdims_random_select_last_index",
|
||||
// // "test_argmax_keepdims_random",
|
||||
// // "test_argmax_negative_axis_keepdims_example_select_last_index",
|
||||
// // "test_argmax_negative_axis_keepdims_example",
|
||||
// // "test_argmax_negative_axis_keepdims_random_select_last_index",
|
||||
// // "test_argmax_negative_axis_keepdims_random",
|
||||
// // "test_argmax_no_keepdims_example_select_last_index",
|
||||
// // "test_argmax_no_keepdims_example",
|
||||
// // "test_argmax_no_keepdims_random_select_last_index",
|
||||
// // "test_argmax_no_keepdims_random",
|
||||
// // "test_argmin_default_axis_example_select_last_index",
|
||||
// // "test_argmin_default_axis_example",
|
||||
// // "test_argmin_default_axis_random_select_last_index",
|
||||
// // "test_argmin_default_axis_random",
|
||||
// // "test_argmin_keepdims_example_select_last_index",
|
||||
// // "test_argmin_keepdims_example",
|
||||
// // "test_argmin_keepdims_random_select_last_index",
|
||||
// // "test_argmin_keepdims_random",
|
||||
// // "test_argmin_negative_axis_keepdims_example_select_last_index",
|
||||
// // "test_argmin_negative_axis_keepdims_example",
|
||||
// // "test_argmin_negative_axis_keepdims_random_select_last_index",
|
||||
// // "test_argmin_negative_axis_keepdims_random",
|
||||
// // "test_argmin_no_keepdims_example_select_last_index",
|
||||
// // "test_argmin_no_keepdims_example",
|
||||
// // "test_argmin_no_keepdims_random_select_last_index",
|
||||
// // "test_argmin_no_keepdims_random",
|
||||
// "test_argmax_default_axis_example_select_last_index",
|
||||
"test_argmax_default_axis_example",
|
||||
// "test_argmax_default_axis_random_select_last_index",
|
||||
"test_argmax_default_axis_random",
|
||||
// "test_argmax_keepdims_example_select_last_index",
|
||||
"test_argmax_keepdims_example",
|
||||
// "test_argmax_keepdims_random_select_last_index",
|
||||
"test_argmax_keepdims_random",
|
||||
// "test_argmax_negative_axis_keepdims_example_select_last_index",
|
||||
"test_argmax_negative_axis_keepdims_example",
|
||||
// "test_argmax_negative_axis_keepdims_random_select_last_index",
|
||||
"test_argmax_negative_axis_keepdims_random",
|
||||
// "test_argmax_no_keepdims_example_select_last_index",
|
||||
"test_argmax_no_keepdims_example",
|
||||
// "test_argmax_no_keepdims_random_select_last_index",
|
||||
"test_argmax_no_keepdims_random",
|
||||
// "test_argmin_default_axis_example_select_last_index",
|
||||
"test_argmin_default_axis_example",
|
||||
// "test_argmin_default_axis_random_select_last_index",
|
||||
"test_argmin_default_axis_random",
|
||||
// "test_argmin_keepdims_example_select_last_index",
|
||||
"test_argmin_keepdims_example",
|
||||
// "test_argmin_keepdims_random_select_last_index",
|
||||
"test_argmin_keepdims_random",
|
||||
// "test_argmin_negative_axis_keepdims_example_select_last_index",
|
||||
"test_argmin_negative_axis_keepdims_example",
|
||||
// "test_argmin_negative_axis_keepdims_random_select_last_index",
|
||||
"test_argmin_negative_axis_keepdims_random",
|
||||
// "test_argmin_no_keepdims_example_select_last_index",
|
||||
"test_argmin_no_keepdims_example",
|
||||
// "test_argmin_no_keepdims_random_select_last_index",
|
||||
"test_argmin_no_keepdims_random",
|
||||
"test_asin_example",
|
||||
"test_asin",
|
||||
"test_asinh_example",
|
||||
|
|
@ -1133,8 +1133,8 @@
|
|||
// "test_softmax_axis_0",
|
||||
// "test_softmax_axis_1_expanded",
|
||||
// "test_softmax_axis_1",
|
||||
// "test_softmax_axis_2_expanded",
|
||||
// "test_softmax_axis_2",
|
||||
"test_softmax_axis_2_expanded",
|
||||
"test_softmax_axis_2",
|
||||
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded",
|
||||
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded",
|
||||
// "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob",
|
||||
|
|
@ -1203,14 +1203,14 @@
|
|||
// "test_softmax_cross_entropy_sum_log_prob_expanded",
|
||||
// "test_softmax_cross_entropy_sum_log_prob",
|
||||
// "test_softmax_cross_entropy_sum",
|
||||
// "test_softmax_default_axis_expanded",
|
||||
// "test_softmax_default_axis",
|
||||
// "test_softmax_example_expanded",
|
||||
// "test_softmax_example",
|
||||
// "test_softmax_large_number_expanded",
|
||||
// "test_softmax_large_number",
|
||||
// "test_softmax_negative_axis_expanded",
|
||||
// "test_softmax_negative_axis",
|
||||
"opset13/test_softmax_default_axis_expanded",
|
||||
"opset13/test_softmax_default_axis",
|
||||
"test_softmax_example_expanded",
|
||||
"test_softmax_example",
|
||||
"test_softmax_large_number_expanded",
|
||||
"test_softmax_large_number",
|
||||
"test_softmax_negative_axis_expanded",
|
||||
"test_softmax_negative_axis",
|
||||
// // "test_softplus_example",
|
||||
// // "test_softplus",
|
||||
// // "test_softsign_example",
|
||||
|
|
|
|||
|
|
@ -388,6 +388,7 @@ export class TensorResultValidator {
|
|||
case 'int16':
|
||||
case 'int32':
|
||||
case 'uint32':
|
||||
case 'int64':
|
||||
case 'bool':
|
||||
return this.integerEqual(
|
||||
actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
|
||||
|
|
|
|||
|
|
@ -237,6 +237,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat);
|
||||
|
|
@ -441,6 +452,18 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
|
||||
|
|
|
|||
41
onnxruntime/core/providers/js/operators/argminmax.cc
Normal file
41
onnxruntime/core/providers/js/operators/argminmax.cc
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "argminmax.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMinMaxOp, sinceVersion, endVersion) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
ArgMinMaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, endVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
|
||||
ArgMinMaxOp<float>);
|
||||
|
||||
#define REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMinMaxOp, sinceVersion) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
ArgMinMaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
|
||||
.InputMemoryType(OrtMemTypeCPU, 1), \
|
||||
ArgMinMaxOp<float>);
|
||||
|
||||
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10);
|
||||
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12);
|
||||
REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMax, 13);
|
||||
|
||||
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
|
||||
REGISTER_ARGMAX_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12);
|
||||
REGISTER_ARGMAX_ELEMENTWISE_KERNEL(ArgMin, 13);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
39
onnxruntime/core/providers/js/operators/argminmax.h
Normal file
39
onnxruntime/core/providers/js/operators/argminmax.h
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/cpu/reduction/reduction_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
#define JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMinMaxKernel) \
|
||||
template <typename T, bool allow_multi_axes = false> \
|
||||
class ArgMinMaxKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
|
||||
public: \
|
||||
using ReduceKernelBase<allow_multi_axes>::axes_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::select_last_index_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
|
||||
ArgMinMaxKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
|
||||
std::vector<int32_t> axes(axes_.size()); \
|
||||
if (axes_.size() > 0) { \
|
||||
axes.push_back(-1); \
|
||||
} \
|
||||
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
|
||||
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(ArgMinMaxKernel, ({ \
|
||||
"keepDims" : !!$1, \
|
||||
"selectLastIndex" : !!$2, \
|
||||
"axes" : $3, \
|
||||
}), \
|
||||
static_cast<int32_t>(keepdims_), \
|
||||
static_cast<int32_t>(select_last_index_), \
|
||||
gsl::narrow_cast<int32_t>(axes[0])); \
|
||||
} \
|
||||
};
|
||||
|
||||
JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMax);
|
||||
JSEP_DEFINE_ARGMINMAX_KERNEL(ArgMin);
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
37
onnxruntime/core/providers/js/operators/softmax.cc
Normal file
37
onnxruntime/core/providers/js/operators/softmax.cc
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "softmax.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
SoftmaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, endVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
|
||||
SoftmaxOp<float>);
|
||||
|
||||
#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
SoftmaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
|
||||
.InputMemoryType(OrtMemTypeCPU, 1), \
|
||||
SoftmaxOp<float>);
|
||||
|
||||
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10);
|
||||
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12);
|
||||
REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(Softmax, 13);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
42
onnxruntime/core/providers/js/operators/softmax.h
Normal file
42
onnxruntime/core/providers/js/operators/softmax.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/cpu/reduction/reduction_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
template <typename T>
|
||||
class Softmax : public JsKernel {
|
||||
public:
|
||||
Softmax(const OpKernelInfo& info) : JsKernel(info) {
|
||||
const auto& node = info.node();
|
||||
opset_ = node.SinceVersion();
|
||||
|
||||
int64_t axis;
|
||||
Status status = info.GetAttr<int64_t>("axis", &axis);
|
||||
|
||||
if (status.IsOK()) {
|
||||
axis_ = gsl::narrow_cast<int>(axis);
|
||||
} else {
|
||||
if (opset_ < 13) {
|
||||
axis_ = 1; // opset-12 and below, the default axis value is 1
|
||||
} else {
|
||||
axis_ = -1; // opset-13, the default axis value is -1
|
||||
}
|
||||
}
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Softmax, ({
|
||||
"axis" : $1
|
||||
}),
|
||||
axis_);
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
int opset_;
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue