Support double for operator Gemm + fix bug in gemm implementation for cuda, rocm when sizeof(type) != sizeof(float) (#6223)

* Support double for operator Gemm
* fix type size while copying data in gemm operator for GPU
* fix type in gemm implementation for rocm
This commit is contained in:
Xavier Dupré 2020-12-31 11:24:16 +01:00 committed by GitHub
parent 70e2f96ef4
commit 5968a91ea6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 276 additions and 158 deletions

View file

@ -120,7 +120,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Acos);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
@ -274,7 +275,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
@ -393,7 +395,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
@ -471,7 +474,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
@ -797,7 +801,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Acos)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
Hardmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
@ -1062,8 +1067,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double,
@ -1208,7 +1215,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ConcatFromSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t,
BitShift)>,
@ -1395,7 +1403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum)>,

View file

@ -9,35 +9,66 @@
namespace onnxruntime {
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
7,
8,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
7,
8,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Gemm<double>);
// opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet.
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
9,
10,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
9,
10,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Gemm<double>);
// opset 11 made bias input 'C' optional
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
11,
12,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Gemm,
11,
12,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Gemm<double>);
// opset 13 Adds BFloat16 support but we are not supporting it yet
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Gemm,
13,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Gemm,
13,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Gemm<double>);
bool GemmPackBFp32(const OpKernelInfo& info,
const Tensor& tensor_b,

View file

@ -115,7 +115,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
out_data, N, device_prop));
} else {
// B is (M, N), no broadcast needed.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(float), cudaMemcpyDeviceToDevice));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(T), cudaMemcpyDeviceToDevice));
}
}

View file

@ -115,7 +115,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
out_data, N));
} else {
// B is (M, N), no broadcast needed.
HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(float), hipMemcpyDeviceToDevice));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(T), hipMemcpyDeviceToDevice));
}
}

View file

