Merge int32/uint32 and int64/uint64 MatMul kernels (#4531)

This commit is contained in:
Tracy Sharpe 2020-07-16 21:25:29 -07:00 committed by GitHub
parent 02aea5d2d4
commit 8b86c5cdb5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 52 deletions

View file

@ -278,9 +278,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint32_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint64_t, MatMul);
// Opset 10
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer);
@ -898,12 +896,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint32_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint64_t,
MatMul)>,
// Opset 10
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer)>,

View file

@ -1,13 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/math/matmul.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "matmul_helper.h"
namespace onnxruntime {
template <typename T>
class MatMul final : public OpKernel {
public:
MatMul(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
};
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
1, 8,
@ -41,41 +49,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
9,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<uint32_t>()}),
MatMul<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
9,
uint32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint32_t>()),
MatMul<uint32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
9,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<int64_t>(), DataTypeImpl::GetTensorType<uint64_t>()}),
MatMul<int64_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
9,
uint64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint64_t>()),
MatMul<uint64_t>);
template <typename T>
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const auto* left_X = ctx->Input<Tensor>(0);
const auto* right_X = ctx->Input<Tensor>(1);
const auto* a = ctx->Input<Tensor>(0);
const auto* b = ctx->Input<Tensor>(1);
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(left_X->Shape(), right_X->Shape()));
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());
Tensor* Y = ctx->Output(0, helper.OutputShape());
// Using DataRaw as int32_t/uint32_t and int64_t/uint64_t share a common
// operator body.
const auto* a_data = reinterpret_cast<const T*>(a->DataRaw());
const auto* b_data = reinterpret_cast<const T*>(b->DataRaw());
auto* y_data = reinterpret_cast<T*>(y->MutableDataRaw());
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
size_t max_len = helper.OutputOffsets().size();
@ -84,9 +85,10 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
left_X->template Data<T>() + helper.LeftOffsets()[i],
right_X->template Data<T>() + helper.RightOffsets()[i],
Y->template MutableData<T>() + helper.OutputOffsets()[i], thread_pool);
a_data + helper.LeftOffsets()[i],
b_data + helper.RightOffsets()[i],
y_data + helper.OutputOffsets()[i],
thread_pool);
}
return Status::OK();

View file

@ -1,21 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
template <typename T>
class MatMul final : public OpKernel {
public:
MatMul(const OpKernelInfo& info)
: OpKernel(info) {
}
Status Compute(OpKernelContext* context) const override;
};
} // namespace onnxruntime