From d4a88cfe3f143d529df2a54f5eb3b4971ca933d7 Mon Sep 17 00:00:00 2001 From: George Nash Date: Mon, 23 Aug 2021 08:45:34 -0700 Subject: [PATCH] Add Gemm op to DNNL Exectution provider (#8799) * Implement Gemm op for DNNL execution provider Signed-off-by: George Nash * Remove KernelRegistry and Gemm op for dnnl ep The KernelRegistry for the dnnl execution provider only registered a Gemm op that as best we can tell was never actually used and also was not using the dnnl library. We have implemented a Gemm op in the DNNL execution provider subgraph code and thus are removing the unused Gemm op that was in the dnnl KernelRegistry. Signed-off-by: George Nash * Fix duplicated output and kernelshape inference fix getcapability to make sure subgraph outputs do not have duplicates fix kernelshape inference in pool Signed-off-by: Wang * Removed most dnnl specialized ifdefs from gradient_ops_test code Re-enable GlobalAveragePoolGrad test for dnnl ep The bugs that were exposed by the GlobalAveragePoolGrad test have been fixed and this test no longer needs to be disabled for DNNL. Removed the ReluGradDnnl test. We are getting the testing from the already existing ReluGrad test. MaxPoolGrad test no longer has specialized execution provider enabling for DNNL execution provider. It will now run without the extra enabling. ConvGrad is the only test that still has dnnl specialized ifdefs However, the ConvGrad code was not being executed by the code unless it was listed first in the list of execution providers. Signed-off-by: George Nash * Fix transpose issue on Gemm On transposing square matrices, getmemoryandreshape will fail to reshape fix by adding a bool Signed-off-by: Wang * Save memory space by reusing internal tensor for output The intermediat matmul output tensor can be used as the output tensor for the binary calculation. Remove the unused IsAttributeSupported from the DnnlGemmNodeCapability class since we now support all of the Gemm attributes in our implementation. Signed-off-by: George Nash Co-authored-by: Wang --- .../providers/dnnl/dnnl_execution_provider.cc | 36 +--- .../providers/dnnl/dnnl_execution_provider.h | 2 - .../providers/dnnl/dnnl_node_capability.cc | 10 + .../providers/dnnl/dnnl_node_capability.h | 14 ++ .../core/providers/dnnl/dnnl_op_manager.cc | 1 + .../providers/dnnl/dnnl_provider_factory.cc | 4 +- onnxruntime/core/providers/dnnl/math/gemm.cc | 203 ------------------ onnxruntime/core/providers/dnnl/math/gemm.h | 32 --- .../core/providers/dnnl/subgraph/dnnl_gemm.cc | 181 ++++++++++++++++ .../core/providers/dnnl/subgraph/dnnl_gemm.h | 34 +++ .../core/providers/dnnl/subgraph/dnnl_pool.cc | 9 +- .../core/providers/dnnl/subgraph/dnnl_pool.h | 2 +- .../dnnl/subgraph/dnnl_subgraph_primitive.cc | 7 +- .../dnnl/subgraph/dnnl_subgraph_primitive.h | 2 +- .../test/providers/cpu/math/gemm_test.cc | 61 ++++++ .../test/gradient/gradient_op_test_utils.cc | 6 + .../test/gradient/gradient_ops_test.cc | 48 +---- 17 files changed, 331 insertions(+), 321 deletions(-) delete mode 100644 onnxruntime/core/providers/dnnl/math/gemm.cc delete mode 100644 onnxruntime/core/providers/dnnl/math/gemm.h create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 500f4a8658..f29ae5a1b1 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -54,40 +54,6 @@ DNNLExecutionProvider::DNNLExecutionProvider(const DNNLExecutionProviderInfo& in DNNLExecutionProvider::~DNNLExecutionProvider() { } -namespace ort_dnnl { -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm); - -Status RegisterDNNLKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - }; - - for (auto& function_table_entry : function_table) { - ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); - } - return Status::OK(); -} - -} // namespace ort_dnnl - -static std::shared_ptr s_kernel_registry; - -void Shutdown_DeleteRegistry() { - s_kernel_registry.reset(); -} - -std::shared_ptr DNNLExecutionProvider::GetKernelRegistry() const { - if (!s_kernel_registry) { - s_kernel_registry = KernelRegistry::Create(); - auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry); - if (!status.IsOK()) - s_kernel_registry.reset(); - ORT_THROW_IF_ERROR(status); - } - - return s_kernel_registry; -} - std::vector> DNNLExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer) const { std::vector> supported_node_vecs; std::vector supported_node_vec; @@ -220,7 +186,7 @@ std::vector> DNNLExecutionProvider::GetCapabi for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { if (node_set.count(it->GetNode().Index()) == 0) { const auto* output_def = output_defs[it->GetSrcArgIndex()]; - if (subgraph_outputs.count(output_def) == 0) { + if (subgraph_outputs.count(output_def) == 0 && graph_outputs.count(output_def) == 0) { subgraph_outputs.insert(output_def); ordered_subgraph_outputs.push_back(output_def); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 2cb2295804..011ae905c0 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -30,8 +30,6 @@ class DNNLExecutionProvider : public IExecutionProvider { explicit DNNLExecutionProvider(const DNNLExecutionProviderInfo& info); virtual ~DNNLExecutionProvider(); - virtual std::shared_ptr GetKernelRegistry() const override; - std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const override; diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 3117444a33..62d354ca8c 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -339,6 +339,8 @@ bool DnnlSumNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlBinaryNodeCapability class +//------------------------------------- bool DnnlBinaryNodeCapability::Supported(const Node* node) const { if (!IsTypeSupported(node)) return false; if (!IsDimensionSupported(node)) return false; @@ -370,4 +372,12 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlGemmNodeCapability class +//------------------------------------- +bool DnnlGemmNodeCapability::Supported(const Node* node) const { + if (!_matmul.Supported(node)) return false; + if (!_binary.Supported(node)) return false; + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 354f9ba7fd..702874b773 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -207,4 +207,18 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability { bool IsDimensionSupported(const Node* node) const; }; +/** + * Decide if a Gemm op is supported by DnnlExecutionProvider + */ +class DnnlGemmNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlGemmNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + DnnlMatMulNodeCapability _matmul; + DnnlBinaryNodeCapability _binary; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index d65acd3c7b..b7d133c34a 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() { dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr(new DnnlGemmNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr(new DnnlDefaultNodeCapability()))); diff --git a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc index 123beb4f87..0cf1f56f50 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc @@ -11,8 +11,6 @@ using namespace onnxruntime; namespace onnxruntime { -void Shutdown_DeleteRegistry(); - struct DnnlProviderFactory : IExecutionProviderFactory { DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {} ~DnnlProviderFactory() override {} @@ -50,7 +48,7 @@ struct Dnnl_Provider : Provider { } void Shutdown() override { - Shutdown_DeleteRegistry(); + return; } } g_provider; diff --git a/onnxruntime/core/providers/dnnl/math/gemm.cc b/onnxruntime/core/providers/dnnl/math/gemm.cc deleted file mode 100644 index 844df4b6a8..0000000000 --- a/onnxruntime/core/providers/dnnl/math/gemm.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "gemm.h" -#include "dnnl.h" -#include "dnnl.hpp" -#include "core/providers/dnnl/dnnl_fwd.h" -#include "gsl/gsl" -#include "onnxruntime_config.h" -// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: -// error: ignoring attributes on template argument "Eigen::PacketType::type {aka __vector(4) float}" [-Werror=ignored-attributes] -#if defined(__GNUC__) -#pragma GCC diagnostic push -#if __GNUC__ >= 6 -#pragma GCC diagnostic ignored "-Wignored-attributes" -#endif -#pragma GCC diagnostic ignored "-Wunused-parameter" -#ifdef HAS_DEPRECATED_COPY -#pragma GCC diagnostic ignored "-Wdeprecated-copy" -#endif -#elif defined(_MSC_VER) -// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): -// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence - -// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64' -// to 'uint64_t', signed/unsigned mismatch -#pragma warning(push) -#pragma warning(disable : 4554) -#pragma warning(disable : 4245) -#pragma warning(disable : 4127) -#endif - -#include "Eigen/Core" - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - - - -namespace onnxruntime { -namespace ort_dnnl { - -ONNX_OPERATOR_KERNEL_EX( - Gemm, - kOnnxDomain, - 7, - kDnnlExecutionProvider, - KernelDefBuilder::Create()->TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gemm); - -class GemmHelper { - public: - GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) { - ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1); - ORT_ENFORCE(right.NumDimensions() == 2); - - if (trans_left) { - M_ = left.NumDimensions() == 2 ? left[1] : left[0]; - K_ = left.NumDimensions() == 2 ? left[0] : 1; - } else { - M_ = left.NumDimensions() == 2 ? left[0] : 1; - K_ = left.NumDimensions() == 2 ? left[1] : left[0]; - } - - int k_dim; - if (trans_right) { - N_ = right[0]; - k_dim = 1; - } else { - N_ = right[1]; - k_dim = 0; - } - - if (right[k_dim] != K_) - status_ = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "GEMM: Dimension mismatch, W: ", - right.ToString(), - " K: " + std::to_string(K_), - " N:" + std::to_string(N_)); - - if (!IsValidBroadcast(bias, M_, N_)) - status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); - - // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. - ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0); - } - - int64_t M() const { return M_; } - int64_t N() const { return N_; } - int64_t K() const { return K_; } - Status State() const { return status_; } - - private: - bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) { - // valid shapes are (,) , (1, N) , (M, 1) , (M, N) - if (bias_shape.NumDimensions() > 2) - return false; - // shape is (1,) or (1, 1), or (,) - if (bias_shape.Size() == 1) - return true; - // valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), - // In last case no broadcasting needed, so don't fail it - return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N)); - } - - private: - int64_t M_; - int64_t K_; - int64_t N_; - Status status_; -}; - -template -using EigenMatrixMapRowMajor = Eigen::Map>; -template -using ConstEigenVectorMap = Eigen::Map>; -template -using ConstEigenMatrixMapRowMajor = Eigen::Map>; - -template <> -Status Gemm::Compute(OpKernelContext* ctx) const { - const auto X = ctx->Input(0); - const auto W = ctx->Input(1); - const auto B = ctx->Input(2); - GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape()); - - if (!helper.State().IsOK()) - return helper.State(); - - dnnl::memory::dim M = gsl::narrow_cast(helper.M()); - dnnl::memory::dim N = gsl::narrow_cast(helper.N()); - dnnl::memory::dim K = gsl::narrow_cast(helper.K()); - auto Y = ctx->Output(0, TensorShape({M, N})); - - if (M <= 0) - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Empty Tensor not supported"); - - if (beta_ != 0) { - auto output_mat = EigenMatrixMapRowMajor( - Y->template MutableData(), - M, - N); - output_mat.setZero(); - - auto& b_shape = B->Shape(); - // if B is (), (1,) or (1, 1), add the scalar - if (b_shape.Size() == 1) { - output_mat.array() += *(B->template Data()); - } - // B is (N,) - else if (b_shape.NumDimensions() == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } else if (b_shape.NumDimensions() == 2) { - // B is (M, 1) - if (b_shape[1] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - M); - output_mat.colwise() += bias_vec; - } - // B is (1, N) - else if (b_shape[0] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } - // B is (M, N), no broadcast needed. - else { - auto bias_mat = ConstEigenMatrixMapRowMajor( - B->template Data(), - M, - N); - output_mat += bias_mat; - } - } - } - - // dnnl_sgemm expects row major matrices, so no need to swap the operands A and B - auto status = dnnl_sgemm(trans_A_ ? 'T' : 'N', - trans_B_ ? 'T' : 'N', - M, N, K, - alpha_, X->template Data(), trans_A_ ? M : K, - W->template Data(), trans_B_ ? K : N, - beta_, Y->template MutableData(), N); - if (status == dnnl_success) { - return Status::OK(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DNNL_sgemm failed with status: ", status); - } -} - -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/math/gemm.h b/onnxruntime/core/providers/dnnl/math/gemm.h deleted file mode 100644 index 6e51cb6c94..0000000000 --- a/onnxruntime/core/providers/dnnl/math/gemm.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace ort_dnnl { -template -class Gemm final : public OpKernel { - public: - Gemm(const OpKernelInfo& info) : OpKernel(info) { - int64_t temp; - ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); - trans_A_ = (temp != 0); - - ORT_ENFORCE(info.GetAttr("transB", &temp).IsOK()); - trans_B_ = (temp != 0); - - ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); - } - - Status Compute(OpKernelContext* context) const override; - - private: - bool trans_A_; - bool trans_B_; - float alpha_; - float beta_; -}; -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc new file mode 100644 index 0000000000..1a85b2b50b --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc @@ -0,0 +1,181 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_gemm.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +DnnlGemm::DnnlGemm() {} + +/* +Gemm implementation: +Gemm: + Inputs: + 0) A - Input Tensor + 1) B - Input Tensor + 2) C - Input Tensor (optional if Opset is 11 or later) + Outputs: + 0) Y - Output Tensor + + +-----------+ + (A) | | + ---------->+ | AB +------+ + (B) | MatMul +--------------------->+ | alphaAB + ---------->+ | (alpha) | Mul +---+ + | | *--------------->+ | | +------+ + +-----------+ +------+ +---->+ | (Y) alphaAB + betaC + | Add +----------------------> + (C) +------+ +---->+ | + --------------------------------------------->+ | | +------+ + (beta) | Mul +---+ + *--------------->+ | betaC + +------+ + +Attributes (alpha, beta, transA, transB) + +To compose Gemm: (algorithm) +(1) perform `MatMul` on input tensors A and B result (AB) +(2) if `Mul` the result of (1) by alpha attribute (alphaAB) +(3) if C is optional return result from (2) and end +(4) if C is avalible `Mul` input C tensor by beta attribute (betaC) +(5) `Add` result from (2) to result from (4) (alphaAB + betaC) +(6) Return output from (5) and end + +OneDNN algorithm: +(1) perform `MatMul` of tensor A and tensor B with `Output scales` set to alpha (0) +(2) if C is optional return output from (1) and end +(3) if C is avalible perform binary `Add` of output from (0) and input C with input C's `scale` attribute set to beta +(4) return output from (4) and end + +*/ + + +void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + auto eng = sp.GetEngine(); + + auto a_dims = sp.GetMemory(node.Input(IN_A).Name()).get_desc().dims(); + auto b_dims = sp.GetMemory(node.Input(IN_B).Name()).get_desc().dims(); + + bool input_c_exists = node.Input(IN_C).Exists(); + + if (a_dims.size() != b_dims.size()) { + while (a_dims.size() < b_dims.size()) { + a_dims.insert(a_dims.begin(), 1); + } + while (a_dims.size() > b_dims.size()) { + b_dims.insert(b_dims.begin(), 1); + } + } + + + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + + bool transA = GetTransA(node); + bool transB = GetTransB(node); + + dnnl::memory::dim M = (transA) ? a_dims[1] : a_dims[0]; + dnnl::memory::dim K = (transA) ? a_dims[0] : a_dims[1]; + dnnl::memory::dim N = (transB) ? b_dims[0] : b_dims[1]; + + dnnl::memory::dims a_strides = (transA) ? dnnl::memory::dims{dnnl::memory::dim(1), M} : dnnl::memory::dims{K, dnnl::memory::dim(1)}; + dnnl::memory::dims b_strides = (transB) ? dnnl::memory::dims{dnnl::memory::dim(1), K} : dnnl::memory::dims{N, dnnl::memory::dim(1)}; + + a_md = dnnl::memory::desc({M, K}, node.Input(IN_A).Type(), a_strides); + b_md = dnnl::memory::desc({K, N}, node.Input(IN_B).Type(), b_strides); + + dnnl::memory::dims output_shape{M, N}; + + dnnl::primitive_attr matmul_attr; + // scale the output from MatMul to alpha + float alpha = GetAlpha(node); + std::vector alphaScale({alpha}); + matmul_attr.set_output_scales(0, alphaScale); + + auto matmul_dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), {N, 1}); + + auto matmul_d = dnnl::matmul::desc(a_md, b_md, matmul_dst_md); + dnnl::matmul::primitive_desc matmul_pd; + matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng); + + auto matmul_a_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng, transA); + auto matmul_b_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng, transB); + auto gemm_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng); + + auto matmul_op = dnnl::matmul(matmul_pd); + + std::unordered_map args; + args.insert({DNNL_ARG_SRC, matmul_a_mem}); + args.insert({DNNL_ARG_WEIGHTS, matmul_b_mem}); + args.insert({DNNL_ARG_DST, gemm_dst_mem}); + + sp.AddPrimitive(matmul_op, args); + + if (input_c_exists) { + auto c_original_md = sp.GetMemory(node.Input(IN_C).Name()).get_desc(); + auto c_dims = c_original_md.dims(); + if (c_dims.size() != a_dims.size()) { + while (c_dims.size() < a_dims.size()) { + c_dims.insert(c_dims.begin(), 1); + } + } + + auto c_md = c_original_md.reshape(c_dims); + + auto y_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + + auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, matmul_pd.dst_desc(), c_md, y_md); + + // Scale input C by beta before adding it to the MatMul output. + dnnl::primitive_attr binary_attr; + float beta = GetBeta(node); + binary_attr.set_scales(DNNL_ARG_SRC_1, 0, {beta}); + + auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr,eng); + + auto binary_c_mem = sp.GetMemoryAndReshape(node.Input(IN_C), binary_pd.src1_desc(), eng); + + auto binary_op = dnnl::binary(binary_pd); + + sp.AddPrimitive(binary_op, {{DNNL_ARG_SRC_0, gemm_dst_mem}, + {DNNL_ARG_SRC_1, binary_c_mem}, + {DNNL_ARG_DST, gemm_dst_mem}}); + } + sp.SetMemory(node.Output(OUT_Y), gemm_dst_mem); +} + +float DnnlGemm::GetAlpha(DnnlNode& node) { + auto attr = node.Attributes().find("alpha"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return 1.0; +} +float DnnlGemm::GetBeta(DnnlNode& node) { + auto attr = node.Attributes().find("beta"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return 1.0; +} + +bool DnnlGemm::GetTransA(DnnlNode& node) { + auto attr = node.Attributes().find("transA"); + if (attr != node.Attributes().end()) { + return (attr->second().i() != 0); + } + return false; +} + +bool DnnlGemm::GetTransB(DnnlNode& node) { + auto attr = node.Attributes().find("transB"); + if (attr != node.Attributes().end()) { + return (attr->second().i() != 0); + } + return false; +} +} // namespace ort_dnnl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h new file mode 100644 index 0000000000..1b8f90e9c4 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h @@ -0,0 +1,34 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlGemm { + public: + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_C = 2 + }; + + enum OutputTensors : int { + OUT_Y = 0 + }; + + DnnlGemm(); + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + float GetAlpha(DnnlNode& node); + float GetBeta(DnnlNode& node); + bool GetTransA(DnnlNode& node); + bool GetTransB(DnnlNode& node); +}; + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc index 6baed09262..101c73941c 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc @@ -31,7 +31,7 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } } - auto kernel_shape = GetKernelShape(node); + auto kernel_shape = GetKernelShape(src_dims, node); PoolShape shape = static_cast(kernel_shape.size()); auto strides = GetStrides(node, shape); @@ -118,7 +118,7 @@ dnnl::memory::dims DnnlPool::GetDilations(DnnlNode& node, PoolShape shape) { return dnnl::memory::dims(dilations.begin(), dilations.end()); } -dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) { +dnnl::memory::dims DnnlPool::GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node) { auto attr = node.Attributes().find("kernel_shape"); std::vector kernel_shape; if (attr != node.Attributes().end()) { @@ -128,9 +128,8 @@ dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) { } return kernel_shape; } - // Infer the Kernel shape from the input weights - auto weight_dims = node.Input(IN_X).Dim(); - kernel_shape = std::vector(weight_dims.begin() + 2, weight_dims.end()); + + kernel_shape = std::vector(src_dims.begin() + 2, src_dims.end()); return kernel_shape; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h index 4edaa69527..5d8e4458ec 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h @@ -34,7 +34,7 @@ class DnnlPool { int64_t GetCeilMode(DnnlNode& node); int64_t GetCountIncludePadding(DnnlNode& node); dnnl::memory::dims GetDilations(DnnlNode& node, PoolShape shape); - dnnl::memory::dims GetKernelShape(DnnlNode& node); + dnnl::memory::dims GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node); /* This will return the calculated padding taking into account the DEPRECATED auto_pad attribute */ std::vector InferPadding(DnnlNode& node, const dnnl::memory::dims& src_dims, const dnnl::memory::dims& kernel_shape, const dnnl::memory::dims& strides); std::vector GetPadding(DnnlNode& node, PoolShape shape); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 43f3af57d9..b3b1e73e42 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -6,6 +6,7 @@ #include "dnnl_batchnorm.h" #include "dnnl_binary.h" #include "dnnl_conv.h" +#include "dnnl_gemm.h" #include "dnnl_lrn.h" #include "dnnl_matmul.h" #include "dnnl_matmul_integer.h" @@ -205,6 +206,8 @@ void DnnlSubgraphPrimitive::AddKernels() { DnnlBatchNorm().CreatePrimitive(*this, node); } else if (node.OpType() == "Conv") { DnnlConv().CreatePrimitive(*this, node); + } else if (node.OpType() == "Gemm") { + DnnlGemm().CreatePrimitive(*this, node); } else if (node.OpType() == "LRN") { DnnlLrn().CreatePrimitive(*this, node); } else if (node.OpType() == "MatMul") { @@ -356,7 +359,7 @@ void DnnlSubgraphPrimitive::SetInitializer(std::string memory_name, dnnl::memory } } -dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng) { +dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose) { // if found just return if (HasMemory(tensor.Name(), mem_desc, eng)) { return GetMemory(tensor.Name(), mem_desc, eng); @@ -372,7 +375,7 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor ten auto mem_to = dnnl::memory(mem_desc, eng); // if it is a reshape, ensure reorder is possible by making the same dims - if (mem_from.get_desc().dims() != mem_to.get_desc().dims()) { + if (mem_from.get_desc().dims() != mem_to.get_desc().dims() || transpose) { auto mem_from_dims = mem_from.get_desc().dims(); auto mem_to_dims = mem_to.get_desc().dims(); if (Product(mem_from_dims) != Product(mem_to_dims)) { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h index ed4bbfb5da..e96cd4750f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h @@ -46,7 +46,7 @@ class DnnlSubgraphPrimitive { dnnl::stream GetStream(); //obtain a dnnl::memory with specified name, memory descriptor and engine, will perform extra reorder/reshape if necessary before returning - dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng); + dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose = false); //add dnnl primitive and memory map to subgraph primitive void AddPrimitive(dnnl::primitive prim, std::unordered_map mem_map); //add a reshape (e.g. squeeze, unsqueeze) to subgraph primitive diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index ddec06b678..6c1573b2a0 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -225,6 +225,67 @@ TEST(GemmOpTest, GemmTransB_1) { TestGemmTransB_1(); } +template +void TestGemmAlpha() { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.5f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {6.0f, 6.0f, 6.0f, + -4.0f, -4.0f, -4.0f}); + //test.AddOutput("Y", {2, 3}, + // {5.0f, 5.0f, 5.0f, + // -5.0f, -5.0f, -5.0f}); +#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser +#endif +} + +TEST(GemmOpTest, GemmAlpha) { + TestGemmAlpha(); + TestGemmAlpha(); +} + +template +void TestGemmBeta() { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 2.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {12.0f, 12.0f, 12.0f, + -8.0f, -8.0f, -8.0f}); +#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser +#endif +} + +TEST(GemmOpTest, GemmBeta) { + TestGemmBeta(); + TestGemmBeta(); +} + template void TestGemmAlphaBeta() { OpTester test("Gemm"); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index acb1a04a44..809545cd2e 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -192,6 +192,12 @@ void GradientOpTester::Run( //if node is not registered for the provider, skip node.SetExecutionProviderType(provider_type); + + // provider types that don't use the KernelRegistry + if (provider_type == onnxruntime::kDnnlExecutionProvider) { + continue; + } + auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; auto st = reg->TryFindKernel(node, execution_provider->Type(), &kci); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index bec6f7c0dc..721667509c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -663,14 +663,6 @@ TEST(GradientCheckerTest, ReluGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } -#ifdef USE_DNNL -TEST(GradientCheckerTest, ReluGradDnnl) { - std::vector> execution_providers; - execution_providers.push_back(DefaultDnnlExecutionProvider()); - UnaryOpGradientTest("Relu", kOnnxDomain, 9, &execution_providers); -} -#endif // USE_DNNL - TEST(GradientCheckerTest, CastGrad) { // A dummy test that cast float to float // TODO: add more test here @@ -728,7 +720,7 @@ static std::vector> GetRandomValuesForMaxPool(const std::vector>* execution_provider) { +TEST(GradientCheckerTest, MaxPoolGrad) { float max_error; GradientChecker gradient_checker; OpDef op_def{"MaxPool"}; @@ -737,9 +729,7 @@ void MaxpoolGradientCheckerTest(std::vector> { gradient_checker.ComputeGradientError(op_def, {{2, 2, 9}}, {{2, 2, 8}}, &max_error, GetRandomValuesForMaxPool({{2, 2, 9}}), - {MakeAttribute("kernel_shape", std::vector{2})}, - true, false, - execution_provider); + {MakeAttribute("kernel_shape", std::vector{2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -748,9 +738,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{2, 3, 5, 5}}, {{2, 3, 4, 4}}, &max_error, GetRandomValuesForMaxPool({{2, 3, 5, 5}}), {MakeAttribute("kernel_shape", std::vector{2, 2}), - MakeAttribute("strides", std::vector{1, 1})}, - true, false, - execution_provider); + MakeAttribute("strides", std::vector{1, 1})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -759,9 +747,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{1, 1, 5, 5}}, {{1, 1, 7, 7}}, &max_error, GetRandomValuesForMaxPool({{1, 1, 5, 5}}), {MakeAttribute("kernel_shape", std::vector{3, 3}), - MakeAttribute("pads", std::vector{2, 2, 2, 2})}, - true, false, - execution_provider); + MakeAttribute("pads", std::vector{2, 2, 2, 2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -770,9 +756,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{1, 1, 32, 32}}, {{1, 1, 10, 10}}, &max_error, GetRandomValuesForMaxPool({{1, 1, 32, 32}}), {MakeAttribute("kernel_shape", std::vector{5, 5}), - MakeAttribute("strides", std::vector{3, 3})}, - true, false, - execution_provider); + MakeAttribute("strides", std::vector{3, 3})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -780,23 +764,11 @@ void MaxpoolGradientCheckerTest(std::vector> { gradient_checker.ComputeGradientError(op_def, {{2, 1, 3, 3, 3}}, {{2, 1, 2, 2, 2}}, &max_error, GetRandomValuesForMaxPool({{2, 1, 3, 3, 3}}), - {MakeAttribute("kernel_shape", std::vector{2, 2, 2})}, - true, false, - execution_provider); + {MakeAttribute("kernel_shape", std::vector{2, 2, 2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } -TEST(GradientCheckerTest, MaxPoolGrad) { - MaxpoolGradientCheckerTest(nullptr); - -#ifdef USE_DNNL - std::vector> execution_providers; - execution_providers.push_back(DefaultDnnlExecutionProvider()); - MaxpoolGradientCheckerTest(&execution_providers); -#endif -} - TEST(GradientCheckerTest, GlobalAveragePoolGrad) { float max_error; GradientChecker gradient_checker; @@ -1057,15 +1029,17 @@ void ConvGradientCheckerTest(std::vector>* e TEST(GradientCheckerTest, ConvGrad) { std::vector> execution_providers; +#ifdef USE_DNNL + // Dnnl EP does not run for ConvGrad unless it is pushed first. + execution_providers.push_back(DefaultDnnlExecutionProvider()); +#endif + execution_providers.push_back(DefaultCpuExecutionProvider()); if (HasCudaEnvironment(700)) { execution_providers.push_back(DefaultCudaExecutionProvider()); } -#ifdef USE_DNNL - execution_providers.push_back(DefaultDnnlExecutionProvider()); -#endif ConvGradientCheckerTest(&execution_providers); }