mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Merge int32/uint32 and int64/uint64 MatMul kernels (#4531)
This commit is contained in:
parent
02aea5d2d4
commit
8b86c5cdb5
3 changed files with 27 additions and 52 deletions
|
|
@ -278,9 +278,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint32_t, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint64_t, MatMul);
|
||||
|
||||
// Opset 10
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer);
|
||||
|
|
@ -898,12 +896,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t,
|
||||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint32_t,
|
||||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t,
|
||||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint64_t,
|
||||
MatMul)>,
|
||||
|
||||
// Opset 10
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer)>,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/math/matmul.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "matmul_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
class MatMul final : public OpKernel {
|
||||
public:
|
||||
MatMul(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
1, 8,
|
||||
|
|
@ -41,41 +49,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
|||
MatMul,
|
||||
9,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<uint32_t>()}),
|
||||
MatMul<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9,
|
||||
uint32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint32_t>()),
|
||||
MatMul<uint32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<int64_t>(), DataTypeImpl::GetTensorType<uint64_t>()}),
|
||||
MatMul<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9,
|
||||
uint64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint64_t>()),
|
||||
MatMul<uint64_t>);
|
||||
|
||||
template <typename T>
|
||||
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
const auto* left_X = ctx->Input<Tensor>(0);
|
||||
const auto* right_X = ctx->Input<Tensor>(1);
|
||||
const auto* a = ctx->Input<Tensor>(0);
|
||||
const auto* b = ctx->Input<Tensor>(1);
|
||||
|
||||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(left_X->Shape(), right_X->Shape()));
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
|
||||
Tensor* y = ctx->Output(0, helper.OutputShape());
|
||||
|
||||
Tensor* Y = ctx->Output(0, helper.OutputShape());
|
||||
// Using DataRaw as int32_t/uint32_t and int64_t/uint64_t share a common
|
||||
// operator body.
|
||||
const auto* a_data = reinterpret_cast<const T*>(a->DataRaw());
|
||||
const auto* b_data = reinterpret_cast<const T*>(b->DataRaw());
|
||||
auto* y_data = reinterpret_cast<T*>(y->MutableDataRaw());
|
||||
|
||||
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
|
||||
size_t max_len = helper.OutputOffsets().size();
|
||||
|
|
@ -84,9 +85,10 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
|||
static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
left_X->template Data<T>() + helper.LeftOffsets()[i],
|
||||
right_X->template Data<T>() + helper.RightOffsets()[i],
|
||||
Y->template MutableData<T>() + helper.OutputOffsets()[i], thread_pool);
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
b_data + helper.RightOffsets()[i],
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
thread_pool);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
class MatMul final : public OpKernel {
|
||||
public:
|
||||
MatMul(const OpKernelInfo& info)
|
||||
: OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue