diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 500f4a8658..f29ae5a1b1 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -54,40 +54,6 @@ DNNLExecutionProvider::DNNLExecutionProvider(const DNNLExecutionProviderInfo& in DNNLExecutionProvider::~DNNLExecutionProvider() { } -namespace ort_dnnl { -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm); - -Status RegisterDNNLKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - }; - - for (auto& function_table_entry : function_table) { - ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); - } - return Status::OK(); -} - -} // namespace ort_dnnl - -static std::shared_ptr s_kernel_registry; - -void Shutdown_DeleteRegistry() { - s_kernel_registry.reset(); -} - -std::shared_ptr DNNLExecutionProvider::GetKernelRegistry() const { - if (!s_kernel_registry) { - s_kernel_registry = KernelRegistry::Create(); - auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry); - if (!status.IsOK()) - s_kernel_registry.reset(); - ORT_THROW_IF_ERROR(status); - } - - return s_kernel_registry; -} - std::vector> DNNLExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer) const { std::vector> supported_node_vecs; std::vector supported_node_vec; @@ -220,7 +186,7 @@ std::vector> DNNLExecutionProvider::GetCapabi for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { if (node_set.count(it->GetNode().Index()) == 0) { const auto* output_def = output_defs[it->GetSrcArgIndex()]; - if (subgraph_outputs.count(output_def) == 0) { + if (subgraph_outputs.count(output_def) == 0 && graph_outputs.count(output_def) == 0) { subgraph_outputs.insert(output_def); ordered_subgraph_outputs.push_back(output_def); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 2cb2295804..011ae905c0 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -30,8 +30,6 @@ class DNNLExecutionProvider : public IExecutionProvider { explicit DNNLExecutionProvider(const DNNLExecutionProviderInfo& info); virtual ~DNNLExecutionProvider(); - virtual std::shared_ptr GetKernelRegistry() const override; - std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const override; diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 3117444a33..62d354ca8c 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -339,6 +339,8 @@ bool DnnlSumNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlBinaryNodeCapability class +//------------------------------------- bool DnnlBinaryNodeCapability::Supported(const Node* node) const { if (!IsTypeSupported(node)) return false; if (!IsDimensionSupported(node)) return false; @@ -370,4 +372,12 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlGemmNodeCapability class +//------------------------------------- +bool DnnlGemmNodeCapability::Supported(const Node* node) const { + if (!_matmul.Supported(node)) return false; + if (!_binary.Supported(node)) return false; + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 354f9ba7fd..702874b773 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -207,4 +207,18 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability { bool IsDimensionSupported(const Node* node) const; }; +/** + * Decide if a Gemm op is supported by DnnlExecutionProvider + */ +class DnnlGemmNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlGemmNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + DnnlMatMulNodeCapability _matmul; + DnnlBinaryNodeCapability _binary; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index d65acd3c7b..b7d133c34a 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() { dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr(new DnnlGemmNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr(new DnnlDefaultNodeCapability()))); diff --git a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc index 123beb4f87..0cf1f56f50 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc @@ -11,8 +11,6 @@ using namespace onnxruntime; namespace onnxruntime { -void Shutdown_DeleteRegistry(); - struct DnnlProviderFactory : IExecutionProviderFactory { DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {} ~DnnlProviderFactory() override {} @@ -50,7 +48,7 @@ struct Dnnl_Provider : Provider { } void Shutdown() override { - Shutdown_DeleteRegistry(); + return; } } g_provider; diff --git a/onnxruntime/core/providers/dnnl/math/gemm.cc b/onnxruntime/core/providers/dnnl/math/gemm.cc deleted file mode 100644 index 844df4b6a8..0000000000 --- a/onnxruntime/core/providers/dnnl/math/gemm.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "gemm.h" -#include "dnnl.h" -#include "dnnl.hpp" -#include "core/providers/dnnl/dnnl_fwd.h" -#include "gsl/gsl" -#include "onnxruntime_config.h" -// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: -// error: ignoring attributes on template argument "Eigen::PacketType::type {aka __vector(4) float}" [-Werror=ignored-attributes] -#if defined(__GNUC__) -#pragma GCC diagnostic push -#if __GNUC__ >= 6 -#pragma GCC diagnostic ignored "-Wignored-attributes" -#endif -#pragma GCC diagnostic ignored "-Wunused-parameter" -#ifdef HAS_DEPRECATED_COPY -#pragma GCC diagnostic ignored "-Wdeprecated-copy" -#endif -#elif defined(_MSC_VER) -// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): -// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence - -// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64' -// to 'uint64_t', signed/unsigned mismatch -#pragma warning(push) -#pragma warning(disable : 4554) -#pragma warning(disable : 4245) -#pragma warning(disable : 4127) -#endif - -#include "Eigen/Core" - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - - - -namespace onnxruntime { -namespace ort_dnnl { - -ONNX_OPERATOR_KERNEL_EX( - Gemm, - kOnnxDomain, - 7, - kDnnlExecutionProvider, - KernelDefBuilder::Create()->TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gemm); - -class GemmHelper { - public: - GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) { - ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1); - ORT_ENFORCE(right.NumDimensions() == 2); - - if (trans_left) { - M_ = left.NumDimensions() == 2 ? left[1] : left[0]; - K_ = left.NumDimensions() == 2 ? left[0] : 1; - } else { - M_ = left.NumDimensions() == 2 ? left[0] : 1; - K_ = left.NumDimensions() == 2 ? left[1] : left[0]; - } - - int k_dim; - if (trans_right) { - N_ = right[0]; - k_dim = 1; - } else { - N_ = right[1]; - k_dim = 0; - } - - if (right[k_dim] != K_) - status_ = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "GEMM: Dimension mismatch, W: ", - right.ToString(), - " K: " + std::to_string(K_), - " N:" + std::to_string(N_)); - - if (!IsValidBroadcast(bias, M_, N_)) - status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); - - // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. - ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0); - } - - int64_t M() const { return M_; } - int64_t N() const { return N_; } - int64_t K() const { return K_; } - Status State() const { return status_; } - - private: - bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) { - // valid shapes are (,) , (1, N) , (M, 1) , (M, N) - if (bias_shape.NumDimensions() > 2) - return false; - // shape is (1,) or (1, 1), or (,) - if (bias_shape.Size() == 1) - return true; - // valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), - // In last case no broadcasting needed, so don't fail it - return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N)); - } - - private: - int64_t M_; - int64_t K_; - int64_t N_; - Status status_; -}; - -template -using EigenMatrixMapRowMajor = Eigen::Map>; -template -using ConstEigenVectorMap = Eigen::Map>; -template -using ConstEigenMatrixMapRowMajor = Eigen::Map>; - -template <> -Status Gemm::Compute(OpKernelContext* ctx) const { - const auto X = ctx->Input(0); - const auto W = ctx->Input(1); - const auto B = ctx->Input(2); - GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape()); - - if (!helper.State().IsOK()) - return helper.State(); - - dnnl::memory::dim M = gsl::narrow_cast(helper.M()); - dnnl::memory::dim N = gsl::narrow_cast(helper.N()); - dnnl::memory::dim K = gsl::narrow_cast(helper.K()); - auto Y = ctx->Output(0, TensorShape({M, N})); - - if (M <= 0) - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Empty Tensor not supported"); - - if (beta_ != 0) { - auto output_mat = EigenMatrixMapRowMajor( - Y->template MutableData(), - M, - N); - output_mat.setZero(); - - auto& b_shape = B->Shape(); - // if B is (), (1,) or (1, 1), add the scalar - if (b_shape.Size() == 1) { - output_mat.array() += *(B->template Data()); - } - // B is (N,) - else if (b_shape.NumDimensions() == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } else if (b_shape.NumDimensions() == 2) { - // B is (M, 1) - if (b_shape[1] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - M); - output_mat.colwise() += bias_vec; - } - // B is (1, N) - else if (b_shape[0] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } - // B is (M, N), no broadcast needed. - else { - auto bias_mat = ConstEigenMatrixMapRowMajor( - B->template Data(), - M, - N); - output_mat += bias_mat; - } - } - } - - // dnnl_sgemm expects row major matrices, so no need to swap the operands A and B - auto status = dnnl_sgemm(trans_A_ ? 'T' : 'N', - trans_B_ ? 'T' : 'N', - M, N, K, - alpha_, X->template Data(), trans_A_ ? M : K, - W->template Data(), trans_B_ ? K : N, - beta_, Y->template MutableData(), N); - if (status == dnnl_success) { - return Status::OK(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DNNL_sgemm failed with status: ", status); - } -} - -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/math/gemm.h b/onnxruntime/core/providers/dnnl/math/gemm.h deleted file mode 100644 index 6e51cb6c94..0000000000 --- a/onnxruntime/core/providers/dnnl/math/gemm.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace ort_dnnl { -template -class Gemm final : public OpKernel { - public: - Gemm(const OpKernelInfo& info) : OpKernel(info) { - int64_t temp; - ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); - trans_A_ = (temp != 0); - - ORT_ENFORCE(info.GetAttr("transB", &temp).IsOK()); - trans_B_ = (temp != 0); - - ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); - } - - Status Compute(OpKernelContext* context) const override; - - private: - bool trans_A_; - bool trans_B_; - float alpha_; - float beta_; -}; -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc new file mode 100644 index 0000000000..1a85b2b50b --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc @@ -0,0 +1,181 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_gemm.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +DnnlGemm::DnnlGemm() {} + +/* +Gemm implementation: +Gemm: + Inputs: + 0) A - Input Tensor + 1) B - Input Tensor + 2) C - Input Tensor (optional if Opset is 11 or later) + Outputs: + 0) Y - Output Tensor + + +-----------+ + (A) | | + ---------->+ | AB +------+ + (B) | MatMul +--------------------->+ | alphaAB + ---------->+ | (alpha) | Mul +---+ + | | *--------------->+ | | +------+ + +-----------+ +------+ +---->+ | (Y) alphaAB + betaC + | Add +----------------------> + (C) +------+ +---->+ | + --------------------------------------------->+ | | +------+ + (beta) | Mul +---+ + *--------------->+ | betaC + +------+ + +Attributes (alpha, beta, transA, transB) + +To compose Gemm: (algorithm) +(1) perform `MatMul` on input tensors A and B result (AB) +(2) if `Mul` the result of (1) by alpha attribute (alphaAB) +(3) if C is optional return result from (2) and end +(4) if C is avalible `Mul` input C tensor by beta attribute (betaC) +(5) `Add` result from (2) to result from (4) (alphaAB + betaC) +(6) Return output from (5) and end + +OneDNN algorithm: +(1) perform `MatMul` of tensor A and tensor B with `Output scales` set to alpha (0) +(2) if C is optional return output from (1) and end +(3) if C is avalible perform binary `Add` of output from (0) and input C with input C's `scale` attribute set to beta +(4) return output from (4) and end + +*/ + + +void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + auto eng = sp.GetEngine(); + + auto a_dims = sp.GetMemory(node.Input(IN_A).Name()).get_desc().dims(); + auto b_dims = sp.GetMemory(node.Input(IN_B).Name()).get_desc().dims(); + + bool input_c_exists = node.Input(IN_C).Exists(); + + if (a_dims.size() != b_dims.size()) { + while (a_dims.size() < b_dims.size()) { + a_dims.insert(a_dims.begin(), 1); + } + while (a_dims.size() > b_dims.size()) { + b_dims.insert(b_dims.begin(), 1); + } + } + + + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + + bool transA = GetTransA(node); + bool transB = GetTransB(node); + + dnnl::memory::dim M = (transA) ? a_dims[1] : a_dims[0]; + dnnl::memory::dim K = (transA) ? a_dims[0] : a_dims[1]; + dnnl::memory::dim N = (transB) ? b_dims[0] : b_dims[1]; + + dnnl::memory::dims a_strides = (transA) ? dnnl::memory::dims{dnnl::memory::dim(1), M} : dnnl::memory::dims{K, dnnl::memory::dim(1)}; + dnnl::memory::dims b_strides = (transB) ? dnnl::memory::dims{dnnl::memory::dim(1), K} : dnnl::memory::dims{N, dnnl::memory::dim(1)}; + + a_md = dnnl::memory::desc({M, K}, node.Input(IN_A).Type(), a_strides); + b_md = dnnl::memory::desc({K, N}, node.Input(IN_B).Type(), b_strides); + + dnnl::memory::dims output_shape{M, N}; + + dnnl::primitive_attr matmul_attr; + // scale the output from MatMul to alpha + float alpha = GetAlpha(node); + std::vector alphaScale({alpha}); + matmul_attr.set_output_scales(0, alphaScale); + + auto matmul_dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), {N, 1}); + + auto matmul_d = dnnl::matmul::desc(a_md, b_md, matmul_dst_md); + dnnl::matmul::primitive_desc matmul_pd; + matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng); + + auto matmul_a_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng, transA); + auto matmul_b_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng, transB); + auto gemm_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng); + + auto matmul_op = dnnl::matmul(matmul_pd); + + std::unordered_map args; + args.insert({DNNL_ARG_SRC, matmul_a_mem}); + args.insert({DNNL_ARG_WEIGHTS, matmul_b_mem}); + args.insert({DNNL_ARG_DST, gemm_dst_mem}); + + sp.AddPrimitive(matmul_op, args); + + if (input_c_exists) { + auto c_original_md = sp.GetMemory(node.Input(IN_C).Name()).get_desc(); + auto c_dims = c_original_md.dims(); + if (c_dims.size() != a_dims.size()) { + while (c_dims.size() < a_dims.size()) { + c_dims.insert(c_dims.begin(), 1); + } + } + + auto c_md = c_original_md.reshape(c_dims); + + auto y_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + + auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, matmul_pd.dst_desc(), c_md, y_md); + + // Scale input C by beta before adding it to the MatMul output. + dnnl::primitive_attr binary_attr; + float beta = GetBeta(node); + binary_attr.set_scales(DNNL_ARG_SRC_1, 0, {beta}); + + auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr,eng); + + auto binary_c_mem = sp.GetMemoryAndReshape(node.Input(IN_C), binary_pd.src1_desc(), eng); + + auto binary_op = dnnl::binary(binary_pd); + + sp.AddPrimitive(binary_op, {{DNNL_ARG_SRC_0, gemm_dst_mem}, + {DNNL_ARG_SRC_1, binary_c_mem}, + {DNNL_ARG_DST, gemm_dst_mem}}); + } + sp.SetMemory(node.Output(OUT_Y), gemm_dst_mem); +} + +float DnnlGemm::GetAlpha(DnnlNode& node) { + auto attr = node.Attributes().find("alpha"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return 1.0; +} +float DnnlGemm::GetBeta(DnnlNode& node) { + auto attr = node.Attributes().find("beta"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return 1.0; +} + +bool DnnlGemm::GetTransA(DnnlNode& node) { + auto attr = node.Attributes().find("transA"); + if (attr != node.Attributes().end()) { + return (attr->second().i() != 0); + } + return false; +} + +bool DnnlGemm::GetTransB(DnnlNode& node) { + auto attr = node.Attributes().find("transB"); + if (attr != node.Attributes().end()) { + return (attr->second().i() != 0); + } + return false; +} +} // namespace ort_dnnl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h new file mode 100644 index 0000000000..1b8f90e9c4 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h @@ -0,0 +1,34 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlGemm { + public: + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_C = 2 + }; + + enum OutputTensors : int { + OUT_Y = 0 + }; + + DnnlGemm(); + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + float GetAlpha(DnnlNode& node); + float GetBeta(DnnlNode& node); + bool GetTransA(DnnlNode& node); + bool GetTransB(DnnlNode& node); +}; + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc index 6baed09262..101c73941c 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc @@ -31,7 +31,7 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } } - auto kernel_shape = GetKernelShape(node); + auto kernel_shape = GetKernelShape(src_dims, node); PoolShape shape = static_cast(kernel_shape.size()); auto strides = GetStrides(node, shape); @@ -118,7 +118,7 @@ dnnl::memory::dims DnnlPool::GetDilations(DnnlNode& node, PoolShape shape) { return dnnl::memory::dims(dilations.begin(), dilations.end()); } -dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) { +dnnl::memory::dims DnnlPool::GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node) { auto attr = node.Attributes().find("kernel_shape"); std::vector kernel_shape; if (attr != node.Attributes().end()) { @@ -128,9 +128,8 @@ dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) { } return kernel_shape; } - // Infer the Kernel shape from the input weights - auto weight_dims = node.Input(IN_X).Dim(); - kernel_shape = std::vector(weight_dims.begin() + 2, weight_dims.end()); + + kernel_shape = std::vector(src_dims.begin() + 2, src_dims.end()); return kernel_shape; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h index 4edaa69527..5d8e4458ec 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.h @@ -34,7 +34,7 @@ class DnnlPool { int64_t GetCeilMode(DnnlNode& node); int64_t GetCountIncludePadding(DnnlNode& node); dnnl::memory::dims GetDilations(DnnlNode& node, PoolShape shape); - dnnl::memory::dims GetKernelShape(DnnlNode& node); + dnnl::memory::dims GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node); /* This will return the calculated padding taking into account the DEPRECATED auto_pad attribute */ std::vector InferPadding(DnnlNode& node, const dnnl::memory::dims& src_dims, const dnnl::memory::dims& kernel_shape, const dnnl::memory::dims& strides); std::vector GetPadding(DnnlNode& node, PoolShape shape); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 43f3af57d9..b3b1e73e42 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -6,6 +6,7 @@ #include "dnnl_batchnorm.h" #include "dnnl_binary.h" #include "dnnl_conv.h" +#include "dnnl_gemm.h" #include "dnnl_lrn.h" #include "dnnl_matmul.h" #include "dnnl_matmul_integer.h" @@ -205,6 +206,8 @@ void DnnlSubgraphPrimitive::AddKernels() { DnnlBatchNorm().CreatePrimitive(*this, node); } else if (node.OpType() == "Conv") { DnnlConv().CreatePrimitive(*this, node); + } else if (node.OpType() == "Gemm") { + DnnlGemm().CreatePrimitive(*this, node); } else if (node.OpType() == "LRN") { DnnlLrn().CreatePrimitive(*this, node); } else if (node.OpType() == "MatMul") { @@ -356,7 +359,7 @@ void DnnlSubgraphPrimitive::SetInitializer(std::string memory_name, dnnl::memory } } -dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng) { +dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose) { // if found just return if (HasMemory(tensor.Name(), mem_desc, eng)) { return GetMemory(tensor.Name(), mem_desc, eng); @@ -372,7 +375,7 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor ten auto mem_to = dnnl::memory(mem_desc, eng); // if it is a reshape, ensure reorder is possible by making the same dims - if (mem_from.get_desc().dims() != mem_to.get_desc().dims()) { + if (mem_from.get_desc().dims() != mem_to.get_desc().dims() || transpose) { auto mem_from_dims = mem_from.get_desc().dims(); auto mem_to_dims = mem_to.get_desc().dims(); if (Product(mem_from_dims) != Product(mem_to_dims)) { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h index ed4bbfb5da..e96cd4750f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h @@ -46,7 +46,7 @@ class DnnlSubgraphPrimitive { dnnl::stream GetStream(); //obtain a dnnl::memory with specified name, memory descriptor and engine, will perform extra reorder/reshape if necessary before returning - dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng); + dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose = false); //add dnnl primitive and memory map to subgraph primitive void AddPrimitive(dnnl::primitive prim, std::unordered_map mem_map); //add a reshape (e.g. squeeze, unsqueeze) to subgraph primitive diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index ddec06b678..6c1573b2a0 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -225,6 +225,67 @@ TEST(GemmOpTest, GemmTransB_1) { TestGemmTransB_1(); } +template +void TestGemmAlpha() { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.5f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {6.0f, 6.0f, 6.0f, + -4.0f, -4.0f, -4.0f}); + //test.AddOutput("Y", {2, 3}, + // {5.0f, 5.0f, 5.0f, + // -5.0f, -5.0f, -5.0f}); +#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser +#endif +} + +TEST(GemmOpTest, GemmAlpha) { + TestGemmAlpha(); + TestGemmAlpha(); +} + +template +void TestGemmBeta() { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 2.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {3}, std::vector(3, 1.0f)); + test.AddOutput("Y", {2, 3}, + {12.0f, 12.0f, 12.0f, + -8.0f, -8.0f, -8.0f}); +#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser +#endif +} + +TEST(GemmOpTest, GemmBeta) { + TestGemmBeta(); + TestGemmBeta(); +} + template void TestGemmAlphaBeta() { OpTester test("Gemm"); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index acb1a04a44..809545cd2e 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -192,6 +192,12 @@ void GradientOpTester::Run( //if node is not registered for the provider, skip node.SetExecutionProviderType(provider_type); + + // provider types that don't use the KernelRegistry + if (provider_type == onnxruntime::kDnnlExecutionProvider) { + continue; + } + auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; auto st = reg->TryFindKernel(node, execution_provider->Type(), &kci); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index bec6f7c0dc..721667509c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -663,14 +663,6 @@ TEST(GradientCheckerTest, ReluGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } -#ifdef USE_DNNL -TEST(GradientCheckerTest, ReluGradDnnl) { - std::vector> execution_providers; - execution_providers.push_back(DefaultDnnlExecutionProvider()); - UnaryOpGradientTest("Relu", kOnnxDomain, 9, &execution_providers); -} -#endif // USE_DNNL - TEST(GradientCheckerTest, CastGrad) { // A dummy test that cast float to float // TODO: add more test here @@ -728,7 +720,7 @@ static std::vector> GetRandomValuesForMaxPool(const std::vector>* execution_provider) { +TEST(GradientCheckerTest, MaxPoolGrad) { float max_error; GradientChecker gradient_checker; OpDef op_def{"MaxPool"}; @@ -737,9 +729,7 @@ void MaxpoolGradientCheckerTest(std::vector> { gradient_checker.ComputeGradientError(op_def, {{2, 2, 9}}, {{2, 2, 8}}, &max_error, GetRandomValuesForMaxPool({{2, 2, 9}}), - {MakeAttribute("kernel_shape", std::vector{2})}, - true, false, - execution_provider); + {MakeAttribute("kernel_shape", std::vector{2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -748,9 +738,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{2, 3, 5, 5}}, {{2, 3, 4, 4}}, &max_error, GetRandomValuesForMaxPool({{2, 3, 5, 5}}), {MakeAttribute("kernel_shape", std::vector{2, 2}), - MakeAttribute("strides", std::vector{1, 1})}, - true, false, - execution_provider); + MakeAttribute("strides", std::vector{1, 1})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -759,9 +747,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{1, 1, 5, 5}}, {{1, 1, 7, 7}}, &max_error, GetRandomValuesForMaxPool({{1, 1, 5, 5}}), {MakeAttribute("kernel_shape", std::vector{3, 3}), - MakeAttribute("pads", std::vector{2, 2, 2, 2})}, - true, false, - execution_provider); + MakeAttribute("pads", std::vector{2, 2, 2, 2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -770,9 +756,7 @@ void MaxpoolGradientCheckerTest(std::vector> gradient_checker.ComputeGradientError(op_def, {{1, 1, 32, 32}}, {{1, 1, 10, 10}}, &max_error, GetRandomValuesForMaxPool({{1, 1, 32, 32}}), {MakeAttribute("kernel_shape", std::vector{5, 5}), - MakeAttribute("strides", std::vector{3, 3})}, - true, false, - execution_provider); + MakeAttribute("strides", std::vector{3, 3})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -780,23 +764,11 @@ void MaxpoolGradientCheckerTest(std::vector> { gradient_checker.ComputeGradientError(op_def, {{2, 1, 3, 3, 3}}, {{2, 1, 2, 2, 2}}, &max_error, GetRandomValuesForMaxPool({{2, 1, 3, 3, 3}}), - {MakeAttribute("kernel_shape", std::vector{2, 2, 2})}, - true, false, - execution_provider); + {MakeAttribute("kernel_shape", std::vector{2, 2, 2})}); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } -TEST(GradientCheckerTest, MaxPoolGrad) { - MaxpoolGradientCheckerTest(nullptr); - -#ifdef USE_DNNL - std::vector> execution_providers; - execution_providers.push_back(DefaultDnnlExecutionProvider()); - MaxpoolGradientCheckerTest(&execution_providers); -#endif -} - TEST(GradientCheckerTest, GlobalAveragePoolGrad) { float max_error; GradientChecker gradient_checker; @@ -1057,15 +1029,17 @@ void ConvGradientCheckerTest(std::vector>* e TEST(GradientCheckerTest, ConvGrad) { std::vector> execution_providers; +#ifdef USE_DNNL + // Dnnl EP does not run for ConvGrad unless it is pushed first. + execution_providers.push_back(DefaultDnnlExecutionProvider()); +#endif + execution_providers.push_back(DefaultCpuExecutionProvider()); if (HasCudaEnvironment(700)) { execution_providers.push_back(DefaultCudaExecutionProvider()); } -#ifdef USE_DNNL - execution_providers.push_back(DefaultDnnlExecutionProvider()); -#endif ConvGradientCheckerTest(&execution_providers); }