Support opset-11 Gemm kernels (#1923)

* Support optional bias in Gemm

* Fix test

* Update

* More updates

* Update

* Update

* Update gemm.cc

* Update

* Update

* Fix build break

* Update

* PR comments

* Update
This commit is contained in:
Hariharan Seshadri 2019-10-01 20:32:28 -07:00 committed by GitHub
parent 31aff686e0
commit a5e134405d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 92 additions and 37 deletions

View file

@ -261,7 +261,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul);
@ -358,6 +358,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lp
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm);
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -604,7 +605,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul)>,
@ -701,6 +702,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -13,9 +13,17 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Gemm<float>);
// opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet.
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Gemm,
9,
10,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
// opset 11 made bias input 'C' optional
ONNX_CPU_OPERATOR_KERNEL(
Gemm,
9,
11,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float>);
} // namespace onnxruntime

View file

@ -31,10 +31,12 @@ class Gemm : public OpKernel {
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
const auto X = context->Input<Tensor>(0);
const auto W = context->Input<Tensor>(1);
const auto B = context->Input<Tensor>(2);
GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, B->Shape());
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const auto* B = context->Input<Tensor>(2);
// Bias could be missing. Treat as scalar 0 if that is the case.
GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans,
B != nullptr ? B->Shape() : TensorShape({}));
if (!helper.State().IsOK())
return helper.State();
@ -42,13 +44,13 @@ class Gemm : public OpKernel {
int64_t M = helper.M();
int64_t N = helper.N();
auto Y = context->Output(0, {M, N});
// if input is emtpy tensor, return directly as nothing need to be calculated.
// if input is empty tensor, return directly as nothing need to be calculated.
if (M == 0 || N == 0)
return Status::OK();
T* y_data = Y->template MutableData<T>();
// Broadcast the bias as needed.
if (beta_ != 0) {
// Broadcast the bias as needed if bias is given
if (beta_ != 0 && B != nullptr) {
auto output_mat = EigenMatrixMapRowMajor<T>(y_data, M, N);
const auto& b_shape = B->Shape();
const T* b_data = B->template Data<T>();
@ -77,7 +79,9 @@ class Gemm : public OpKernel {
alpha_,
X->template Data<T>(),
W->template Data<T>(),
beta_,
// ideally we need to set the output buffer contents to 0 if bias is missing,
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
B != nullptr ? beta_ : 0,
y_data,
tp);

View file

@ -50,17 +50,18 @@ class GemmHelper {
Status State() const { return status_; }
private:
bool IsValidBroadcast(const TensorShape& shape, int64_t M, int64_t N) {
if (shape.NumDimensions() != 1 && shape.NumDimensions() != 2)
bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) {
// valid shapes are (,) , (1, N) , (M, 1) , (M, N)
if (bias_shape.NumDimensions() > 2)
return false;
// shape is (1,) or (1, 1), or (,)
if (shape.Size() == 1)
if (bias_shape.Size() == 1)
return true;
// shape is (N,) or (1, N) or (M, 1)
// or (M, N), in last case no broadcast needed, but don't fail it
return ((shape.NumDimensions() == 1 && shape[0] == N) ||
(shape.NumDimensions() == 2 && shape[0] == M && (shape[1] == 1 || shape[1] == N)) ||
(shape.NumDimensions() == 2 && shape[0] == 1 && shape[1] == N));
// valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N),
// In last case no broadcasting needed, so don't fail it
return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) ||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) ||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N));
}
private:

View file

@ -214,9 +214,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Ga
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul);
@ -537,10 +537,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Less);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Less);
// opset 10
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
// opset 11
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm);
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
@ -555,9 +561,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul)>,
@ -880,6 +886,11 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -20,10 +20,20 @@ namespace cuda {
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
9, \
10, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
@ -38,10 +48,11 @@ template <typename T>
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const auto X = ctx->Input<Tensor>(0);
const auto W = ctx->Input<Tensor>(1);
const auto B = ctx->Input<Tensor>(2);
GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape());
const auto* X = ctx->Input<Tensor>(0);
const auto* W = ctx->Input<Tensor>(1);
const auto* B = ctx->Input<Tensor>(2);
// Bias could be missing. Treat as scalar 0 if that is the case.
GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B != nullptr ? B->Shape() : TensorShape({}));
if (!helper.State().IsOK())
return helper.State();
@ -49,14 +60,14 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
int M = gsl::narrow_cast<int>(helper.M());
int N = gsl::narrow_cast<int>(helper.N());
int K = gsl::narrow_cast<int>(helper.K());
auto Y = ctx->Output(0, TensorShape(std::vector<int64_t>{M, N}));
auto* Y = ctx->Output(0, TensorShape(std::vector<int64_t>{M, N}));
CudaT* out_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
// broadcast bias if needed
if (beta_ != 0) {
// broadcast bias if needed and is present
if (beta_ != 0 && B != nullptr) {
auto& b_shape = B->Shape();
const CudaT* b_data = reinterpret_cast<const CudaT*>(B->template Data<T>());
@ -112,7 +123,9 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
(trans_B_ ? K : N),
reinterpret_cast<const CudaT*>(X->template Data<T>()),
(trans_A_ ? M : K),
&beta,
// ideally we need to set the output buffer contents to 0 if bias is missing,
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
B != nullptr ? &beta : &zero,
out_data, N));
return Status::OK();

View file

@ -471,7 +471,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
{"bitshift_left_uint64", "BitShift(11) not implemented yet"},
{"bitshift_left_uint32", "BitShift(11) not implemented yet"},
{"bitshift_left_uint16", "BitShift(11) not implemented yet"},
{"gemm_default_scalar_bias", "Gemm ValidBroadcast() has bug to be fixed."},
};
#ifdef USE_NGRAPH

View file

@ -232,5 +232,24 @@ TEST(GemmOpTest, GemmEmptyTensor) {
test.Run();
}
TEST(GemmOpTest, GemmNoBiasOpset11) {
OpTester test("Gemm", 11);
test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddOutput<float>("Y", {2, 3},
{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f});
// NGraph and tensorRT don't seem to support missing bias
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
}
} // namespace test
} // namespace onnxruntime

View file

@ -148,8 +148,6 @@ def create_backend_test(testname=None):
'^test_reduce_*',
'^test_onehot_*',
'^test_constant_pad_cpu.*',
'^test_gemm_default_scalar_bias_cpu.*',
'^test_gemm_*',
'^test_edge_pad_cpu.*',
'^test_reflect_pad_cpu.*'
)