diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 4bc6d526ad..7a8b947a51 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -278,9 +278,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint32_t, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint64_t, MatMul); // Opset 10 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer); @@ -898,12 +896,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { MatMul)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, // Opset 10 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8a7d26fc13..6a66fd609f 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -1,13 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cpu/math/matmul.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "matmul_helper.h" namespace onnxruntime { +template +class MatMul final : public OpKernel { + public: + MatMul(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* context) const override; +}; + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, 1, 8, @@ -41,41 +49,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, 9, int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMul); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - MatMul, - 9, - uint32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MatMul); - ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, 9, int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMul); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - MatMul, - 9, - uint64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MatMul); - template Status MatMul::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const auto* left_X = ctx->Input(0); - const auto* right_X = ctx->Input(1); + const auto* a = ctx->Input(0); + const auto* b = ctx->Input(1); MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(left_X->Shape(), right_X->Shape())); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape())); + Tensor* y = ctx->Output(0, helper.OutputShape()); - Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Using DataRaw as int32_t/uint32_t and int64_t/uint64_t share a common + // operator body. + const auto* a_data = reinterpret_cast(a->DataRaw()); + const auto* b_data = reinterpret_cast(b->DataRaw()); + auto* y_data = reinterpret_cast(y->MutableDataRaw()); // TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well size_t max_len = helper.OutputOffsets().size(); @@ -84,9 +85,10 @@ Status MatMul::Compute(OpKernelContext* ctx) const { static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), - left_X->template Data() + helper.LeftOffsets()[i], - right_X->template Data() + helper.RightOffsets()[i], - Y->template MutableData() + helper.OutputOffsets()[i], thread_pool); + a_data + helper.LeftOffsets()[i], + b_data + helper.RightOffsets()[i], + y_data + helper.OutputOffsets()[i], + thread_pool); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h deleted file mode 100644 index 02a42aa3ef..0000000000 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" - -namespace onnxruntime { - -template -class MatMul final : public OpKernel { - public: - MatMul(const OpKernelInfo& info) - : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override; -}; - -} // namespace onnxruntime