diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 62d354ca8c..79e440c954 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -372,6 +372,30 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const { return true; } +// DnnlBinaryNodeCapability class +//------------------------------------- +bool DnnlElementwiseCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsDimensionSupported(node)) return false; + return true; +} + +bool DnnlElementwiseCapability::IsDimensionSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); + if (node_inputs[0]->Shape() == nullptr) { + return true; + } + + // OneDNN will silently convert scaler values to a {1} tensor which causes issues for + // for Onnruntime when it expects an empty tensor i.e. {} + // TODO convert {1} outputs back to scaler {} once that is done DnnlElementwiseCapability + // can be removed and just us the DnnlDefaultNodeCapability. + if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 0) { + return false; + } + return true; +} + // DnnlGemmNodeCapability class //------------------------------------- bool DnnlGemmNodeCapability::Supported(const Node* node) const { diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 702874b773..f8ccc7d5b0 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -207,6 +207,16 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability { bool IsDimensionSupported(const Node* node) const; }; +class DnnlElementwiseCapability : public DnnlDefaultNodeCapability { + public: + DnnlElementwiseCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsDimensionSupported(const Node* node) const; +}; + /** * Decide if a Gemm op is supported by DnnlExecutionProvider */ diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index b7d133c34a..d8b56ba789 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -6,24 +6,33 @@ namespace onnxruntime { DnnlOpManager::DnnlOpManager() { + dnnl_ops_map_.emplace(std::make_pair("Abs", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Add", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr(new DnnlGemmNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Log", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr(new DnnlMatMulNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MatMulInteger", std::unique_ptr(new DnnlMatMulIntegerNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Mul", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("ReduceMean", std::unique_ptr(new DnnlReduceMeanNodeCapability()))); - dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Round", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Sigmoid", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Softmax", std::unique_ptr(new DnnlSoftmaxNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Softplus", std::unique_ptr(new DnnlElementwiseCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Sqrt", std::unique_ptr(new DnnlElementwiseCapability()))); dnnl_ops_map_.emplace(std::make_pair("Sub", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr(new DnnlSumNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Tanh", std::unique_ptr(new DnnlElementwiseCapability()))); #if defined(ENABLE_TRAINING) dnnl_ops_map_.emplace(std::make_pair("AveragePoolGrad", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr(new DnnlDefaultNodeCapability()))); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc new file mode 100644 index 0000000000..640039f4f9 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc @@ -0,0 +1,80 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_elementwise.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +DnnlElementwise::DnnlElementwise() { + default_alpha_ = 1.0; +} + +void DnnlElementwise::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + auto dnnl_engine = sp.GetEngine(); + + auto elementwise_src_mem = sp.GetMemory(node.Input(IN_X).Name()); + auto src_md = elementwise_src_mem.get_desc(); + dnnl::algorithm algo; + bool requires_alpha = false; + float alpha = 0.0; + if (node.OpType() == "Abs") { + algo = dnnl::algorithm::eltwise_abs; + } else if (node.OpType() == "Elu") { + requires_alpha = true; + default_alpha_ = 1.0; + alpha = GetAlpha(node); + algo = dnnl::algorithm::eltwise_elu; + } else if (node.OpType() == "Exp") { + algo = dnnl::algorithm::eltwise_exp; + } else if (node.OpType() == "Log") { + algo = dnnl::algorithm::eltwise_log; + } else if (node.OpType() == "Relu") { + algo = dnnl::algorithm::eltwise_relu; + } else if (node.OpType() == "Round") { + algo = dnnl::algorithm::eltwise_round; + } else if (node.OpType() == "Sigmoid") { + // in OneDNN eltwise_logistic is defined as 1/(1 + exp(-x)) which matches the definition of "Sigmoid" in Onnx + algo = dnnl::algorithm::eltwise_logistic; + } else if (node.OpType() == "Softplus") { + // in OneDNN eltwise_soft_relu is defined as ln(1 + exp(x)) which matches the definition of "Softplus" in Onnx + algo = dnnl::algorithm::eltwise_soft_relu; + } else if (node.OpType() == "Sqrt") { + algo = dnnl::algorithm::eltwise_sqrt; + } else if (node.OpType() == "Tanh") { + algo = dnnl::algorithm::eltwise_tanh; + } else { + ORT_THROW("op type not supported"); + } + dnnl::eltwise_forward::primitive_desc elementwise_pd; + if (requires_alpha) { + auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md, alpha); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + } else { + auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md); + elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine); + } + + // If using GPU this will move the memory from the CPU to the GPU. + elementwise_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), elementwise_pd.src_desc(), dnnl_engine); + auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine); + + auto elemenwise_primitive = dnnl::eltwise_forward(elementwise_pd); + sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, elementwise_src_mem}, + {DNNL_ARG_DST, elementwise_dst_mem}}); + + sp.SetMemory(node.Output(OUT_Y), elementwise_dst_mem); +} + +float DnnlElementwise::GetAlpha(DnnlNode& node) { + auto attr = node.Attributes().find("alpha"); + if (attr != node.Attributes().end()) { + return attr->second().f(); + } + return default_alpha_; +} + +} // namespace ort_dnnl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h new file mode 100644 index 0000000000..e366ccb2ae --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h @@ -0,0 +1,41 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlElementwise { + public: + enum InputTensors : int { + IN_X = 0, + }; + + enum OutputTensors : int { + OUT_Y = 0 + }; + + DnnlElementwise(); + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + /* + * GetAlpha will get the 'alpha' attribute if the attribute is not found + * the the `default_alph_` will be returned instead. This is set to 1.0 + * by the `DnnlElementwise` constructor but should be updated for any operator + * that has an 'alpha' property. + * + * See how `GetAlpha` is called for the 'Elu' operator in the `CreatePrimitive` code. + * + * Note: The number of operators that use the 'alpha' attribute is much smaller than + * initially expected. + */ + float GetAlpha(DnnlNode& node); + float default_alpha_; +}; + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc deleted file mode 100644 index bb7c4e566b..0000000000 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.cc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright(C) 2021 Intel Corporation -// Licensed under the MIT License - -#include "dnnl_relu.h" -#include "dnnl_subgraph.h" -#include "dnnl_subgraph_primitive.h" - -namespace onnxruntime { -namespace ort_dnnl { - -DnnlRelu::DnnlRelu() {} - -void DnnlRelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { - - auto dnnl_engine = sp.GetEngine(); - - auto relu_src_mem = sp.GetMemory(node.Input(IN_X).Name()); - auto src_md = relu_src_mem.get_desc(); - - auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, src_md); - auto relu_pd = dnnl::eltwise_forward::primitive_desc(relu_desc, dnnl_engine); - - // If using GPU this will move the memory from the CPU to the GPU. - relu_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), relu_pd.src_desc(), dnnl_engine); - auto relu_dst_mem = dnnl::memory(relu_pd.dst_desc(), dnnl_engine); - - auto relu_op = dnnl::eltwise_forward(relu_pd); - sp.AddPrimitive(relu_op, {{DNNL_ARG_SRC, relu_src_mem}, - {DNNL_ARG_DST, relu_dst_mem}}); - - sp.SetMemory(node.Output(OUT_Y), relu_dst_mem); -} - -} // namespace ort_dnnl -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h deleted file mode 100644 index 942be16b86..0000000000 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_relu.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright(C) 2021 Intel Corporation -// Licensed under the MIT License - -#pragma once -#include "dnnl_subgraph.h" -#include "dnnl_subgraph_primitive.h" - -namespace onnxruntime { -namespace ort_dnnl { - -class DnnlRelu { - public: - enum InputTensors : int { - IN_X = 0, - }; - - enum OutputTensors : int { - OUT_Y = 0 - }; - - DnnlRelu(); - void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); -}; - -} // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index b3b1e73e42..6432ef9632 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -6,13 +6,13 @@ #include "dnnl_batchnorm.h" #include "dnnl_binary.h" #include "dnnl_conv.h" +#include "dnnl_elementwise.h" #include "dnnl_gemm.h" #include "dnnl_lrn.h" #include "dnnl_matmul.h" #include "dnnl_matmul_integer.h" #include "dnnl_pool.h" #include "dnnl_reducemean.h" -#include "dnnl_relu.h" #include "dnnl_softmax.h" #include "dnnl_sum.h" @@ -37,6 +37,49 @@ int Product(dnnl::memory::dims d) { return result; } +void DnnlSubgraphPrimitive::AddKernels() { + std::unordered_set binary_ops = {"Add", "Div", "Mul", "Sub"}; + std::unordered_set elementwise_ops = {"Abs", "Elu", "Exp","Log", "Relu", "Round", "Sigmoid", "Softplus", "Sqrt", "Tanh"}; + std::unordered_set pool_ops = {"AveragePool", "GlobalAveragePool", "GlobalMaxPool", "MaxPool"}; + for (auto& node : subgraph_->GetDnnlNodes()) { + if (node.OpType() == "BatchNormalization") { + DnnlBatchNorm().CreatePrimitive(*this, node); + } else if (binary_ops.count(node.OpType())) { + DnnlBinary().CreatePrimitive(*this, node); + } else if (node.OpType() == "Conv") { + DnnlConv().CreatePrimitive(*this, node); + } else if (elementwise_ops.count(node.OpType())) { + DnnlElementwise().CreatePrimitive(*this, node); + } else if (node.OpType() == "Gemm") { + DnnlGemm().CreatePrimitive(*this, node); + } else if (node.OpType() == "LRN") { + DnnlLrn().CreatePrimitive(*this, node); + } else if (node.OpType() == "MatMul") { + DnnlMatMul().CreatePrimitive(*this, node); + } else if (node.OpType() == "MatMulInteger") { + DnnlMatMulInteger().CreatePrimitive(*this, node); + } else if (pool_ops.count(node.OpType())) { + DnnlPool().CreatePrimitive(*this, node); + } else if (node.OpType() == "ReduceMean") { + DnnlReduceMean().CreatePrimitive(*this, node); + } else if (node.OpType() == "Softmax") { + DnnlSoftmax().CreatePrimitive(*this, node); + } else if (node.OpType() == "Sum") { + DnnlSum().CreatePrimitive(*this, node); +#if defined(ENABLE_TRAINING) + } else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") { + DnnlPoolGrad().CreatePrimitive(*this, node); + } else if (node.OpType() == "ConvGrad") { + DnnlConvGrad().CreatePrimitive(*this, node); + } else if (node.OpType() == "ReluGrad") { + DnnlReluGrad().CreatePrimitive(*this, node); +#endif + } else { + throw std::invalid_argument("Kernel not found"); + } + } +} + DnnlSubgraphPrimitive::DnnlSubgraphPrimitive(ort_dnnl::DnnlSubgraph& dnnl_subgraph) { subgraph_ = &dnnl_subgraph; if (dnnl_engine_get_count(dnnl_engine_kind_t::dnnl_cpu)) { @@ -194,48 +237,6 @@ void DnnlSubgraphPrimitive::AddInitializers() { } } -void DnnlSubgraphPrimitive::AddKernels() { - std::unordered_set binary_ops = {"Add", "Mul", "Sub", "Div"}; - for (auto& node : subgraph_->GetDnnlNodes()) { - if (node.OpType() == "AveragePool" || node.OpType() == "GlobalAveragePool" || - node.OpType() == "GlobalMaxPool" || node.OpType() == "MaxPool") { - DnnlPool().CreatePrimitive(*this, node); - } else if (binary_ops.count(node.OpType())) { - DnnlBinary().CreatePrimitive(*this, node); - } else if (node.OpType() == "BatchNormalization") { - DnnlBatchNorm().CreatePrimitive(*this, node); - } else if (node.OpType() == "Conv") { - DnnlConv().CreatePrimitive(*this, node); - } else if (node.OpType() == "Gemm") { - DnnlGemm().CreatePrimitive(*this, node); - } else if (node.OpType() == "LRN") { - DnnlLrn().CreatePrimitive(*this, node); - } else if (node.OpType() == "MatMul") { - DnnlMatMul().CreatePrimitive(*this, node); - } else if (node.OpType() == "MatMulInteger") { - DnnlMatMulInteger().CreatePrimitive(*this, node); - } else if (node.OpType() == "ReduceMean") { - DnnlReduceMean().CreatePrimitive(*this, node); - } else if (node.OpType() == "Relu") { - DnnlRelu().CreatePrimitive(*this, node); - } else if (node.OpType() == "Softmax") { - DnnlSoftmax().CreatePrimitive(*this, node); - } else if (node.OpType() == "Sum") { - DnnlSum().CreatePrimitive(*this, node); -#if defined(ENABLE_TRAINING) - } else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") { - DnnlPoolGrad().CreatePrimitive(*this, node); - } else if (node.OpType() == "ConvGrad") { - DnnlConvGrad().CreatePrimitive(*this, node); - } else if (node.OpType() == "ReluGrad") { - DnnlReluGrad().CreatePrimitive(*this, node); -#endif - } else { - throw std::invalid_argument("Kernel not found"); - } - } -} - void DnnlSubgraphPrimitive::AddOutputs() { for (auto& tensor : subgraph_->GetDnnlOutputs()) { auto dnnl_data_type = tensor.Type(); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 721667509c..79e13d50c9 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -441,7 +441,11 @@ TEST(GradientCheckerTest, LogGrad) { TensorInfo x_info{shape, true, &transformer}; float max_error; + #ifdef USE_DNNL + float error_tolerance = 3e-3f; + #else float error_tolerance = 1e-3f; + #endif GradientChecker gradient_checker; OpDef op_def{"Log"};