diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 1006c43803..55b0e77556 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -120,7 +120,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Acos); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -274,7 +275,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -393,7 +395,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -471,7 +474,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -797,7 +801,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1395,7 +1403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 0b43ec6a80..15a7034d1c 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -9,35 +9,66 @@ namespace onnxruntime { -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( Gemm, 7, 8, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + Gemm, + 7, + 8, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); // opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet. -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( Gemm, 9, 10, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + Gemm, + 9, + 10, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); // opset 11 made bias input 'C' optional -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( Gemm, 11, 12, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + Gemm, + 11, + 12, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); // opset 13 Adds BFloat16 support but we are not supporting it yet -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( Gemm, 13, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Gemm, + 13, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); bool GemmPackBFp32(const OpKernelInfo& info, const Tensor& tensor_b, diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 02a177d065..04f47fb3d8 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -115,7 +115,7 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { out_data, N, device_prop)); } else { // B is (M, N), no broadcast needed. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(float), cudaMemcpyDeviceToDevice)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(T), cudaMemcpyDeviceToDevice)); } } diff --git a/onnxruntime/core/providers/rocm/math/gemm.cc b/onnxruntime/core/providers/rocm/math/gemm.cc index 4d9b08847f..413744d595 100644 --- a/onnxruntime/core/providers/rocm/math/gemm.cc +++ b/onnxruntime/core/providers/rocm/math/gemm.cc @@ -115,7 +115,7 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { out_data, N)); } else { // B is (M, N), no broadcast needed. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(float), hipMemcpyDeviceToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(T), hipMemcpyDeviceToDevice)); } } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 895ac6da84..3c9859d72b 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -8,6 +8,7 @@ namespace onnxruntime { namespace test { +template void TestGemmNoTrans(bool b_is_initializer) { OpTester test("Gemm"); @@ -16,24 +17,29 @@ void TestGemmNoTrans(bool b_is_initializer) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f), b_is_initializer); - test.AddInput("C", {2, 3}, std::vector(6, 1.0f)); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f), b_is_initializer); + test.AddInput("C", {2, 3}, std::vector(6, 1.0f)); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -9.0f, -9.0f, -9.0f}); test.Run(); } -TEST(GemmOpTest, GemmNoTrans) { - TestGemmNoTrans(false); +TEST(GemmOpTest, GemmNoTrans_float) { + TestGemmNoTrans(false); +} + +TEST(GemmOpTest, GemmNoTrans_double) { + TestGemmNoTrans(false); } // NNAPI EP requires weight to be an initializer TEST(GemmOpTest, GemmNoTransBIsInitializer) { - TestGemmNoTrans(true); + TestGemmNoTrans(true); + TestGemmNoTrans(true); } // Only CUDA kernel has float 16 support @@ -75,7 +81,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) { } #endif -static void TestGemmBroadcast(bool b_is_initializer) { +template +void TestGemmBroadcast(bool b_is_initializer) { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -83,14 +90,14 @@ static void TestGemmBroadcast(bool b_is_initializer) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f), b_is_initializer); - test.AddInput("C", {3}, std::vector{1.0f, 2.0f, 3.0f}); - test.AddOutput("Y", {2, 3}, - {11.0f, 12.0f, 13.0f, - -9.0f, -8.0f, -7.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f), b_is_initializer); + test.AddInput("C", {3}, std::vector{1.0f, 2.0f, 3.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 12.0f, 13.0f, + -9.0f, -8.0f, -7.0f}); #if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO : Temporarily disabled due to accuracy issues #else @@ -99,13 +106,16 @@ static void TestGemmBroadcast(bool b_is_initializer) { } TEST(GemmOpTest, GemmBroadcast) { - TestGemmBroadcast(false); + TestGemmBroadcast(false); + TestGemmBroadcast(false); } TEST(GemmOpTest, GemmBroadcastBIsInitializer) { - TestGemmBroadcast(true); + TestGemmBroadcast(true); + TestGemmBroadcast(true); } +template static void TestGemmTrans(bool b_is_initializer) { OpTester test("Gemm"); @@ -114,34 +124,37 @@ static void TestGemmTrans(bool b_is_initializer) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {4, 2}, - {1.0f, -1.0f, - 2.0f, -2.0f, - 3.0f, -3.0f, - 4.0f, -4.0f}); - test.AddInput("B", {3, 4}, std::vector(12, 1.0f), b_is_initializer); - test.AddInput("C", {3}, std::vector(3, 1.0f)); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}); - #if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues - #else + test.AddInput("A", {4, 2}, + {1.0f, -1.0f, + 2.0f, -2.0f, + 3.0f, -3.0f, + 4.0f, -4.0f}); + test.AddInput("B", {3, 4}, std::vector(12, 1.0f), b_is_initializer); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -9.0f, -9.0f, -9.0f}); +#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else test.Run(); - #endif +#endif } TEST(GemmOpTest, GemmTrans) { - TestGemmTrans(false); + TestGemmTrans(false); + TestGemmTrans(false); } TEST(GemmOpTest, GemmTransBIsInitializer) { - TestGemmTrans(true); + TestGemmTrans(true); + TestGemmTrans(true); } // NNAPI EP's GEMM only works as A*B', add case only B is transposed // Also test NNAPI EP's handling of non-1D bias (C of Gemm) -TEST(GemmOpTest, GemmTransB) { +template +static void TestGemmTransB() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -149,14 +162,14 @@ TEST(GemmOpTest, GemmTransB) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {3, 4}, std::vector(12, 1.0f)); - test.AddInput("C", {1, 3}, std::vector(3, 1.0f)); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {3, 4}, std::vector(12, 1.0f)); + test.AddInput("C", {1, 3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -9.0f, -9.0f, -9.0f}); #if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues #else @@ -164,9 +177,15 @@ TEST(GemmOpTest, GemmTransB) { #endif } +TEST(GemmOpTest, GemmTransB) { + TestGemmTransB(); + TestGemmTransB(); +} + // NNAPI EP's GEMM only works as A*B', add case only B is transposed // Also test NNAPI EP's handling of non-1D bias (C of Gemm) which is broadcastable but not valid for NNAPI -TEST(GemmOpTest, GemmTransB_1) { +template +void TestGemmTransB_1() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -174,14 +193,14 @@ TEST(GemmOpTest, GemmTransB_1) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {3, 4}, std::vector(12, 1.0f)); - test.AddInput("C", {2, 1}, std::vector(2, 1.0f)); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {3, 4}, std::vector(12, 1.0f)); + test.AddInput("C", {2, 1}, std::vector(2, 1.0f)); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -9.0f, -9.0f, -9.0f}); #if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues #else @@ -189,7 +208,13 @@ TEST(GemmOpTest, GemmTransB_1) { #endif } -TEST(GemmOpTest, GemmAlphaBeta) { +TEST(GemmOpTest, GemmTransB_1) { + TestGemmTransB_1(); + TestGemmTransB_1(); +} + +template +void TestGemmAlphaBeta() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -197,14 +222,14 @@ TEST(GemmOpTest, GemmAlphaBeta) { test.AddAttribute("alpha", 0.5f); test.AddAttribute("beta", 2.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {3}, std::vector(3, 1.0f)); - test.AddOutput("Y", {2, 3}, - {7.0f, 7.0f, 7.0f, - -3.0f, -3.0f, -3.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {7.0f, 7.0f, 7.0f, + -3.0f, -3.0f, -3.0f}); #if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues #else @@ -212,7 +237,13 @@ TEST(GemmOpTest, GemmAlphaBeta) { #endif } -TEST(GemmOpTest, GemmNaN) { +TEST(GemmOpTest, GemmAlphaBeta) { + TestGemmAlphaBeta(); + TestGemmAlphaBeta(); +} + +template +void TestGemmNaN() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -220,18 +251,24 @@ TEST(GemmOpTest, GemmNaN) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 0.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {2, 3}, std::vector(6, 1.0f)); - test.AddOutput("Y", {2, 3}, - {10.0f, 10.0f, 10.0f, - -10.0f, -10.0f, -10.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {2, 3}, std::vector(6, 1.0f)); + test.AddOutput("Y", {2, 3}, + {10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser } -TEST(GemmOpTest, GemmScalarBroadcast) { +TEST(GemmOpTest, GemmNaN) { + TestGemmNaN(); + TestGemmNaN(); +} + +template +void TestGemmScalarBroadcast() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -239,37 +276,49 @@ TEST(GemmOpTest, GemmScalarBroadcast) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {1}, std::vector{1.0f}); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {1}, std::vector{1.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -9.0f, -9.0f, -9.0f}); + test.Run(); +} + +TEST(GemmOpTest, GemmScalarBroadcast) { + TestGemmScalarBroadcast(); + TestGemmScalarBroadcast(); +} + +template +void TestGemm2DBroadcast_1() { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {2, 1}, std::vector{1.0, 2.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -8.0f, -8.0f, -8.0f}); test.Run(); } TEST(GemmOpTest, Gemm2DBroadcast_1) { - OpTester test("Gemm"); - - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); - - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {2, 1}, std::vector{1.0f, 2.0f}); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -8.0f, -8.0f, -8.0f}); - test.Run(); + TestGemm2DBroadcast_1(); + TestGemm2DBroadcast_1(); } -TEST(GemmOpTest, Gemm2DBroadcast_2) { +template +void TestGemm2DBroadcast_2() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -278,18 +327,24 @@ TEST(GemmOpTest, Gemm2DBroadcast_2) { test.AddAttribute("beta", 1.0f); // Same as GemmBroadcast, but adding the unnecessary second dimension. - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {1, 3}, std::vector{1.0f, 2.0f, 3.0f}); - test.AddOutput("Y", {2, 3}, - {11.0f, 12.0f, 13.0f, - -9.0f, -8.0f, -7.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {1, 3}, std::vector{1.0f, 2.0f, 3.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 12.0f, 13.0f, + -9.0f, -8.0f, -7.0f}); test.Run(); } -TEST(GemmOpTest, GemmFalseBroadcast) { +TEST(GemmOpTest, Gemm2DBroadcast_2) { + TestGemm2DBroadcast_2(); + TestGemm2DBroadcast_2(); +} + +template +void TestGemmFalseBroadcast() { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -297,18 +352,24 @@ TEST(GemmOpTest, GemmFalseBroadcast) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {2, 3}, std::vector{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}); - test.AddOutput("Y", {2, 3}, - {11.0f, 11.0f, 11.0f, - -8.0f, -8.0f, -8.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {2, 3}, std::vector{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 11.0f, 11.0f, + -8.0f, -8.0f, -8.0f}); test.Run(); } -TEST(GemmOpTest, GemmEmptyTensor) { +TEST(GemmOpTest, GemmFalseBroadcast) { + TestGemmFalseBroadcast(); + TestGemmFalseBroadcast(); +} + +template +void TestGemmEmptyTensor() { OpTester test("Gemm"); test.AddAttribute("transA", static_cast(0)); @@ -316,16 +377,22 @@ TEST(GemmOpTest, GemmEmptyTensor) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {0, 4}, - {}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddInput("C", {3}, std::vector(3, 1.0f)); - test.AddOutput("Y", {0, 3}, - {}); + test.AddInput("A", {0, 4}, + {}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {0, 3}, + {}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider}); //TensorRT: doesn't support dynamic shape yet } -TEST(GemmOpTest, GemmNoBiasOpset11) { +TEST(GemmOpTest, GemmEmptyTensor) { + TestGemmEmptyTensor(); + TestGemmEmptyTensor(); +} + +template +static void TestGemmNoBiasOpset11() { OpTester test("Gemm", 11); test.AddAttribute("transA", static_cast(0)); @@ -333,29 +400,40 @@ TEST(GemmOpTest, GemmNoBiasOpset11) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f}); - test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); - test.AddOutput("Y", {2, 3}, - {10.0f, 10.0f, 10.0f, - -10.0f, -10.0f, -10.0f}); + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddOutput("Y", {2, 3}, + {10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}); + // tensorRT don't seem to support missing bias + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(GemmOpTest, GemmNoBiasOpset11) { + TestGemmNoBiasOpset11(); + TestGemmNoBiasOpset11(); +} + +template +static void TestGemmWithAlphaOpset11() { + OpTester test("Gemm", 11); + + test.AddAttribute("alpha", 2.0f); + + test.AddInput("A", {2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f}); + test.AddInput("B", {2, 2}, std::vector(4, 1.0f)); + test.AddOutput("Y", {2, 2}, + {6.0f, 6.0f, 14.0f, 14.0f}); // tensorRT don't seem to support missing bias test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(GemmOpTest, GemmWithAlphaOpset11) { - OpTester test("Gemm", 11); - - test.AddAttribute("alpha", 2.0f); - - test.AddInput("A", {2, 2}, - {1.0f, 2.0f, 3.0f, 4.0f}); - test.AddInput("B", {2, 2}, std::vector(4, 1.0f)); - test.AddOutput("Y", {2, 2}, - {6.0f, 6.0f, 14.0f, 14.0f}); - // tensorRT don't seem to support missing bias - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + TestGemmWithAlphaOpset11(); + TestGemmWithAlphaOpset11(); } } // namespace test