mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Support double for operator Gemm + fix bug in gemm implementation for cuda, rocm when sizeof(type) != sizeof(float) (#6223)
* Support double for operator Gemm * fix type size while copying data in gemm operator for GPU * fix type in gemm implementation for rocm
This commit is contained in:
parent
70e2f96ef4
commit
5968a91ea6
5 changed files with 276 additions and 158 deletions
|
|
@ -120,7 +120,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Acos);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
|
||||
|
|
@ -274,7 +275,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
|
||||
|
|
@ -393,7 +395,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
|
||||
|
|
@ -471,7 +474,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
|
||||
|
|
@ -797,7 +801,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Acos)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
|
||||
Hardmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
|
||||
|
|
@ -1062,8 +1067,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float,
|
||||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double,
|
||||
|
|
@ -1208,7 +1215,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
ConcatFromSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t,
|
||||
BitShift)>,
|
||||
|
|
@ -1395,7 +1403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum)>,
|
||||
|
|
|
|||
|
|
@ -9,35 +9,66 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
7,
|
||||
8,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
7,
|
||||
8,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Gemm<double>);
|
||||
|
||||
// opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet.
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
9,
|
||||
10,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
9,
|
||||
10,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Gemm<double>);
|
||||
|
||||
// opset 11 made bias input 'C' optional
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
11,
|
||||
12,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Gemm,
|
||||
11,
|
||||
12,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Gemm<double>);
|
||||
|
||||
// opset 13 Adds BFloat16 support but we are not supporting it yet
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Gemm,
|
||||
13,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Gemm,
|
||||
13,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Gemm<double>);
|
||||
|
||||
bool GemmPackBFp32(const OpKernelInfo& info,
|
||||
const Tensor& tensor_b,
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
out_data, N, device_prop));
|
||||
} else {
|
||||
// B is (M, N), no broadcast needed.
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(float), cudaMemcpyDeviceToDevice));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(T), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
out_data, N));
|
||||
} else {
|
||||
// B is (M, N), no broadcast needed.
|
||||
HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(float), hipMemcpyDeviceToDevice));
|
||||
HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(T), hipMemcpyDeviceToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
void TestGemmNoTrans(bool b_is_initializer) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
|
|
@ -16,24 +17,29 @@ void TestGemmNoTrans(bool b_is_initializer) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<float>("C", {2, 3}, std::vector<float>(6, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<T>("C", {2, 3}, std::vector<T>(6, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNoTrans) {
|
||||
TestGemmNoTrans(false);
|
||||
TEST(GemmOpTest, GemmNoTrans_float) {
|
||||
TestGemmNoTrans<float>(false);
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNoTrans_double) {
|
||||
TestGemmNoTrans<double>(false);
|
||||
}
|
||||
|
||||
// NNAPI EP requires weight to be an initializer
|
||||
TEST(GemmOpTest, GemmNoTransBIsInitializer) {
|
||||
TestGemmNoTrans(true);
|
||||
TestGemmNoTrans<float>(true);
|
||||
TestGemmNoTrans<double>(true);
|
||||
}
|
||||
|
||||
// Only CUDA kernel has float 16 support
|
||||
|
|
@ -75,7 +81,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static void TestGemmBroadcast(bool b_is_initializer) {
|
||||
template <typename T>
|
||||
void TestGemmBroadcast(bool b_is_initializer) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -83,14 +90,14 @@ static void TestGemmBroadcast(bool b_is_initializer) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<float>("C", {3}, std::vector<float>{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
-9.0f, -8.0f, -7.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<T>("C", {3}, std::vector<T>{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
-9.0f, -8.0f, -7.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO : Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
|
|
@ -99,13 +106,16 @@ static void TestGemmBroadcast(bool b_is_initializer) {
|
|||
}
|
||||
|
||||
TEST(GemmOpTest, GemmBroadcast) {
|
||||
TestGemmBroadcast(false);
|
||||
TestGemmBroadcast<float>(false);
|
||||
TestGemmBroadcast<double>(false);
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmBroadcastBIsInitializer) {
|
||||
TestGemmBroadcast(true);
|
||||
TestGemmBroadcast<float>(true);
|
||||
TestGemmBroadcast<double>(true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestGemmTrans(bool b_is_initializer) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
|
|
@ -114,34 +124,37 @@ static void TestGemmTrans(bool b_is_initializer) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {4, 2},
|
||||
{1.0f, -1.0f,
|
||||
2.0f, -2.0f,
|
||||
3.0f, -3.0f,
|
||||
4.0f, -4.0f});
|
||||
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
test.AddInput<T>("A", {4, 2},
|
||||
{1.0f, -1.0f,
|
||||
2.0f, -2.0f,
|
||||
3.0f, -3.0f,
|
||||
4.0f, -4.0f});
|
||||
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f), b_is_initializer);
|
||||
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
test.Run();
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmTrans) {
|
||||
TestGemmTrans(false);
|
||||
TestGemmTrans<float>(false);
|
||||
TestGemmTrans<double>(false);
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmTransBIsInitializer) {
|
||||
TestGemmTrans(true);
|
||||
TestGemmTrans<float>(true);
|
||||
TestGemmTrans<double>(true);
|
||||
}
|
||||
|
||||
// NNAPI EP's GEMM only works as A*B', add case only B is transposed
|
||||
// Also test NNAPI EP's handling of non-1D bias (C of Gemm)
|
||||
TEST(GemmOpTest, GemmTransB) {
|
||||
template <typename T>
|
||||
static void TestGemmTransB() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -149,14 +162,14 @@ TEST(GemmOpTest, GemmTransB) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {1, 3}, std::vector<float>(3, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {1, 3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
|
|
@ -164,9 +177,15 @@ TEST(GemmOpTest, GemmTransB) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmTransB) {
|
||||
TestGemmTransB<float>();
|
||||
TestGemmTransB<double>();
|
||||
}
|
||||
|
||||
// NNAPI EP's GEMM only works as A*B', add case only B is transposed
|
||||
// Also test NNAPI EP's handling of non-1D bias (C of Gemm) which is broadcastable but not valid for NNAPI
|
||||
TEST(GemmOpTest, GemmTransB_1) {
|
||||
template <typename T>
|
||||
void TestGemmTransB_1() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -174,14 +193,14 @@ TEST(GemmOpTest, GemmTransB_1) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {3, 4}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {2, 1}, std::vector<float>(2, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {3, 4}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {2, 1}, std::vector<T>(2, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
|
|
@ -189,7 +208,13 @@ TEST(GemmOpTest, GemmTransB_1) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmAlphaBeta) {
|
||||
TEST(GemmOpTest, GemmTransB_1) {
|
||||
TestGemmTransB_1<float>();
|
||||
TestGemmTransB_1<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmAlphaBeta() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -197,14 +222,14 @@ TEST(GemmOpTest, GemmAlphaBeta) {
|
|||
test.AddAttribute("alpha", 0.5f);
|
||||
test.AddAttribute("beta", 2.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{7.0f, 7.0f, 7.0f,
|
||||
-3.0f, -3.0f, -3.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{7.0f, 7.0f, 7.0f,
|
||||
-3.0f, -3.0f, -3.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
|
|
@ -212,7 +237,13 @@ TEST(GemmOpTest, GemmAlphaBeta) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNaN) {
|
||||
TEST(GemmOpTest, GemmAlphaBeta) {
|
||||
TestGemmAlphaBeta<float>();
|
||||
TestGemmAlphaBeta<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmNaN() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -220,18 +251,24 @@ TEST(GemmOpTest, GemmNaN) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 0.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {2, 3}, std::vector<float>(6, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {2, 3}, std::vector<T>(6, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmScalarBroadcast) {
|
||||
TEST(GemmOpTest, GemmNaN) {
|
||||
TestGemmNaN<float>();
|
||||
TestGemmNaN<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmScalarBroadcast() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -239,37 +276,49 @@ TEST(GemmOpTest, GemmScalarBroadcast) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {1}, std::vector<float>{1.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {1}, std::vector<T>{1.0f});
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-9.0f, -9.0f, -9.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmScalarBroadcast) {
|
||||
TestGemmScalarBroadcast<float>();
|
||||
TestGemmScalarBroadcast<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemm2DBroadcast_1() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {2, 1}, std::vector<T>{1.0, 2.0f});
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-8.0f, -8.0f, -8.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, Gemm2DBroadcast_1) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {2, 1}, std::vector<float>{1.0f, 2.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-8.0f, -8.0f, -8.0f});
|
||||
test.Run();
|
||||
TestGemm2DBroadcast_1<float>();
|
||||
TestGemm2DBroadcast_1<double>();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, Gemm2DBroadcast_2) {
|
||||
template <typename T>
|
||||
void TestGemm2DBroadcast_2() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -278,18 +327,24 @@ TEST(GemmOpTest, Gemm2DBroadcast_2) {
|
|||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
// Same as GemmBroadcast, but adding the unnecessary second dimension.
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {1, 3}, std::vector<float>{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
-9.0f, -8.0f, -7.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {1, 3}, std::vector<T>{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
-9.0f, -8.0f, -7.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmFalseBroadcast) {
|
||||
TEST(GemmOpTest, Gemm2DBroadcast_2) {
|
||||
TestGemm2DBroadcast_2<float>();
|
||||
TestGemm2DBroadcast_2<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmFalseBroadcast() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -297,18 +352,24 @@ TEST(GemmOpTest, GemmFalseBroadcast) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {2, 3}, std::vector<float>{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-8.0f, -8.0f, -8.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {2, 3}, std::vector<T>{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f});
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{11.0f, 11.0f, 11.0f,
|
||||
-8.0f, -8.0f, -8.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmEmptyTensor) {
|
||||
TEST(GemmOpTest, GemmFalseBroadcast) {
|
||||
TestGemmFalseBroadcast<float>();
|
||||
TestGemmFalseBroadcast<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmEmptyTensor() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", static_cast<int64_t>(0));
|
||||
|
|
@ -316,16 +377,22 @@ TEST(GemmOpTest, GemmEmptyTensor) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {0, 4},
|
||||
{});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {3}, std::vector<float>(3, 1.0f));
|
||||
test.AddOutput<float>("Y", {0, 3},
|
||||
{});
|
||||
test.AddInput<T>("A", {0, 4},
|
||||
{});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {0, 3},
|
||||
{});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider}); //TensorRT: doesn't support dynamic shape yet
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNoBiasOpset11) {
|
||||
TEST(GemmOpTest, GemmEmptyTensor) {
|
||||
TestGemmEmptyTensor<float>();
|
||||
TestGemmEmptyTensor<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestGemmNoBiasOpset11() {
|
||||
OpTester test("Gemm", 11);
|
||||
|
||||
test.AddAttribute("transA", static_cast<int64_t>(0));
|
||||
|
|
@ -333,29 +400,40 @@ TEST(GemmOpTest, GemmNoBiasOpset11) {
|
|||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f});
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f});
|
||||
// tensorRT don't seem to support missing bias
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNoBiasOpset11) {
|
||||
TestGemmNoBiasOpset11<float>();
|
||||
TestGemmNoBiasOpset11<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestGemmWithAlphaOpset11() {
|
||||
OpTester test("Gemm", 11);
|
||||
|
||||
test.AddAttribute("alpha", 2.0f);
|
||||
|
||||
test.AddInput<T>("A", {2, 2},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f});
|
||||
test.AddInput<T>("B", {2, 2}, std::vector<T>(4, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 2},
|
||||
{6.0f, 6.0f, 14.0f, 14.0f});
|
||||
// tensorRT don't seem to support missing bias
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmWithAlphaOpset11) {
|
||||
OpTester test("Gemm", 11);
|
||||
|
||||
test.AddAttribute("alpha", 2.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 2},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f});
|
||||
test.AddInput<float>("B", {2, 2}, std::vector<float>(4, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 2},
|
||||
{6.0f, 6.0f, 14.0f, 14.0f});
|
||||
// tensorRT don't seem to support missing bias
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
TestGemmWithAlphaOpset11<float>();
|
||||
TestGemmWithAlphaOpset11<double>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue