Add elementwise operators to DNNL execution provider (#8899)

The following ops have been added to the DNNL execution provider
Abs, Elu, Exp, Log, *Relu, Round, Sigmoid, Softplus, Sqrt, and Tanh

*Relu op was moved from its individual file to the elementwise operators

The error tolerance for the LogGrad unit test had to be decreased slightly
when using OneDNN.  Still investigating why a differet tolerance value is
needed.

DnnlSubgraph::AddKernels() member function was moved to the top of the file
since this is eddited every time a new operator is added to the the execution
provider this places the code at the top which mean less scrooling when
adding new kernels.

Signed-off-by: George Nash <george.nash@intel.com>
This commit is contained in:
George Nash 2021-08-31 12:20:49 -07:00 committed by GitHub
parent 2e37fe3f68
commit dc75a135c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 213 additions and 105 deletions

View file

@ -372,6 +372,30 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const {
return true;
}
// DnnlBinaryNodeCapability class
//-------------------------------------
bool DnnlElementwiseCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
if (!IsDimensionSupported(node)) return false;
return true;
}
bool DnnlElementwiseCapability::IsDimensionSupported(const Node* node) const {
auto node_inputs = node->InputDefs();
if (node_inputs[0]->Shape() == nullptr) {
return true;
}
// OneDNN will silently convert scaler values to a {1} tensor which causes issues for
// for Onnruntime when it expects an empty tensor i.e. {}
// TODO convert {1} outputs back to scaler {} once that is done DnnlElementwiseCapability
// can be removed and just us the DnnlDefaultNodeCapability.
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 0) {
return false;
}
return true;
}
// DnnlGemmNodeCapability class
//-------------------------------------
bool DnnlGemmNodeCapability::Supported(const Node* node) const {

View file

@ -207,6 +207,16 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability {
bool IsDimensionSupported(const Node* node) const;
};
class DnnlElementwiseCapability : public DnnlDefaultNodeCapability {
public:
DnnlElementwiseCapability() : DnnlDefaultNodeCapability({"float"}) {}
bool Supported(const Node* node) const override;
private:
bool IsDimensionSupported(const Node* node) const;
};
/**
* Decide if a Gemm op is supported by DnnlExecutionProvider
*/

View file

@ -6,24 +6,33 @@
namespace onnxruntime {
DnnlOpManager::DnnlOpManager() {
dnnl_ops_map_.emplace(std::make_pair("Abs", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Add", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr<DnnlNodeCapability>(new DnnlGemmNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Log", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MatMulInteger", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulIntegerNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Mul", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("ReduceMean", std::unique_ptr<DnnlNodeCapability>(new DnnlReduceMeanNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Round", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Sigmoid", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Softmax", std::unique_ptr<DnnlNodeCapability>(new DnnlSoftmaxNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Softplus", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Sqrt", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("Sub", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr<DnnlNodeCapability>(new DnnlSumNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Tanh", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
#if defined(ENABLE_TRAINING)
dnnl_ops_map_.emplace(std::make_pair("AveragePoolGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));

View file

@ -0,0 +1,80 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#include "dnnl_elementwise.h"
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
DnnlElementwise::DnnlElementwise() {
default_alpha_ = 1.0;
}
void DnnlElementwise::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
auto dnnl_engine = sp.GetEngine();
auto elementwise_src_mem = sp.GetMemory(node.Input(IN_X).Name());
auto src_md = elementwise_src_mem.get_desc();
dnnl::algorithm algo;
bool requires_alpha = false;
float alpha = 0.0;
if (node.OpType() == "Abs") {
algo = dnnl::algorithm::eltwise_abs;
} else if (node.OpType() == "Elu") {
requires_alpha = true;
default_alpha_ = 1.0;
alpha = GetAlpha(node);
algo = dnnl::algorithm::eltwise_elu;
} else if (node.OpType() == "Exp") {
algo = dnnl::algorithm::eltwise_exp;
} else if (node.OpType() == "Log") {
algo = dnnl::algorithm::eltwise_log;
} else if (node.OpType() == "Relu") {
algo = dnnl::algorithm::eltwise_relu;
} else if (node.OpType() == "Round") {
algo = dnnl::algorithm::eltwise_round;
} else if (node.OpType() == "Sigmoid") {
// in OneDNN eltwise_logistic is defined as 1/(1 + exp(-x)) which matches the definition of "Sigmoid" in Onnx
algo = dnnl::algorithm::eltwise_logistic;
} else if (node.OpType() == "Softplus") {
// in OneDNN eltwise_soft_relu is defined as ln(1 + exp(x)) which matches the definition of "Softplus" in Onnx
algo = dnnl::algorithm::eltwise_soft_relu;
} else if (node.OpType() == "Sqrt") {
algo = dnnl::algorithm::eltwise_sqrt;
} else if (node.OpType() == "Tanh") {
algo = dnnl::algorithm::eltwise_tanh;
} else {
ORT_THROW("op type not supported");
}
dnnl::eltwise_forward::primitive_desc elementwise_pd;
if (requires_alpha) {
auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md, alpha);
elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine);
} else {
auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md);
elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine);
}
// If using GPU this will move the memory from the CPU to the GPU.
elementwise_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), elementwise_pd.src_desc(), dnnl_engine);
auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine);
auto elemenwise_primitive = dnnl::eltwise_forward(elementwise_pd);
sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, elementwise_src_mem},
{DNNL_ARG_DST, elementwise_dst_mem}});
sp.SetMemory(node.Output(OUT_Y), elementwise_dst_mem);
}
float DnnlElementwise::GetAlpha(DnnlNode& node) {
auto attr = node.Attributes().find("alpha");
if (attr != node.Attributes().end()) {
return attr->second().f();
}
return default_alpha_;
}
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -0,0 +1,41 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
class DnnlElementwise {
public:
enum InputTensors : int {
IN_X = 0,
};
enum OutputTensors : int {
OUT_Y = 0
};
DnnlElementwise();
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
private:
/*
* GetAlpha will get the 'alpha' attribute if the attribute is not found
* the the `default_alph_` will be returned instead. This is set to 1.0
* by the `DnnlElementwise` constructor but should be updated for any operator
* that has an 'alpha' property.
*
* See how `GetAlpha` is called for the 'Elu' operator in the `CreatePrimitive` code.
*
* Note: The number of operators that use the 'alpha' attribute is much smaller than
* initially expected.
*/
float GetAlpha(DnnlNode& node);
float default_alpha_;
};
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -1,35 +0,0 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#include "dnnl_relu.h"
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
DnnlRelu::DnnlRelu() {}
void DnnlRelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
auto dnnl_engine = sp.GetEngine();
auto relu_src_mem = sp.GetMemory(node.Input(IN_X).Name());
auto src_md = relu_src_mem.get_desc();
auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, src_md);
auto relu_pd = dnnl::eltwise_forward::primitive_desc(relu_desc, dnnl_engine);
// If using GPU this will move the memory from the CPU to the GPU.
relu_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), relu_pd.src_desc(), dnnl_engine);
auto relu_dst_mem = dnnl::memory(relu_pd.dst_desc(), dnnl_engine);
auto relu_op = dnnl::eltwise_forward(relu_pd);
sp.AddPrimitive(relu_op, {{DNNL_ARG_SRC, relu_src_mem},
{DNNL_ARG_DST, relu_dst_mem}});
sp.SetMemory(node.Output(OUT_Y), relu_dst_mem);
}
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -1,26 +0,0 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
class DnnlRelu {
public:
enum InputTensors : int {
IN_X = 0,
};
enum OutputTensors : int {
OUT_Y = 0
};
DnnlRelu();
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
};
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -6,13 +6,13 @@
#include "dnnl_batchnorm.h"
#include "dnnl_binary.h"
#include "dnnl_conv.h"
#include "dnnl_elementwise.h"
#include "dnnl_gemm.h"
#include "dnnl_lrn.h"
#include "dnnl_matmul.h"
#include "dnnl_matmul_integer.h"
#include "dnnl_pool.h"
#include "dnnl_reducemean.h"
#include "dnnl_relu.h"
#include "dnnl_softmax.h"
#include "dnnl_sum.h"
@ -37,6 +37,49 @@ int Product(dnnl::memory::dims d) {
return result;
}
void DnnlSubgraphPrimitive::AddKernels() {
std::unordered_set<std::string> binary_ops = {"Add", "Div", "Mul", "Sub"};
std::unordered_set<std::string> elementwise_ops = {"Abs", "Elu", "Exp","Log", "Relu", "Round", "Sigmoid", "Softplus", "Sqrt", "Tanh"};
std::unordered_set<std::string> pool_ops = {"AveragePool", "GlobalAveragePool", "GlobalMaxPool", "MaxPool"};
for (auto& node : subgraph_->GetDnnlNodes()) {
if (node.OpType() == "BatchNormalization") {
DnnlBatchNorm().CreatePrimitive(*this, node);
} else if (binary_ops.count(node.OpType())) {
DnnlBinary().CreatePrimitive(*this, node);
} else if (node.OpType() == "Conv") {
DnnlConv().CreatePrimitive(*this, node);
} else if (elementwise_ops.count(node.OpType())) {
DnnlElementwise().CreatePrimitive(*this, node);
} else if (node.OpType() == "Gemm") {
DnnlGemm().CreatePrimitive(*this, node);
} else if (node.OpType() == "LRN") {
DnnlLrn().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMul") {
DnnlMatMul().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMulInteger") {
DnnlMatMulInteger().CreatePrimitive(*this, node);
} else if (pool_ops.count(node.OpType())) {
DnnlPool().CreatePrimitive(*this, node);
} else if (node.OpType() == "ReduceMean") {
DnnlReduceMean().CreatePrimitive(*this, node);
} else if (node.OpType() == "Softmax") {
DnnlSoftmax().CreatePrimitive(*this, node);
} else if (node.OpType() == "Sum") {
DnnlSum().CreatePrimitive(*this, node);
#if defined(ENABLE_TRAINING)
} else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") {
DnnlPoolGrad().CreatePrimitive(*this, node);
} else if (node.OpType() == "ConvGrad") {
DnnlConvGrad().CreatePrimitive(*this, node);
} else if (node.OpType() == "ReluGrad") {
DnnlReluGrad().CreatePrimitive(*this, node);
#endif
} else {
throw std::invalid_argument("Kernel not found");
}
}
}
DnnlSubgraphPrimitive::DnnlSubgraphPrimitive(ort_dnnl::DnnlSubgraph& dnnl_subgraph) {
subgraph_ = &dnnl_subgraph;
if (dnnl_engine_get_count(dnnl_engine_kind_t::dnnl_cpu)) {
@ -194,48 +237,6 @@ void DnnlSubgraphPrimitive::AddInitializers() {
}
}
void DnnlSubgraphPrimitive::AddKernels() {
std::unordered_set<std::string> binary_ops = {"Add", "Mul", "Sub", "Div"};
for (auto& node : subgraph_->GetDnnlNodes()) {
if (node.OpType() == "AveragePool" || node.OpType() == "GlobalAveragePool" ||
node.OpType() == "GlobalMaxPool" || node.OpType() == "MaxPool") {
DnnlPool().CreatePrimitive(*this, node);
} else if (binary_ops.count(node.OpType())) {
DnnlBinary().CreatePrimitive(*this, node);
} else if (node.OpType() == "BatchNormalization") {
DnnlBatchNorm().CreatePrimitive(*this, node);
} else if (node.OpType() == "Conv") {
DnnlConv().CreatePrimitive(*this, node);
} else if (node.OpType() == "Gemm") {
DnnlGemm().CreatePrimitive(*this, node);
} else if (node.OpType() == "LRN") {
DnnlLrn().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMul") {
DnnlMatMul().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMulInteger") {
DnnlMatMulInteger().CreatePrimitive(*this, node);
} else if (node.OpType() == "ReduceMean") {
DnnlReduceMean().CreatePrimitive(*this, node);
} else if (node.OpType() == "Relu") {
DnnlRelu().CreatePrimitive(*this, node);
} else if (node.OpType() == "Softmax") {
DnnlSoftmax().CreatePrimitive(*this, node);
} else if (node.OpType() == "Sum") {
DnnlSum().CreatePrimitive(*this, node);
#if defined(ENABLE_TRAINING)
} else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") {
DnnlPoolGrad().CreatePrimitive(*this, node);
} else if (node.OpType() == "ConvGrad") {
DnnlConvGrad().CreatePrimitive(*this, node);
} else if (node.OpType() == "ReluGrad") {
DnnlReluGrad().CreatePrimitive(*this, node);
#endif
} else {
throw std::invalid_argument("Kernel not found");
}
}
}
void DnnlSubgraphPrimitive::AddOutputs() {
for (auto& tensor : subgraph_->GetDnnlOutputs()) {
auto dnnl_data_type = tensor.Type();

View file

@ -441,7 +441,11 @@ TEST(GradientCheckerTest, LogGrad) {
TensorInfo x_info{shape, true, &transformer};
float max_error;
#ifdef USE_DNNL
float error_tolerance = 3e-3f;
#else
float error_tolerance = 1e-3f;
#endif
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Log"};