From dc75a135c8fdcd7877e429fad9e87d5357c04772 Mon Sep 17 00:00:00 2001 From: George Nash Date: Tue, 31 Aug 2021 12:20:49 -0700 Subject: [PATCH] Add elementwise operators to DNNL execution provider (#8899) The following ops have been added to the DNNL execution provider Abs, Elu, Exp, Log, *Relu, Round, Sigmoid, Softplus, Sqrt, and Tanh *Relu op was moved from its individual file to the elementwise operators The error tolerance for the LogGrad unit test had to be decreased slightly when using OneDNN. Still investigating why a differet tolerance value is needed. DnnlSubgraph::AddKernels() member function was moved to the top of the file since this is eddited every time a new operator is added to the the execution provider this places the code at the top which mean less scrooling when adding new kernels. Signed-off-by: George Nash --- .../providers/dnnl/dnnl_node_capability.cc | 24 +++++ .../providers/dnnl/dnnl_node_capability.h | 10 +++ .../core/providers/dnnl/dnnl_op_manager.cc | 11 ++- .../dnnl/subgraph/dnnl_elementwise.cc | 80 +++++++++++++++++ .../dnnl/subgraph/dnnl_elementwise.h | 41 +++++++++ .../core/providers/dnnl/subgraph/dnnl_relu.cc | 35 -------- .../core/providers/dnnl/subgraph/dnnl_relu.h | 26 ------ .../dnnl/subgraph/dnnl_subgraph_primitive.cc | 87 ++++++++++--------- .../test/gradient/gradient_ops_test.cc | 4 + 9 files changed, 213 insertions(+), 105 deletions(-) create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h delete mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc delete mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 62d354ca8c..79e440c954 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -372,6 +372,30 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlBinaryNodeCapability class +//------------------------------------- +bool DnnlElementwiseCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsDimensionSupported(node)) return false; + return true; +} + +bool DnnlElementwiseCapability::IsDimensionSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); + if (node_inputs[0]->Shape() == nullptr) { + return true; + } + + // OneDNN will silently convert scaler values to a {1} tensor which causes issues for + // for Onnruntime when it expects an empty tensor i.e. {} + // TODO convert {1} outputs back to scaler {} once that is done DnnlElementwiseCapability + // can be removed and just us the DnnlDefaultNodeCapability. + if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 0) { + return false; + } + return true; +} + // DnnlGemmNodeCapability class //------------------------------------- bool DnnlGemmNodeCapability::Supported(const Node* node) const { diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 702874b773..f8ccc7d5b0 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -207,6 +207,16 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability { bool IsDimensionSupported(const Node* node) const; }; +class DnnlElementwiseCapability : public DnnlDefaultNodeCapability { + public: + DnnlElementwiseCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsDimensionSupported(const Node* node) const; +}; + /** * Decide if a Gemm op is supported by DnnlExecutionProvider */ diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index b7d133c34a..d8b56ba789 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -6,24 +6,33 @@ namespace onnxruntime { DnnlOpManager::DnnlOpManager() { + dnnl_ops_map_.emplace(std::make_pair("Abs", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Add", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr(new DnnlGemmNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Log", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr(new DnnlMatMulNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MatMulInteger", std::unique_ptr(new DnnlMatMulIntegerNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Mul", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("ReduceMean", std::unique_ptr(new DnnlReduceMeanNodeCapability()))); - dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Round", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Sigmoid", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Softmax", std::unique_ptr(new DnnlSoftmaxNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Softplus", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Sqrt", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Sub", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr(new DnnlSumNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Tanh", std::unique_ptr(new DnnlElementwiseCapability()))); #if defined(ENABLE_TRAINING) dnnl_ops_map_.emplace(std::make_pair("AveragePoolGrad", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr(new DnnlDefaultNodeCapability()))); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc new file mode 100644 index 0000000000..640039f4f9 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc @@ -0,0 +1,80 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_elementwise.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +DnnlElementwise::DnnlElementwise() { + default_alpha_ = 1.0; +} + +void DnnlElementwise::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + auto dnnl_engine = sp.GetEngine(); + + auto elementwise_src_mem = sp.GetMemory(node.Input(IN_X).Name()); + auto src_md = elementwise_src_mem.get_desc(); + dnnl::algorithm algo; + bool requires_alpha = false; + float alpha = 0.0; + if (node.OpType() == "Abs") { + algo = dnnl::algorithm::eltwise_abs; + } else if (node.OpType() == "Elu") { + requires_alpha = true; + default_alpha_ = 1.0; + alpha = GetAlpha(node); + algo = dnnl::algorithm::eltwise_elu; + } else if (node.OpType() == "Exp") { + algo = dnnl::algorithm::eltwise_exp; + } else if (node.OpType() == "Log") { + algo = dnnl::algorithm::eltwise_log; + } else if (node.OpType() == "Relu") { + algo = dnnl::algorithm::eltwise_relu; + } else if (node.OpType() == "Round") { + algo = dnnl::algorithm::eltwise_round; + } else if (node.OpType() == "Sigmoid") { + // in OneDNN eltwise_logistic is defined as 1/(1 + exp(-x)) which matches the definition of "Sigmoid" in Onnx + algo = dnnl::algorithm::eltwise_logistic; + } else if (node.OpType() == "Softplus") { + // in OneDNN eltwise_soft_relu is defined as ln(1 + exp(x)) which matches the definition of "Softplus" in Onnx + algo = dnnl::algorithm::eltwise_soft_relu; + } else if (node.OpType() == "Sqrt") { + algo = dnnl::algorithm::eltwise_sqrt; + } else if (node.OpType() == "Tanh") { + algo = dnnl::algorithm::eltwise_tanh; + } else { + ORT_THROW("op type not supported"); + } + dnnl::eltwise_forward::primitive_desc elementwise_pd; + if (requires_alpha) { + auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md, alpha); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + } else { + auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + } + + // If using GPU this will move the memory from the CPU to the GPU. + elementwise_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), elementwise_pd.src_desc(), dnnl_engine); + auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine); + + auto elemenwise_primitive = dnnl::eltwise_forward(elementwise_pd); + sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, elementwise_src_mem}, + {DNNL_ARG_DST, elementwise_dst_mem}}); + + sp.SetMemory(node.Output(OUT_Y), elementwise_dst_mem); +} + +float DnnlElementwise::GetAlpha(DnnlNode& node) { + auto attr = node.Attributes().find("alpha"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return default_alpha_; +} + +} // namespace ort_dnnl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h new file mode 100644 index 0000000000..e366ccb2ae --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h @@ -0,0 +1,41 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlElementwise { + public: + enum InputTensors : int { + IN_X = 0, + }; + + enum OutputTensors : int { + OUT_Y = 0 + }; + + DnnlElementwise(); + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + /* + * GetAlpha will get the 'alpha' attribute if the attribute is not found + * the the `default_alph_` will be returned instead. This is set to 1.0 + * by the `DnnlElementwise` constructor but should be updated for any operator + * that has an 'alpha' property. + * + * See how `GetAlpha` is called for the 'Elu' operator in the `CreatePrimitive` code. + * + * Note: The number of operators that use the 'alpha' attribute is much smaller than + * initially expected. + */ + float GetAlpha(DnnlNode& node); + float default_alpha_; +}; + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc deleted file mode 100644 index bb7c4e566b..0000000000 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright(C) 2021 Intel Corporation -// Licensed under the MIT License - -#include "dnnl_relu.h" -#include "dnnl_subgraph.h" -#include "dnnl_subgraph_primitive.h" - -namespace onnxruntime { -namespace ort_dnnl { - -DnnlRelu::DnnlRelu() {} - -void DnnlRelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { - - auto dnnl_engine = sp.GetEngine(); - - auto relu_src_mem = sp.GetMemory(node.Input(IN_X).Name()); - auto src_md = relu_src_mem.get_desc(); - - auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, src_md); - auto relu_pd = dnnl::eltwise_forward::primitive_desc(relu_desc, dnnl_engine); - - // If using GPU this will move the memory from the CPU to the GPU. - relu_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), relu_pd.src_desc(), dnnl_engine); - auto relu_dst_mem = dnnl::memory(relu_pd.dst_desc(), dnnl_engine); - - auto relu_op = dnnl::eltwise_forward(relu_pd); - sp.AddPrimitive(relu_op, {{DNNL_ARG_SRC, relu_src_mem}, - {DNNL_ARG_DST, relu_dst_mem}}); - - sp.SetMemory(node.Output(OUT_Y), relu_dst_mem); -} - -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h deleted file mode 100644 index 942be16b86..0000000000 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright(C) 2021 Intel Corporation -// Licensed under the MIT License - -#pragma once -#include "dnnl_subgraph.h" -#include "dnnl_subgraph_primitive.h" - -namespace onnxruntime { -namespace ort_dnnl { - -class DnnlRelu { - public: - enum InputTensors : int { - IN_X = 0, - }; - - enum OutputTensors : int { - OUT_Y = 0 - }; - - DnnlRelu(); - void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); -}; - -} // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index b3b1e73e42..6432ef9632 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -6,13 +6,13 @@ #include "dnnl_batchnorm.h" #include "dnnl_binary.h" #include "dnnl_conv.h" +#include "dnnl_elementwise.h" #include "dnnl_gemm.h" #include "dnnl_lrn.h" #include "dnnl_matmul.h" #include "dnnl_matmul_integer.h" #include "dnnl_pool.h" #include "dnnl_reducemean.h" -#include "dnnl_relu.h" #include "dnnl_softmax.h" #include "dnnl_sum.h" @@ -37,6 +37,49 @@ int Product(dnnl::memory::dims d) { return result; } +void DnnlSubgraphPrimitive::AddKernels() { + std::unordered_set binary_ops = {"Add", "Div", "Mul", "Sub"}; + std::unordered_set elementwise_ops = {"Abs", "Elu", "Exp","Log", "Relu", "Round", "Sigmoid", "Softplus", "Sqrt", "Tanh"}; + std::unordered_set pool_ops = {"AveragePool", "GlobalAveragePool", "GlobalMaxPool", "MaxPool"}; + for (auto& node : subgraph_->GetDnnlNodes()) { + if (node.OpType() == "BatchNormalization") { + DnnlBatchNorm().CreatePrimitive(*this, node); + } else if (binary_ops.count(node.OpType())) { + DnnlBinary().CreatePrimitive(*this, node); + } else if (node.OpType() == "Conv") { + DnnlConv().CreatePrimitive(*this, node); + } else if (elementwise_ops.count(node.OpType())) { + DnnlElementwise().CreatePrimitive(*this, node); + } else if (node.OpType() == "Gemm") { + DnnlGemm().CreatePrimitive(*this, node); + } else if (node.OpType() == "LRN") { + DnnlLrn().CreatePrimitive(*this, node); + } else if (node.OpType() == "MatMul") { + DnnlMatMul().CreatePrimitive(*this, node); + } else if (node.OpType() == "MatMulInteger") { + DnnlMatMulInteger().CreatePrimitive(*this, node); + } else if (pool_ops.count(node.OpType())) { + DnnlPool().CreatePrimitive(*this, node); + } else if (node.OpType() == "ReduceMean") { + DnnlReduceMean().CreatePrimitive(*this, node); + } else if (node.OpType() == "Softmax") { + DnnlSoftmax().CreatePrimitive(*this, node); + } else if (node.OpType() == "Sum") { + DnnlSum().CreatePrimitive(*this, node); +#if defined(ENABLE_TRAINING) + } else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") { + DnnlPoolGrad().CreatePrimitive(*this, node); + } else if (node.OpType() == "ConvGrad") { + DnnlConvGrad().CreatePrimitive(*this, node); + } else if (node.OpType() == "ReluGrad") { + DnnlReluGrad().CreatePrimitive(*this, node); +#endif + } else { + throw std::invalid_argument("Kernel not found"); + } + } +} + DnnlSubgraphPrimitive::DnnlSubgraphPrimitive(ort_dnnl::DnnlSubgraph& dnnl_subgraph) { subgraph_ = &dnnl_subgraph; if (dnnl_engine_get_count(dnnl_engine_kind_t::dnnl_cpu)) { @@ -194,48 +237,6 @@ void DnnlSubgraphPrimitive::AddInitializers() { } } -void DnnlSubgraphPrimitive::AddKernels() { - std::unordered_set binary_ops = {"Add", "Mul", "Sub", "Div"}; - for (auto& node : subgraph_->GetDnnlNodes()) { - if (node.OpType() == "AveragePool" || node.OpType() == "GlobalAveragePool" || - node.OpType() == "GlobalMaxPool" || node.OpType() == "MaxPool") { - DnnlPool().CreatePrimitive(*this, node); - } else if (binary_ops.count(node.OpType())) { - DnnlBinary().CreatePrimitive(*this, node); - } else if (node.OpType() == "BatchNormalization") { - DnnlBatchNorm().CreatePrimitive(*this, node); - } else if (node.OpType() == "Conv") { - DnnlConv().CreatePrimitive(*this, node); - } else if (node.OpType() == "Gemm") { - DnnlGemm().CreatePrimitive(*this, node); - } else if (node.OpType() == "LRN") { - DnnlLrn().CreatePrimitive(*this, node); - } else if (node.OpType() == "MatMul") { - DnnlMatMul().CreatePrimitive(*this, node); - } else if (node.OpType() == "MatMulInteger") { - DnnlMatMulInteger().CreatePrimitive(*this, node); - } else if (node.OpType() == "ReduceMean") { - DnnlReduceMean().CreatePrimitive(*this, node); - } else if (node.OpType() == "Relu") { - DnnlRelu().CreatePrimitive(*this, node); - } else if (node.OpType() == "Softmax") { - DnnlSoftmax().CreatePrimitive(*this, node); - } else if (node.OpType() == "Sum") { - DnnlSum().CreatePrimitive(*this, node); -#if defined(ENABLE_TRAINING) - } else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") { - DnnlPoolGrad().CreatePrimitive(*this, node); - } else if (node.OpType() == "ConvGrad") { - DnnlConvGrad().CreatePrimitive(*this, node); - } else if (node.OpType() == "ReluGrad") { - DnnlReluGrad().CreatePrimitive(*this, node); -#endif - } else { - throw std::invalid_argument("Kernel not found"); - } - } -} - void DnnlSubgraphPrimitive::AddOutputs() { for (auto& tensor : subgraph_->GetDnnlOutputs()) { auto dnnl_data_type = tensor.Type(); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 721667509c..79e13d50c9 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -441,7 +441,11 @@ TEST(GradientCheckerTest, LogGrad) { TensorInfo x_info{shape, true, &transformer}; float max_error; + #ifdef USE_DNNL + float error_tolerance = 3e-3f; + #else float error_tolerance = 1e-3f; + #endif GradientChecker gradient_checker; OpDef op_def{"Log"};