From 9642f1448e1fefcef37411343ceda154ae7eec6a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 9 Oct 2020 12:39:22 -0700 Subject: [PATCH] Add OpSet 13 Registrations (#5426) Register Sigmoid for OpSet13 Register OpSet 13 for Sum, Min, Max, Mean. Add Erf OpSet 13 registration. Register Clip for OpSet 13 Add Gemm/MatMul Opset 13 resigstartions Signed-off-by: Dmitri Smirnov --- .../providers/cpu/activation/activations.cc | 2 +- .../providers/cpu/cpu_execution_provider.cc | 114 +++++++++++------- onnxruntime/core/providers/cpu/math/clip.cc | 44 ++++--- .../providers/cpu/math/element_wise_ops.cc | 20 ++- onnxruntime/core/providers/cpu/math/gemm.cc | 12 +- onnxruntime/core/providers/cpu/math/matmul.cc | 46 ++++++- 6 files changed, 160 insertions(+), 78 deletions(-) diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index d962b0d21e..f83e7b9aa4 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -40,7 +40,7 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(HardSigmoid, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(LeakyRelu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(Relu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(Selu, 6); -REGISTER_UNARY_ELEMENTWISE_KERNEL(Sigmoid, 6); +REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13) REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1); REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 4148569dea..9da8542cde 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -35,7 +35,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Har class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Relu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Selu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softplus); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softsign); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); @@ -87,7 +88,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, float, Sum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 11, Min); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max); @@ -106,7 +107,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, double, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, float, Mean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos); @@ -232,7 +233,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, Sign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Erf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, OneHot); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, OneHot); @@ -265,10 +266,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); // Opset 10 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer); @@ -375,7 +376,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -408,10 +409,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, int32_t, ArgMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, Clip); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Min); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Max); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, Min); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Pow); @@ -430,8 +431,22 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Einsum); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); +// REVIEW(codemzs): ConstEigenVectorArrayMap.cast KernelCreateInfo BuildKernelCreateInfo() { @@ -480,7 +490,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -554,7 +565,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -585,7 +596,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -790,7 +801,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // Opset 10 @@ -980,7 +991,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ConcatFromSequence)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1090,11 +1101,11 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1149,6 +1160,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Expand)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, @@ -1160,9 +1172,23 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 13 - BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo()), Clip_6); -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Clip, - 11, - 11, - KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), - Clip); +#define REG_KERNEL_VERSIONED_NONTEMPL(OP_TYPE, START_VER, END_VER, KERNEL_CLASS, ...) \ + ONNX_CPU_OPERATOR_VERSIONED_KERNEL(OP_TYPE, \ + START_VER, \ + END_VER, \ + KernelDefBuilder() \ + .MayInplace(0, 0) \ + .TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()), \ + KERNEL_CLASS); #define REG_KERNEL_NONTEMPL(OP_TYPE, VERSION, KERNEL_CLASS, ...) \ - ONNX_CPU_OPERATOR_KERNEL( \ - OP_TYPE, \ - VERSION, \ - KernelDefBuilder() \ - .MayInplace(0, 0) \ + ONNX_CPU_OPERATOR_KERNEL( \ + OP_TYPE, \ + VERSION, \ + KernelDefBuilder() \ + .MayInplace(0, 0) \ .TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()), \ KERNEL_CLASS); -REG_KERNEL_NONTEMPL(Clip, 12, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t); +REG_KERNEL_VERSIONED_NONTEMPL(Clip, 11, 11, Clip, float); +REG_KERNEL_VERSIONED_NONTEMPL(Clip, 12, 12, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t); +REG_KERNEL_NONTEMPL(Clip, 13, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t); -template +template Status Clip_6::Compute(OpKernelContext* ctx) const { - const auto* X = ctx->Input(0); - Tensor* Y = ctx->Output(0, X->Shape()); - EigenVectorMap(Y->template MutableData(), Y->Shape().Size()) = - ConstEigenVectorMap(X->template Data(), X->Shape().Size()) - .cwiseMax(this->min_) - .cwiseMin(this->max_); - return Status::OK(); + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + EigenVectorMap(Y->template MutableData(), Y->Shape().Size()) = + ConstEigenVectorMap(X->template Data(), X->Shape().Size()) + .cwiseMax(this->min_) + .cwiseMin(this->max_); + return Status::OK(); } template diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index a1ec4d3d90..c4c40e27aa 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -134,15 +134,21 @@ REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, double, Exp); REG_ELEMENTWISE_TYPED_KERNEL(Log, 6, float, Log); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, float, Sum_6); -REG_ELEMENTWISE_TYPED_KERNEL(Sum, 8, float, Sum_8); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 8, 12, float, Sum_8); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, float, Sum_8); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6); REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, float, double); -REG_ELEMENTWISE_KERNEL_NONT(Max, 12, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6); REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, float); -REG_ELEMENTWISE_KERNEL_NONT(Min, 12, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, float, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, double, Less); @@ -170,14 +176,18 @@ REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, float, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, double, Equal); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6); -REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 8, 12, float, Mean_8); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_TYPED_KERNEL(Mean, 13, float, Mean_8); REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift); //REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint16_t, BitShift); REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift); REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift); -REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Erf, 9, 12, float, Erf); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_TYPED_KERNEL(Erf, 13, float, Erf); // REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Not, 1, bool, Not); // REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(And, 7, bool, And); diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 49d625c1ee..fcb8e77761 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -15,15 +15,23 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet. ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Gemm, - 9, + 9, 10, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); // opset 11 made bias input 'C' optional -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Gemm, 11, + 12, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); + +// opset 13 Adds BFloat16 support but we are not supporting it yet +ONNX_CPU_OPERATOR_KERNEL( + Gemm, + 13, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 31fdde40e2..3e58c402f8 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -24,34 +24,68 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul); // opset 9 supports more types -ONNX_CPU_OPERATOR_TYPED_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, 9, + 12, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, + 12, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, + 12, + int32_t, + KernelDefBuilder() + .TypeConstraint("T", BuildKernelDefConstraints()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, + 12, + int64_t, + KernelDefBuilder() + .TypeConstraint("T", BuildKernelDefConstraints()), + MatMul); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + MatMul, + 13, float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, - 9, + 13, double, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, - 9, + 13, int32_t, KernelDefBuilder() - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", BuildKernelDefConstraints()), MatMul); ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, - 9, + 13, int64_t, KernelDefBuilder() - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", BuildKernelDefConstraints()), MatMul); template