MatMul op: Support new integer types and double type as part of opset V9 compliance (#482)

* Support new integer types and double type as part of opset V9 compliance
This commit is contained in:
Hariharan Seshadri 2019-02-20 17:03:37 -08:00 committed by GitHub
parent b69c834c06
commit c2b8ac0154
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 404 additions and 110 deletions

View file

@ -89,7 +89,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Ata
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Hardmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint32_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int64_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint64_t, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, BatchNormalization);
@ -342,7 +347,12 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, Gemm)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Hardmax)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, LogSoftmax)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint32_t, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int64_t, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint64_t, MatMul)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softmax)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, TopK)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, BatchNormalization)>());

View file

@ -9,15 +9,50 @@
namespace onnxruntime {
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
MatMul,
1,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MatMul<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
1, 9,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MatMul<float>);
template <>
Status MatMul<float>::Compute(OpKernelContext* ctx) const {
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
1, 9,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
MatMul<double>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9, 9,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
MatMul<int32_t>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9, 9,
uint32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint32_t>()),
MatMul<uint32_t>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9, 9,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
MatMul<int64_t>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9, 9,
uint64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint64_t>()),
MatMul<uint64_t>);
template <typename T>
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
const Tensor* left_X = ctx->Input<Tensor>(0);
const Tensor* right_X = ctx->Input<Tensor>(1);
@ -28,17 +63,17 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
for (int i = 0; i < helper.OutputOffsets().size(); i++) {
math::Gemm<float, CPUMathUtil>(
math::Gemm<T, CPUMathUtil>(
CblasNoTrans,
CblasNoTrans,
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
/* alpha */ 1.0f,
left_X->template Data<float>() + helper.LeftOffsets()[i],
right_X->template Data<float>() + helper.RightOffsets()[i],
left_X->template Data<T>() + helper.LeftOffsets()[i],
right_X->template Data<T>() + helper.RightOffsets()[i],
/* beta */ 0.0f,
Y->template MutableData<float>() + helper.OutputOffsets()[i],
Y->template MutableData<T>() + helper.OutputOffsets()[i],
&CPUMathUtil::Instance());
}

View file

@ -51,6 +51,59 @@
namespace onnxruntime {
namespace math {
// Gemm implementation purely based on Eigen.
template <typename T>
void GemmEigen(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
int64_t M,
int64_t N,
int64_t K,
float alpha,
const T* A,
const T* B,
float beta,
T* C) {
auto C_mat = EigenMatrixMap<T>(C, N, M);
if (beta == 0) {
C_mat.setZero();
} else {
C_mat *= static_cast<T>(beta);
}
switch (TransA) {
case CblasNoTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, N, K) *
ConstEigenMatrixMap<T>(A, K, M));
return;
case CblasTrans:
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, K, N).transpose() *
ConstEigenMatrixMap<T>(A, K, M));
return;
default:
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
case CblasTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, N, K) *
ConstEigenMatrixMap<T>(A, M, K).transpose());
return;
case CblasTrans:
C_mat.noalias() += static_cast<T>(alpha) * (ConstEigenMatrixMap<T>(B, K, N).transpose() *
ConstEigenMatrixMap<T>(A, M, K).transpose());
return;
default:
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
default:
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
}
}
////////////////////////////////////////////////////////////////////////////////
// BLAS alternatives.
// Depending on whether we have specified an external BLAS library or not, we
@ -110,47 +163,100 @@ void Gemm<float, CPUMathUtil>(
int ldb = (int)((TransB == CblasNoTrans) ? N : K);
MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
#else
auto C_mat = EigenMatrixMap<float>(C, N, M);
if (beta == 0) {
C_mat.setZero();
} else {
C_mat *= beta;
}
switch (TransA) {
case CblasNoTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, N, K) *
ConstEigenMatrixMap<float>(A, K, M));
return;
case CblasTrans:
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, K, N).transpose() *
ConstEigenMatrixMap<float>(A, K, M));
return;
default:
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
case CblasTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, N, K) *
ConstEigenMatrixMap<float>(A, M, K).transpose());
return;
case CblasTrans:
C_mat.noalias() += alpha * (ConstEigenMatrixMap<float>(B, K, N).transpose() *
ConstEigenMatrixMap<float>(A, M, K).transpose());
return;
default:
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
default:
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
}
GemmEigen<float>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
#endif
}
template <>
void Gemm<double, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const double* A,
const double* B,
const float beta,
double* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<double>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int32_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const int32_t* A,
const int32_t* B,
const float beta,
int32_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint32_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const uint32_t* A,
const uint32_t* B,
const float beta,
uint32_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int64_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const int64_t* A,
const int64_t* B,
const float beta,
int64_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint64_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const uint64_t* A,
const uint64_t* B,
const float beta,
uint64_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen.
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void GemmEx<float, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
@ -343,6 +449,102 @@ void Gemm<float, CPUMathUtil>(
beta, C, gsl::narrow_cast<int>(N));
}
template <>
void Gemm<double, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const double* A,
const double* B,
const float beta,
double* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
cblas_dgemm(CblasRowMajor, TransA, TransB,
gsl::narrow_cast<int>(M),
gsl::narrow_cast<int>(N),
gsl::narrow_cast<int>(K),
gsl::narrow_cast<double>(alpha), A, lda, B, ldb,
gsl::narrow_cast<double>(beta), C, gsl::narrow_cast<int>(N));
}
template <>
void Gemm<int32_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const int32_t* A,
const int32_t* B,
const float beta,
int32_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No int32_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<int32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint32_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const uint32_t* A,
const uint32_t* B,
const float beta,
uint32_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No uint32_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<uint32_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<int64_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const int64_t* A,
const int64_t* B,
const float beta,
int64_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No int64_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<int64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void Gemm<uint64_t, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const uint64_t* A,
const uint64_t* B,
const float beta,
uint64_t* C,
CPUMathUtil* /*provider*/,
MLDataType /*math_type*/) {
// No uint64_t Gemm offering from MKLML. Directly fallback to Eigen.
GemmEigen<uint64_t>(TransA, TransB, M, N, K, alpha, A, B, beta, C);
}
template <>
void GemmEx<float, CPUMathUtil>(
const CBLAS_TRANSPOSE TransA,

View file

@ -7,75 +7,122 @@
namespace onnxruntime {
namespace test {
TEST(MathOpTest, MatMul) {
std::vector<float> vals{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
template <typename T>
struct MatMulTestData {
std::string name;
std::vector<int64_t> input0_dims;
std::vector<int64_t> input1_dims;
std::vector<int64_t> expected_dims;
std::vector<T> expected_vals;
};
struct MatMulTest {
std::string name;
std::vector<int64_t> input0_dims;
std::vector<int64_t> input1_dims;
std::vector<int64_t> expected_dims;
std::vector<float> expected_vals;
};
template <typename T>
std::vector<MatMulTestData<T>> GenerateTestCases()
{
std::vector<MatMulTestData<T>> test_cases;
MatMulTest testcases[] = {
{"test padding and broadcast",
{3, 1, 1, 2},
{2, 2, 2},
{3, 2, 1, 2},
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}},
{"test padding and broadcast",
{2, 3, 2},
{3, 2, 2, 1},
{3, 2, 3, 1},
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}},
{"test left 1D",
{2},
{3, 2, 1},
{3, 1},
{1, 3, 5}},
{"test right 1D",
{3, 1, 2},
{2},
{3, 1},
{1, 3, 5}},
{"test scalar output",
{3},
{3},
{},
{5}},
{"test 2D",
{3, 4},
{4, 3},
{3, 3},
{42, 48, 54, 114, 136, 158, 186, 224, 262}},
{"test 2D special",
{2, 2, 3},
{3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}},
{"test 2D special 2",
{2, 2, 3},
{1, 3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}},
};
test_cases.push_back(
{"test padding and broadcast",
{3, 1, 1, 2},
{2, 2, 2},
{3, 2, 1, 2},
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}});
for (auto t : testcases) {
OpTester test("MatMul");
test_cases.push_back(
{"test padding and broadcast",
{2, 3, 2},
{3, 2, 2, 1},
{3, 2, 3, 1},
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}});
test_cases.push_back(
{"test left 1D",
{2},
{3, 2, 1},
{3, 1},
{1, 3, 5}});
test_cases.push_back(
{"test right 1D",
{3, 1, 2},
{2},
{3, 1},
{1, 3, 5}});
test_cases.push_back(
{"test scalar output",
{3},
{3},
{},
{5}});
test_cases.push_back(
{"test 2D",
{3, 4},
{4, 3},
{3, 3},
{42, 48, 54, 114, 136, 158, 186, 224, 262}});
test_cases.push_back(
{"test 2D special",
{2, 2, 3},
{3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
test_cases.push_back(
{"test 2D special 2",
{2, 2, 3},
{1, 3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
return test_cases;
}
template <typename T>
void RunMatMulTest(int32_t opset_version = 7)
{
std::vector<T> common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
for (auto t : GenerateTestCases<T>()) {
OpTester test("MatMul", opset_version);
int64_t size0 = TensorShape::ReinterpretBaseType(t.input0_dims).SizeHelper(0, t.input0_dims.size());
std::vector<float> input0_vals(vals.cbegin(), vals.cbegin() + size0);
test.AddInput<float>("A", t.input0_dims, input0_vals);
std::vector<T> input0_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size0);
test.AddInput<T>("A", t.input0_dims, input0_vals);
int64_t size1 = TensorShape::ReinterpretBaseType(t.input1_dims).SizeHelper(0, t.input1_dims.size());
std::vector<float> input1_vals(vals.cbegin(), vals.cbegin() + size1);
test.AddInput<float>("B", t.input1_dims, input1_vals);
std::vector<T> input1_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size1);
test.AddInput<T>("B", t.input1_dims, input1_vals);
test.AddOutput<float>("Y", t.expected_dims, t.expected_vals);
test.AddOutput<T>("Y", t.expected_dims, t.expected_vals);
test.Run();
}
}
TEST(MathOpTest, MatMulFloatType) {
RunMatMulTest<float>();
}
TEST(MathOpTest, MatMulDoubleType) {
RunMatMulTest<double>();
}
TEST(MathOpTest, MatMulInt32Type) {
RunMatMulTest<int32_t>(9);
}
TEST(MathOpTest, MatMulUint32Type) {
RunMatMulTest<uint32_t>(9);
}
TEST(MathOpTest, MatMulInt64Type) {
RunMatMulTest<int64_t>(9);
}
TEST(MathOpTest, MatMulUint64Type) {
RunMatMulTest<uint64_t>(9);
}
} // namespace test
} // namespace onnxruntime