@ -8,6 +8,7 @@
namespace onnxruntime {
namespace test {
template <typename T>
void TestGemmNoTrans(bool b_is_initializer) {
OpTester test("Gemm");
@ -16,24 +17,29 @@ void TestGemmNoTrans(bool b_is_initializer) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f), b_is_initializer);
test.AddInput<float>("C", {2, 3}, std::vector<float>(6, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f), b_is_initializer);
test.AddInput<T>("C", {2, 3}, std::vector<T>(6, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.Run();
}
TEST(GemmOpTest, GemmNoTrans) {
TestGemmNoTrans(false);
TEST(GemmOpTest, GemmNoTrans_float) {
TestGemmNoTrans<float>(false);
}
TEST(GemmOpTest, GemmNoTrans_double) {
TestGemmNoTrans<double>(false);
}
// NNAPI EP requires weight to be an initializer
TEST(GemmOpTest, GemmNoTransBIsInitializer) {
TestGemmNoTrans(true);
TestGemmNoTrans<float>(true);
TestGemmNoTrans<double>(true);
}
// Only CUDA kernel has float 16 support
@ -75,7 +81,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) {
}
#endif
static void TestGemmBroadcast(bool b_is_initializer) {
template <typename T>
void TestGemmBroadcast(bool b_is_initializer) {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -83,14 +90,14 @@ static void TestGemmBroadcast(bool b_is_initializer) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f), b_is_initializer);
test.AddInput<float>("C", {3}, std::vector<float>{1.0f, 2.0f, 3.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 12.0f, 13.0f,
-9.0f, -8.0f, -7.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f), b_is_initializer);
test.AddInput<T>("C", {3}, std::vector<T>{1.0f, 2.0f, 3.0f});
test.AddOutput<T>("Y", {2, 3},
{11.0f, 12.0f, 13.0f,
-9.0f, -8.0f, -7.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO : Temporarily disabled due to accuracy issues
#else
@ -99,13 +106,16 @@ static void TestGemmBroadcast(bool b_is_initializer) {
}
TEST(GemmOpTest, GemmBroadcast) {
TestGemmBroadcast(false);
TestGemmBroadcast<float>(false);
TestGemmBroadcast<double>(false);
}
TEST(GemmOpTest, GemmBroadcastBIsInitializer) {
TestGemmBroadcast(true);
TestGemmBroadcast<float>(true);
TestGemmBroadcast<double>(true);
}
template <typename T>
static void TestGemmTrans(bool b_is_initializer) {
OpTester test("Gemm");
@ -114,34 +124,37 @@ static void TestGemmTrans(bool b_is_initializer) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {4, 2},
{1.0f, -1.0f,
2.0f, -2.0f,
3.0f, -3.0f,
4.0f, -4.0f});
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f), b_is_initializer);
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
#else
test.AddInput<T>("A", {4, 2},
{1.0f, -1.0f,
2.0f, -2.0f,
3.0f, -3.0f,
4.0f, -4.0f});
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f), b_is_initializer);
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
#else
test.Run();
#endif
#endif
}
TEST(GemmOpTest, GemmTrans) {
TestGemmTrans(false);
TestGemmTrans<float>(false);
TestGemmTrans<double>(false);
}
TEST(GemmOpTest, GemmTransBIsInitializer) {
TestGemmTrans(true);
TestGemmTrans<float>(true);
TestGemmTrans<double>(true);
}
// NNAPI EP's GEMM only works as A*B', add case only B is transposed
// Also test NNAPI EP's handling of non-1D bias (C of Gemm)
TEST(GemmOpTest, GemmTransB) {
template <typename T>
static void TestGemmTransB() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -149,14 +162,14 @@ TEST(GemmOpTest, GemmTransB) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {1, 3}, std::vector<float>(3, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {1, 3}, std::vector<T>(3, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
#else
@ -164,9 +177,15 @@ TEST(GemmOpTest, GemmTransB) {
#endif
}
TEST(GemmOpTest, GemmTransB) {
TestGemmTransB<float>();
TestGemmTransB<double>();
}
// NNAPI EP's GEMM only works as A*B', add case only B is transposed
// Also test NNAPI EP's handling of non-1D bias (C of Gemm) which is broadcastable but not valid for NNAPI
TEST(GemmOpTest, GemmTransB_1) {
template <typename T>
void TestGemmTransB_1() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -174,14 +193,14 @@ TEST(GemmOpTest, GemmTransB_1) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {2, 1}, std::vector<float>(2, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {2, 1}, std::vector<T>(2, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
#else
@ -189,7 +208,13 @@ TEST(GemmOpTest, GemmTransB_1) {
#endif
}
TEST(GemmOpTest, GemmAlphaBeta) {
TEST(GemmOpTest, GemmTransB_1) {
TestGemmTransB_1<float>();
TestGemmTransB_1<double>();
}
template <typename T>
void TestGemmAlphaBeta() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -197,14 +222,14 @@ TEST(GemmOpTest, GemmAlphaBeta) {
test.AddAttribute("alpha", 0.5f);
test.AddAttribute("beta", 2.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{7.0f, 7.0f, 7.0f,
-3.0f, -3.0f, -3.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{7.0f, 7.0f, 7.0f,
-3.0f, -3.0f, -3.0f});
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
#else
@ -212,7 +237,13 @@ TEST(GemmOpTest, GemmAlphaBeta) {
#endif
}
TEST(GemmOpTest, GemmNaN) {
TEST(GemmOpTest, GemmAlphaBeta) {
TestGemmAlphaBeta<float>();
TestGemmAlphaBeta<double>();
}
template <typename T>
void TestGemmNaN() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -220,18 +251,24 @@ TEST(GemmOpTest, GemmNaN) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 0.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {2, 3}, std::vector<float>(6, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {2, 3}, std::vector<T>(6, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser
}
TEST(GemmOpTest, GemmScalarBroadcast) {
TEST(GemmOpTest, GemmNaN) {
TestGemmNaN<float>();
TestGemmNaN<double>();
}
template <typename T>
void TestGemmScalarBroadcast() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -239,37 +276,49 @@ TEST(GemmOpTest, GemmScalarBroadcast) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {1}, std::vector<float>{1.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {1}, std::vector<T>{1.0f});
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-9.0f, -9.0f, -9.0f});
test.Run();
}
TEST(GemmOpTest, GemmScalarBroadcast) {
TestGemmScalarBroadcast<float>();
TestGemmScalarBroadcast<double>();
}
template <typename T>
void TestGemm2DBroadcast_1() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {2, 1}, std::vector<T>{1.0, 2.0f});
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-8.0f, -8.0f, -8.0f});
test.Run();
}
TEST(GemmOpTest, Gemm2DBroadcast_1) {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {2, 1}, std::vector<float>{1.0f, 2.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-8.0f, -8.0f, -8.0f});
test.Run();
TestGemm2DBroadcast_1<float>();
TestGemm2DBroadcast_1<double>();
}
TEST(GemmOpTest, Gemm2DBroadcast_2) {
template <typename T>
void TestGemm2DBroadcast_2() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -278,18 +327,24 @@ TEST(GemmOpTest, Gemm2DBroadcast_2) {
test.AddAttribute("beta", 1.0f);
// Same as GemmBroadcast, but adding the unnecessary second dimension.
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {1, 3}, std::vector<float>{1.0f, 2.0f, 3.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 12.0f, 13.0f,
-9.0f, -8.0f, -7.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {1, 3}, std::vector<T>{1.0f, 2.0f, 3.0f});
test.AddOutput<T>("Y", {2, 3},
{11.0f, 12.0f, 13.0f,
-9.0f, -8.0f, -7.0f});
test.Run();
}
TEST(GemmOpTest, GemmFalseBroadcast) {
TEST(GemmOpTest, Gemm2DBroadcast_2) {
TestGemm2DBroadcast_2<float>();
TestGemm2DBroadcast_2<double>();
}
template <typename T>
void TestGemmFalseBroadcast() {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -297,18 +352,24 @@ TEST(GemmOpTest, GemmFalseBroadcast) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {2, 3}, std::vector<float>{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-8.0f, -8.0f, -8.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {2, 3}, std::vector<T>{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f});
test.AddOutput<T>("Y", {2, 3},
{11.0f, 11.0f, 11.0f,
-8.0f, -8.0f, -8.0f});
test.Run();
}
TEST(GemmOpTest, GemmEmptyTensor) {
TEST(GemmOpTest, GemmFalseBroadcast) {
TestGemmFalseBroadcast<float>();
TestGemmFalseBroadcast<double>();
}
template <typename T>
void TestGemmEmptyTensor() {
OpTester test("Gemm");
test.AddAttribute("transA", static_cast<int64_t>(0));
@ -316,16 +377,22 @@ TEST(GemmOpTest, GemmEmptyTensor) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {0, 4},
{});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
test.AddOutput<float>("Y", {0, 3},
{});
test.AddInput<T>("A", {0, 4},
{});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
test.AddOutput<T>("Y", {0, 3},
{});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider}); //TensorRT: doesn't support dynamic shape yet
}
TEST(GemmOpTest, GemmNoBiasOpset11) {
TEST(GemmOpTest, GemmEmptyTensor) {
TestGemmEmptyTensor<float>();
TestGemmEmptyTensor<double>();
}
template <typename T>
static void TestGemmNoBiasOpset11() {
OpTester test("Gemm", 11);
test.AddAttribute("transA", static_cast<int64_t>(0));
@ -333,29 +400,40 @@ TEST(GemmOpTest, GemmNoBiasOpset11) {
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f});
test.AddInput<T>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
test.AddOutput<T>("Y", {2, 3},
{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f});
// tensorRT don't seem to support missing bias
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
TEST(GemmOpTest, GemmNoBiasOpset11) {
TestGemmNoBiasOpset11<float>();
TestGemmNoBiasOpset11<double>();
}
template <typename T>
static void TestGemmWithAlphaOpset11() {
OpTester test("Gemm", 11);
test.AddAttribute("alpha", 2.0f);
test.AddInput<T>("A", {2, 2},
{1.0f, 2.0f, 3.0f, 4.0f});
test.AddInput<T>("B", {2, 2}, std::vector<T>(4, 1.0f));
test.AddOutput<T>("Y", {2, 2},
{6.0f, 6.0f, 14.0f, 14.0f});
// tensorRT don't seem to support missing bias
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
TEST(GemmOpTest, GemmWithAlphaOpset11) {
OpTester test("Gemm", 11);
test.AddAttribute("alpha", 2.0f);
test.AddInput<float>("A", {2, 2},
{1.0f, 2.0f, 3.0f, 4.0f});
test.AddInput<float>("B", {2, 2}, std::vector<float>(4, 1.0f));
test.AddOutput<float>("Y", {2, 2},
{6.0f, 6.0f, 14.0f, 14.0f});
// tensorRT don't seem to support missing bias
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
TestGemmWithAlphaOpset11<float>();
TestGemmWithAlphaOpset11<double>();
}
} // namespace test