mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
MatMul op: Support new integer types and double type as part of opset V9 compliance (#482)
* Support new integer types and double type as part of opset V9 compliance
This commit is contained in:
parent
b69c834c06
commit
c2b8ac0154
4 changed files with 404 additions and 110 deletions
|
|
@ -89,7 +89,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Ata
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Hardmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint32_t, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int64_t, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint64_t, MatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, TopK);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, BatchNormalization);
|
||||
|
|
@ -342,7 +347,12 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, Gemm)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Hardmax)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, LogSoftmax)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint32_t, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int64_t, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint64_t, MatMul)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softmax)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, TopK)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, BatchNormalization)>());
|
||||
|
|
|
|||
|
|
@ -9,15 +9,50 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
MatMul,
|
||||
1,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
MatMul<float>);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
1, 9,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
MatMul<float>);
|
||||
|
||||
template <>
|
||||
Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
1, 9,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
MatMul<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9, 9,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
MatMul<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9, 9,
|
||||
uint32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint32_t>()),
|
||||
MatMul<uint32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9, 9,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
MatMul<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
MatMul,
|
||||
9, 9,
|
||||
uint64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint64_t>()),
|
||||
MatMul<uint64_t>);
|
||||
|
||||
template <typename T>
|
||||
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* left_X = ctx->Input<Tensor>(0);
|
||||
const Tensor* right_X = ctx->Input<Tensor>(1);
|
||||
|
||||
|
|
@ -28,17 +63,17 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
|
||||
for (int i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
/* alpha */ 1.0f,
|
||||
left_X->template Data<float>() + helper.LeftOffsets()[i],
|
||||
right_X->template Data<float>() + helper.RightOffsets()[i],
|
||||
left_X->template Data<T>() + helper.LeftOffsets()[i],
|
||||
right_X->template Data<T>() + helper.RightOffsets()[i],
|
||||
/* beta */ 0.0f,
|
||||
Y->template MutableData<float>() + helper.OutputOffsets()[i],
|
||||
Y->template MutableData<T>() + helper.OutputOffsets()[i],
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,6 +51,59 @@
|
|||
namespace onnxruntime {
|
||||
namespace math {
|
||||
|
||||
// Gemm implementation purely based on Eigen.
|
||||
template <typename T>
|
||||
void GemmEigen(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
float alpha,
|
||||
const T* A,
|
||||
const T* B,
|
||||
float beta,
|
||||
T* C) {
|
||||
auto C_mat = EigenMatrixMap<T>(C, N, M);
|
||||
if (beta == 0) {
|
||||
C_mat.setZero();
|
||||
} else {
|
||||
C_mat *= static_cast<T>(beta);
|
||||
}
|
||||
switch (TransA) {
|
||||
case CblasNoTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, N, K) *
|
||||
ConstEigenMatrixMap<T>(A, K, M));
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, K, N).transpose() *
|
||||
ConstEigenMatrixMap<T>(A, K, M));
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
case CblasTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, N, K) *
|
||||
ConstEigenMatrixMap<T>(A, M, K).transpose());
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, K, N).transpose() *
|
||||
ConstEigenMatrixMap<T>(A, M, K).transpose());
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// BLAS alternatives.
|
||||
// Depending on whether we have specified an external BLAS library or not, we
|
||||
|
|
@ -110,47 +163,100 @@ void Gemm<float, CPUMathUtil>(
|
|||
int ldb = (int)((TransB == CblasNoTrans) ? N : K);
|
||||
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
|
||||
#else
|
||||
auto C_mat = EigenMatrixMap<float>(C, N, M);
|
||||
if (beta == 0) {
|
||||
C_mat.setZero();
|
||||
} else {
|
||||
C_mat *= beta;
|
||||
}
|
||||
switch (TransA) {
|
||||
case CblasNoTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, N, K) *
|
||||
ConstEigenMatrixMap<float>(A, K, M));
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, K, N).transpose() *
|
||||
ConstEigenMatrixMap<float>(A, K, M));
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
case CblasTrans: {
|
||||
switch (TransB) {
|
||||
case CblasNoTrans:
|
||||
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, N, K) *
|
||||
ConstEigenMatrixMap<float>(A, M, K).transpose());
|
||||
return;
|
||||
case CblasTrans:
|
||||
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, K, N).transpose() *
|
||||
ConstEigenMatrixMap<float>(A, M, K).transpose());
|
||||
return;
|
||||
default:
|
||||
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
|
||||
}
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
|
||||
}
|
||||
GemmEigen<float>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<double, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const double* A,
|
||||
const double* B,
|
||||
const float beta,
|
||||
double* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<double>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int32_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const int32_t* A,
|
||||
const int32_t* B,
|
||||
const float beta,
|
||||
int32_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint32_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const uint32_t* A,
|
||||
const uint32_t* B,
|
||||
const float beta,
|
||||
uint32_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int64_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const int64_t* A,
|
||||
const int64_t* B,
|
||||
const float beta,
|
||||
int64_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint64_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const uint64_t* A,
|
||||
const uint64_t* B,
|
||||
const float beta,
|
||||
uint64_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
|
||||
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
|
|
@ -343,6 +449,102 @@ void Gemm<float, CPUMathUtil>(
|
|||
beta, C, gsl::narrow_cast<int>(N));
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<double, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const double* A,
|
||||
const double* B,
|
||||
const float beta,
|
||||
double* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
cblas_dgemm(CblasRowMajor, TransA, TransB,
|
||||
gsl::narrow_cast<int>(M),
|
||||
gsl::narrow_cast<int>(N),
|
||||
gsl::narrow_cast<int>(K),
|
||||
gsl::narrow_cast<double>(alpha), A, lda, B, ldb,
|
||||
gsl::narrow_cast<double>(beta), C, gsl::narrow_cast<int>(N));
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int32_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const int32_t* A,
|
||||
const int32_t* B,
|
||||
const float beta,
|
||||
int32_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No int32_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint32_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const uint32_t* A,
|
||||
const uint32_t* B,
|
||||
const float beta,
|
||||
uint32_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No uint32_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<int64_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const int64_t* A,
|
||||
const int64_t* B,
|
||||
const float beta,
|
||||
int64_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No int64_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<uint64_t, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const float alpha,
|
||||
const uint64_t* A,
|
||||
const uint64_t* B,
|
||||
const float beta,
|
||||
uint64_t* C,
|
||||
CPUMathUtil* /*provider*/,
|
||||
MLDataType /*math_type*/) {
|
||||
// No uint64_t Gemm offering from MKLML. Directly fallback to Eigen.
|
||||
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUMathUtil>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
|
|
|
|||
|
|
@ -7,75 +7,122 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(MathOpTest, MatMul) {
|
||||
std::vector<float> vals{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
|
||||
template <typename T>
|
||||
struct MatMulTestData {
|
||||
std::string name;
|
||||
std::vector<int64_t> input0_dims;
|
||||
std::vector<int64_t> input1_dims;
|
||||
std::vector<int64_t> expected_dims;
|
||||
std::vector<T> expected_vals;
|
||||
};
|
||||
|
||||
struct MatMulTest {
|
||||
std::string name;
|
||||
std::vector<int64_t> input0_dims;
|
||||
std::vector<int64_t> input1_dims;
|
||||
std::vector<int64_t> expected_dims;
|
||||
std::vector<float> expected_vals;
|
||||
};
|
||||
template <typename T>
|
||||
std::vector<MatMulTestData<T>> GenerateTestCases()
|
||||
{
|
||||
std::vector<MatMulTestData<T>> test_cases;
|
||||
|
||||
MatMulTest testcases[] = {
|
||||
{"test padding and broadcast",
|
||||
{3, 1, 1, 2},
|
||||
{2, 2, 2},
|
||||
{3, 2, 1, 2},
|
||||
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}},
|
||||
{"test padding and broadcast",
|
||||
{2, 3, 2},
|
||||
{3, 2, 2, 1},
|
||||
{3, 2, 3, 1},
|
||||
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}},
|
||||
{"test left 1D",
|
||||
{2},
|
||||
{3, 2, 1},
|
||||
{3, 1},
|
||||
{1, 3, 5}},
|
||||
{"test right 1D",
|
||||
{3, 1, 2},
|
||||
{2},
|
||||
{3, 1},
|
||||
{1, 3, 5}},
|
||||
{"test scalar output",
|
||||
{3},
|
||||
{3},
|
||||
{},
|
||||
{5}},
|
||||
{"test 2D",
|
||||
{3, 4},
|
||||
{4, 3},
|
||||
{3, 3},
|
||||
{42, 48, 54, 114, 136, 158, 186, 224, 262}},
|
||||
{"test 2D special",
|
||||
{2, 2, 3},
|
||||
{3, 4},
|
||||
{2, 2, 4},
|
||||
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}},
|
||||
{"test 2D special 2",
|
||||
{2, 2, 3},
|
||||
{1, 3, 4},
|
||||
{2, 2, 4},
|
||||
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}},
|
||||
};
|
||||
test_cases.push_back(
|
||||
{"test padding and broadcast",
|
||||
{3, 1, 1, 2},
|
||||
{2, 2, 2},
|
||||
{3, 2, 1, 2},
|
||||
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}});
|
||||
|
||||
for (auto t : testcases) {
|
||||
OpTester test("MatMul");
|
||||
test_cases.push_back(
|
||||
{"test padding and broadcast",
|
||||
{2, 3, 2},
|
||||
{3, 2, 2, 1},
|
||||
{3, 2, 3, 1},
|
||||
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test left 1D",
|
||||
{2},
|
||||
{3, 2, 1},
|
||||
{3, 1},
|
||||
{1, 3, 5}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test right 1D",
|
||||
{3, 1, 2},
|
||||
{2},
|
||||
{3, 1},
|
||||
{1, 3, 5}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test scalar output",
|
||||
{3},
|
||||
{3},
|
||||
{},
|
||||
{5}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test 2D",
|
||||
{3, 4},
|
||||
{4, 3},
|
||||
{3, 3},
|
||||
{42, 48, 54, 114, 136, 158, 186, 224, 262}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test 2D special",
|
||||
{2, 2, 3},
|
||||
{3, 4},
|
||||
{2, 2, 4},
|
||||
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"test 2D special 2",
|
||||
{2, 2, 3},
|
||||
{1, 3, 4},
|
||||
{2, 2, 4},
|
||||
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RunMatMulTest(int32_t opset_version = 7)
|
||||
{
|
||||
std::vector<T> common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
|
||||
for (auto t : GenerateTestCases<T>()) {
|
||||
OpTester test("MatMul", opset_version);
|
||||
|
||||
int64_t size0 = TensorShape::ReinterpretBaseType(t.input0_dims).SizeHelper(0, t.input0_dims.size());
|
||||
std::vector<float> input0_vals(vals.cbegin(), vals.cbegin() + size0);
|
||||
test.AddInput<float>("A", t.input0_dims, input0_vals);
|
||||
std::vector<T> input0_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size0);
|
||||
test.AddInput<T>("A", t.input0_dims, input0_vals);
|
||||
|
||||
int64_t size1 = TensorShape::ReinterpretBaseType(t.input1_dims).SizeHelper(0, t.input1_dims.size());
|
||||
std::vector<float> input1_vals(vals.cbegin(), vals.cbegin() + size1);
|
||||
test.AddInput<float>("B", t.input1_dims, input1_vals);
|
||||
std::vector<T> input1_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size1);
|
||||
test.AddInput<T>("B", t.input1_dims, input1_vals);
|
||||
|
||||
test.AddOutput<float>("Y", t.expected_dims, t.expected_vals);
|
||||
test.AddOutput<T>("Y", t.expected_dims, t.expected_vals);
|
||||
test.Run();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulFloatType) {
|
||||
RunMatMulTest<float>();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulDoubleType) {
|
||||
RunMatMulTest<double>();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulInt32Type) {
|
||||
RunMatMulTest<int32_t>(9);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulUint32Type) {
|
||||
RunMatMulTest<uint32_t>(9);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulInt64Type) {
|
||||
RunMatMulTest<int64_t>(9);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, MatMulUint64Type) {
|
||||
RunMatMulTest<uint64_t>(9);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue