[JSEP] Upgrade to ONNX Opset 21 (#22595)

### JSEP Ops that need updating

- [x] Cast
- [x] ReduceMax
- [x] ReduceMin
- [x] Squeeze
- [x] Unsqueeze
- [x] Transpose
- [x] AveragePool
- [x] Flatten
- [x] Pad
- [x] If
This commit is contained in:
Prathik Rao 2024-10-29 17:44:38 -07:00 committed by GitHub
parent e2e837584f
commit 5cc7fb4a74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 156 additions and 43 deletions

View file

@ -10,6 +10,8 @@ import * as path from 'path';
import { downloadZip, extractFile } from './utils';
const TEST_DATA_OPSET_VERSIONS = [
['opset21', '1.16.2'],
['opset20', '1.15.0'],
['opset19', '1.14.0'],
['opset18', '1.13.1'],
['opset17', '1.12.1'],

View file

@ -21,11 +21,11 @@ Do not modify directly.*
| Atan | ai.onnx(7+) | |
| Atanh | ai.onnx(9+) | |
| Attention | com.microsoft(1+) | need implementing mask and past/present |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
| AveragePool | ai.onnx(7-9,10,11-18,19+); com.ms.internal.nhwc(7-9,10,11-18,19+) | need perf optimization; need implementing activation |
| BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | |
| BiasAdd | com.microsoft(1+) | |
| BiasSplitGelu | com.microsoft(1+) | |
| Cast | ai.onnx(6-8,9-12,13-18,19+) | |
| Cast | ai.onnx(6-8,9-12,13-18,19-20,21+) | |
| Ceil | ai.onnx(6-12,13+) | |
| Clip | ai.onnx(6-10,11,12,13+) | |
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
@ -44,7 +44,7 @@ Do not modify directly.*
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
| FastGelu | com.microsoft(1+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13-20,21+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
@ -58,7 +58,7 @@ Do not modify directly.*
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| GroupQueryAttention | com.microsoft(1+) | |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| If | ai.onnx(1-10,11-12,13-18,19-20,21+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(1-16,17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
@ -74,7 +74,7 @@ Do not modify directly.*
| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present |
| Neg | ai.onnx(6-12,13+) | |
| Not | ai.onnx(1+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19-20,21+) | |
| Pow | ai.onnx(7-11,12,13-14,15+) | |
| QuickGelu | com.microsoft(1+) | |
| Range | ai.onnx(11+) | |
@ -83,9 +83,9 @@ Do not modify directly.*
| ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | |
| ReduceLogSum | ai.onnx(1-10,11-12,13-17,18+) | |
| ReduceLogSumExp | ai.onnx(1-10,11-12,13-17,18+) | |
| ReduceMax | ai.onnx(1-10,11,12,13-17,18+) | |
| ReduceMax | ai.onnx(1-10,11,12,13-17,18-19,20+) | |
| ReduceMean | ai.onnx(1-10,11-12,13-17,18+) | |
| ReduceMin | ai.onnx(1-10,11,12,13-17,18+) | |
| ReduceMin | ai.onnx(1-10,11,12,13-17,18-19,20+) | |
| ReduceProd | ai.onnx(1-10,11-12,13-17,18+) | |
| ReduceSum | ai.onnx(1-10,11-12,13+) | |
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
@ -104,12 +104,12 @@ Do not modify directly.*
| Softmax | ai.onnx(1-10,11-12,13+) | |
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
| Sqrt | ai.onnx(6-12,13+) | |
| Squeeze | ai.onnx(1-10,11-12,13+) | |
| Squeeze | ai.onnx(1-10,11-12,13-20,21+) | |
| Sub | ai.onnx(7-12,13,14+) | |
| Tan | ai.onnx(7+) | |
| Tanh | ai.onnx(6-12,13+) | |
| ThresholdedRelu | ai.onnx(10+) | |
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
| Transpose | ai.onnx(1-12,13-20,21+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13-20,21+) | |
| Where | ai.onnx(9-15,16+) | |

View file

@ -121,7 +121,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Not)
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 8, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 10, Clip);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, Clip);
@ -139,7 +140,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean);
@ -150,7 +152,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMin);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd);
@ -233,17 +236,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Res
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Unsqueeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Transpose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, DepthToSpace);
@ -273,10 +279,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 18, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool);
@ -341,7 +349,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sli
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile);
@ -358,12 +367,14 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization);
@ -439,7 +450,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast),
KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast),
KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast),
KERNEL_CREATE_INFO(19, Cast),
KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast),
KERNEL_CREATE_INFO(21, Cast),
// activations
KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip),
@ -501,12 +513,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean)>,
@ -515,13 +529,15 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd)>,
@ -561,7 +577,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(16, Where),
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
@ -591,10 +608,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 18, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool)>,
@ -660,7 +679,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile)>,
@ -677,12 +697,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,

View file

@ -49,10 +49,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Cast,
kOnnxDomain,
19, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);
ONNX_OPERATOR_KERNEL_EX(
Cast,
kOnnxDomain,
19,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())

View file

@ -36,10 +36,20 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
13, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
Flatten,
kOnnxDomain,
13,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)

View file

@ -44,9 +44,21 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
If);
// opset-19 supports float8
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
19, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
ONNX_OPERATOR_KERNEL_EX(If,
kOnnxDomain,
19,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU

View file

@ -56,10 +56,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 3),
Pad);
ONNX_OPERATOR_KERNEL_EX(
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Pad,
kOnnxDomain,
19,
20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Pad);
ONNX_OPERATOR_KERNEL_EX(
Pad,
kOnnxDomain,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedFloatTypes())

View file

@ -55,8 +55,10 @@ POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10)
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10)
POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11)
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 11, 18)
POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11, 18)
POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 19)
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 19)
POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1)
POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1)

View file

@ -20,6 +20,16 @@ namespace js {
// a new opset version update applies to Reduce* operators, we may need to add another macro like
// REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type.
// i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased.
#define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
ReduceOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
ReduceOp<true>);
#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \
ONNX_OPERATOR_KERNEL_EX( \
@ -41,13 +51,15 @@ REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17);
REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 18);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMax, 18, 19);
REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 20);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17);
REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 18);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMin, 18, 19);
REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 20);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12);

View file

@ -10,7 +10,7 @@ namespace js {
ONNX_OPERATOR_KERNEL_EX(
Squeeze,
kOnnxDomain,
13,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 1),
Squeeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Squeeze,
kOnnxDomain,
13, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Squeeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Squeeze,
kOnnxDomain,

View file

@ -15,10 +15,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Transpose,
kOnnxDomain,
13, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);
ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
13,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes()),

View file

@ -10,7 +10,7 @@ namespace js {
ONNX_OPERATOR_KERNEL_EX(
Unsqueeze,
kOnnxDomain,
13,
21,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 1),
Unsqueeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Unsqueeze,
kOnnxDomain,
13, 20,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedDataTypes())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Unsqueeze);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Unsqueeze,
kOnnxDomain,