mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Address ZeroK case for Gemm for CPU and CUDA (#22111)
### Description When K == 0 output a MxN matrix filled with bias if present or filled with zeros. This brings it inline with MatMul behavior especially when Gemm is used to fuse MatMul with Add. ### Motivation and Context * Comply with numpy spec of MatMul * Address a case when empty initializers are used for computation.
This commit is contained in:
parent
8d2d40781c
commit
fe8a10caa4
5 changed files with 95 additions and 27 deletions
|
|
@ -154,6 +154,14 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
|
|||
// Broadcast the bias as needed if bias is given
|
||||
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
|
||||
|
||||
if (K == 0) {
|
||||
if (beta == 0 || c_data == nullptr) {
|
||||
EigenMatrixMapRowMajor<T> dest(y_data, narrow<Eigen::Index>(M), narrow<Eigen::Index>(N));
|
||||
dest.setZero();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
math::Gemm<T>(trans_a, trans_b,
|
||||
M, N, K,
|
||||
alpha,
|
||||
|
|
@ -179,16 +187,18 @@ void Gemm<MLFloat16>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans
|
|||
if (M == 0 || N == 0)
|
||||
return;
|
||||
|
||||
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wclass-memaccess"
|
||||
#endif
|
||||
// MLFloat16's constructor is explicit, so here we need to use memset
|
||||
if (K == 0) {
|
||||
if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) {
|
||||
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
|
||||
} else {
|
||||
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
|
||||
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (c_data == nullptr)
|
||||
memset(&beta, 0, sizeof(MLFloat16));
|
||||
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
beta = onnxruntime::MLFloat16::Zero;
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
bool support_mlas = false;
|
||||
if (c_shape == nullptr) {
|
||||
|
|
@ -413,19 +423,24 @@ Status Gemm<float>::Compute(OpKernelContext* context) const {
|
|||
c_data, c_shape, y_data, thread_pool);
|
||||
} else {
|
||||
GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data);
|
||||
MlasGemm(
|
||||
trans_A_,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha_,
|
||||
A->Data<float>(),
|
||||
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
|
||||
packed_b_.get(),
|
||||
c_data != nullptr ? beta_ : 0.0f,
|
||||
y_data,
|
||||
static_cast<size_t>(N),
|
||||
thread_pool);
|
||||
if (K > 0) {
|
||||
MlasGemm(
|
||||
trans_A_,
|
||||
static_cast<size_t>(M),
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(K),
|
||||
alpha_,
|
||||
A->Data<float>(),
|
||||
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
|
||||
packed_b_.get(),
|
||||
c_data != nullptr ? beta_ : 0.0f,
|
||||
y_data,
|
||||
static_cast<size_t>(N),
|
||||
thread_pool);
|
||||
} else if (beta_ == 0 || c_data == nullptr) {
|
||||
EigenMatrixMapRowMajor<float> dest(y_data, narrow<Eigen::Index>(M), narrow<Eigen::Index>(N));
|
||||
dest.setZero();
|
||||
}
|
||||
}
|
||||
|
||||
ComputeActivation(y_data, SafeInt<size_t>(M) * N, thread_pool);
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ class GemmHelper {
|
|||
status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast");
|
||||
|
||||
// it is possible the input is empty tensor, for example the output of roipool in fast rcnn.
|
||||
ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0);
|
||||
// it is also possible that K == 0
|
||||
ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0);
|
||||
}
|
||||
|
||||
ptrdiff_t M() const { return M_; }
|
||||
|
|
|
|||
|
|
@ -106,8 +106,9 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
|
|||
if (helper.K() == 0) {
|
||||
// When we have (M, 0, N) then the inputs are empty, but the output should
|
||||
// be filled out with zeros.
|
||||
auto output_span = y->MutableDataAsSpan<T>();
|
||||
std::fill(output_span.begin(), output_span.end(), T{});
|
||||
EigenMatrixMapRowMajor<T> dest(y->MutableData<T>(),
|
||||
narrow<Eigen::Index>(helper.M()), narrow<Eigen::Index>(helper.N()));
|
||||
dest.setZero();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -241,8 +242,9 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
|||
if (helper.K() == 0) {
|
||||
// When we have (M, 0, N) then the inputs are empty, but the output should
|
||||
// be filled out with zeros.
|
||||
auto output_span = y->MutableDataAsSpan<float>();
|
||||
std::fill(output_span.begin(), output_span.end(), float{});
|
||||
EigenMatrixMapRowMajor<float> dest(y->MutableData<float>(),
|
||||
narrow<Eigen::Index>(helper.M()), narrow<Eigen::Index>(helper.N()));
|
||||
dest.setZero();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,16 @@ Status Gemm<T>::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const
|
|||
}
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
if (beta_ == 0 || B == nullptr) {
|
||||
// When we have (M, 0, N) then the output should be filled out with zeros
|
||||
// unless we have a bias
|
||||
Fill<CudaT>(Stream(ctx), reinterpret_cast<CudaT*>(Y->MutableData<T>()), CudaT(0.f),
|
||||
Y->Shape().Size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CudaT alpha = ToCudaType<T>::FromFloat(alpha_);
|
||||
CudaT beta = ToCudaType<T>::FromFloat(beta_);
|
||||
// Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
|
||||
|
|
|
|||
|
|
@ -641,6 +641,46 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) {
|
|||
.Config(run_with_tunable_op)
|
||||
.RunWithConfig();
|
||||
}
|
||||
|
||||
TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) {
|
||||
OpTester test("Gemm", 13);
|
||||
|
||||
test.AddAttribute("transA", static_cast<int64_t>(0));
|
||||
test.AddAttribute("transB", static_cast<int64_t>(0));
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<TypeParam>("A", {4, 0}, {});
|
||||
test.AddInput<TypeParam>("B", {0, 4}, {});
|
||||
test.AddInput<TypeParam>("C", {4}, std::vector<TypeParam>(4, static_cast<TypeParam>(1.0f)));
|
||||
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(1.0f)));
|
||||
|
||||
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
|
||||
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
|
||||
kOpenVINOExecutionProvider})
|
||||
.Config(run_with_tunable_op)
|
||||
.RunWithConfig();
|
||||
}
|
||||
|
||||
TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) {
|
||||
OpTester test("Gemm", 13);
|
||||
|
||||
test.AddAttribute("transA", static_cast<int64_t>(0));
|
||||
test.AddAttribute("transB", static_cast<int64_t>(0));
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", .0f);
|
||||
|
||||
test.AddInput<TypeParam>("A", {4, 0}, {});
|
||||
test.AddInput<TypeParam>("B", {0, 4}, {});
|
||||
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(0.0f)));
|
||||
|
||||
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
|
||||
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
|
||||
kOpenVINOExecutionProvider})
|
||||
.Config(run_with_tunable_op)
|
||||
.RunWithConfig();
|
||||
}
|
||||
|
||||
TYPED_TEST(GemmOpTypedTests, MissingBias) {
|
||||
OpTester test("Gemm", 11);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue