mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Support opset-11 Gemm kernels (#1923)
* Support optional bias in Gemm * Fix test * Update * More updates * Update * Update * Update gemm.cc * Update * Update * Fix build break * Update * PR comments * Update
This commit is contained in:
parent
31aff686e0
commit
a5e134405d
9 changed files with 92 additions and 37 deletions
|
|
@ -261,7 +261,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul);
|
||||
|
|
@ -358,6 +358,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lp
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm);
|
||||
|
||||
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -604,7 +605,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul)>,
|
||||
|
|
@ -701,6 +702,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -13,9 +13,17 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
Gemm<float>);
|
||||
|
||||
// opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet.
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Gemm,
|
||||
9,
|
||||
10,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
|
||||
// opset 11 made bias input 'C' optional
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Gemm,
|
||||
9,
|
||||
11,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -31,10 +31,12 @@ class Gemm : public OpKernel {
|
|||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
const auto X = context->Input<Tensor>(0);
|
||||
const auto W = context->Input<Tensor>(1);
|
||||
const auto B = context->Input<Tensor>(2);
|
||||
GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, B->Shape());
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto* W = context->Input<Tensor>(1);
|
||||
const auto* B = context->Input<Tensor>(2);
|
||||
// Bias could be missing. Treat as scalar 0 if that is the case.
|
||||
GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans,
|
||||
B != nullptr ? B->Shape() : TensorShape({}));
|
||||
|
||||
if (!helper.State().IsOK())
|
||||
return helper.State();
|
||||
|
|
@ -42,13 +44,13 @@ class Gemm : public OpKernel {
|
|||
int64_t M = helper.M();
|
||||
int64_t N = helper.N();
|
||||
auto Y = context->Output(0, {M, N});
|
||||
// if input is emtpy tensor, return directly as nothing need to be calculated.
|
||||
// if input is empty tensor, return directly as nothing need to be calculated.
|
||||
if (M == 0 || N == 0)
|
||||
return Status::OK();
|
||||
T* y_data = Y->template MutableData<T>();
|
||||
|
||||
// Broadcast the bias as needed.
|
||||
if (beta_ != 0) {
|
||||
// Broadcast the bias as needed if bias is given
|
||||
if (beta_ != 0 && B != nullptr) {
|
||||
auto output_mat = EigenMatrixMapRowMajor<T>(y_data, M, N);
|
||||
const auto& b_shape = B->Shape();
|
||||
const T* b_data = B->template Data<T>();
|
||||
|
|
@ -77,7 +79,9 @@ class Gemm : public OpKernel {
|
|||
alpha_,
|
||||
X->template Data<T>(),
|
||||
W->template Data<T>(),
|
||||
beta_,
|
||||
// ideally we need to set the output buffer contents to 0 if bias is missing,
|
||||
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
|
||||
B != nullptr ? beta_ : 0,
|
||||
y_data,
|
||||
tp);
|
||||
|
||||
|
|
|
|||
|
|
@ -50,17 +50,18 @@ class GemmHelper {
|
|||
Status State() const { return status_; }
|
||||
|
||||
private:
|
||||
bool IsValidBroadcast(const TensorShape& shape, int64_t M, int64_t N) {
|
||||
if (shape.NumDimensions() != 1 && shape.NumDimensions() != 2)
|
||||
bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) {
|
||||
// valid shapes are (,) , (1, N) , (M, 1) , (M, N)
|
||||
if (bias_shape.NumDimensions() > 2)
|
||||
return false;
|
||||
// shape is (1,) or (1, 1), or (,)
|
||||
if (shape.Size() == 1)
|
||||
if (bias_shape.Size() == 1)
|
||||
return true;
|
||||
// shape is (N,) or (1, N) or (M, 1)
|
||||
// or (M, N), in last case no broadcast needed, but don't fail it
|
||||
return ((shape.NumDimensions() == 1 && shape[0] == N) ||
|
||||
(shape.NumDimensions() == 2 && shape[0] == M && (shape[1] == 1 || shape[1] == N)) ||
|
||||
(shape.NumDimensions() == 2 && shape[0] == 1 && shape[1] == N));
|
||||
// valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N),
|
||||
// In last case no broadcasting needed, so don't fail it
|
||||
return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) ||
|
||||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) ||
|
||||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -214,9 +214,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Ga
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul);
|
||||
|
|
@ -537,10 +537,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Less);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Less);
|
||||
|
||||
// opset 10
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
|
||||
|
||||
// opset 11
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
|
|
@ -555,9 +561,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul)>,
|
||||
|
|
@ -880,6 +886,11 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
|
||||
|
||||
// opset 11
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -20,10 +20,20 @@ namespace cuda {
|
|||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
9, \
|
||||
10, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
11, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
|
|
@ -38,10 +48,11 @@ template <typename T>
|
|||
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
const auto X = ctx->Input<Tensor>(0);
|
||||
const auto W = ctx->Input<Tensor>(1);
|
||||
const auto B = ctx->Input<Tensor>(2);
|
||||
GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape());
|
||||
const auto* X = ctx->Input<Tensor>(0);
|
||||
const auto* W = ctx->Input<Tensor>(1);
|
||||
const auto* B = ctx->Input<Tensor>(2);
|
||||
// Bias could be missing. Treat as scalar 0 if that is the case.
|
||||
GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B != nullptr ? B->Shape() : TensorShape({}));
|
||||
|
||||
if (!helper.State().IsOK())
|
||||
return helper.State();
|
||||
|
|
@ -49,14 +60,14 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
int M = gsl::narrow_cast<int>(helper.M());
|
||||
int N = gsl::narrow_cast<int>(helper.N());
|
||||
int K = gsl::narrow_cast<int>(helper.K());
|
||||
auto Y = ctx->Output(0, TensorShape(std::vector<int64_t>{M, N}));
|
||||
auto* Y = ctx->Output(0, TensorShape(std::vector<int64_t>{M, N}));
|
||||
CudaT* out_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
|
||||
|
||||
CudaT one = ToCudaType<T>::FromFloat(1.0f);
|
||||
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
|
||||
|
||||
// broadcast bias if needed
|
||||
if (beta_ != 0) {
|
||||
// broadcast bias if needed and is present
|
||||
if (beta_ != 0 && B != nullptr) {
|
||||
auto& b_shape = B->Shape();
|
||||
const CudaT* b_data = reinterpret_cast<const CudaT*>(B->template Data<T>());
|
||||
|
||||
|
|
@ -112,7 +123,9 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
(trans_B_ ? K : N),
|
||||
reinterpret_cast<const CudaT*>(X->template Data<T>()),
|
||||
(trans_A_ ? M : K),
|
||||
&beta,
|
||||
// ideally we need to set the output buffer contents to 0 if bias is missing,
|
||||
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
|
||||
B != nullptr ? &beta : &zero,
|
||||
out_data, N));
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -471,7 +471,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"bitshift_left_uint64", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint32", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint16", "BitShift(11) not implemented yet"},
|
||||
{"gemm_default_scalar_bias", "Gemm ValidBroadcast() has bug to be fixed."},
|
||||
};
|
||||
|
||||
#ifdef USE_NGRAPH
|
||||
|
|
|
|||
|
|
@ -232,5 +232,24 @@ TEST(GemmOpTest, GemmEmptyTensor) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmNoBiasOpset11) {
|
||||
OpTester test("Gemm", 11);
|
||||
|
||||
test.AddAttribute("transA", static_cast<int64_t>(0));
|
||||
test.AddAttribute("transB", static_cast<int64_t>(0));
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f});
|
||||
// NGraph and tensorRT don't seem to support missing bias
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -148,8 +148,6 @@ def create_backend_test(testname=None):
|
|||
'^test_reduce_*',
|
||||
'^test_onehot_*',
|
||||
'^test_constant_pad_cpu.*',
|
||||
'^test_gemm_default_scalar_bias_cpu.*',
|
||||
'^test_gemm_*',
|
||||
'^test_edge_pad_cpu.*',
|
||||
'^test_reflect_pad_cpu.*'
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue