Add DnnlOpManager (#7521)

* Add  DnnlOpManager

The DnnlOpManager is able to more accurately check if a node is
supported by the DNNLExecutionProvider.

The DNNLExecutionProvider::GetCapability function has been updated
to use the DnnlOpManager.

This commit adds the ability to check if data type, attributes,
and tensor dimensions of the node are supported.

The IsDimensionSupported function is no longer needed since the checks
it was doing have been moved into the individual implementations of
the virtual class DnnlNodeCapability.

Signed-off-by: George Nash <george.nash@intel.com>

* Fix AveragePool entry in the DnnlOpManager

Added check for ceil_mode attribute in the PoolNodeCapability
check.  DnnlExecutionProvider does not support ceil_mode other
than the default value.

Signed-off-by: George Nash <george.nash@intel.com>
This commit is contained in:
George Nash 2021-05-12 20:04:26 -07:00 committed by GitHub
parent dac24f7d63
commit b4e8e9b004
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 389 additions and 70 deletions

View file

@ -10,6 +10,7 @@
#include "subgraph/dnnl_func_kernel.h"
#include "dnnl_execution_provider.h"
#include "dnnl_fwd.h"
#include "dnnl_node_capability.h"
namespace onnxruntime {
@ -85,7 +86,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const {
continue;
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) {
FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos;
FP16_graph = node->InputDefs()[0]->Type()->find("float16") != std::string::npos;
break;
}
}
@ -96,8 +97,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const {
continue;
}
auto op_it = dnnl_ops_.find(node->OpType());
if (op_it != dnnl_ops_.end()) {
if (opManager_.IsOpTypeAvalible(node->OpType())) {
dnnl_nodes_in_the_graph = true;
break;
}
@ -248,7 +248,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
continue;
}
if (IsDimensionSupported(node) == false) {
if (!opManager_.IsNodeSupported(node)) {
node_index++;
if (subgraph_ptr->dnnl_nodes.size() > 0) {
CreateMetaDef(graph_viewer, *subgraph_attributes, subgraph_ptr, sub_var, result);
@ -259,8 +259,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
continue;
}
auto op_it = dnnl_ops_.find(node->OpType());
if (op_it != dnnl_ops_.end()) {
if (opManager_.IsOpTypeAvalible(node->OpType())) {
sub_var.subgraph_node_indexes.push_back(node->Index());
// can we fuse (at Dnnl level) nodes?
@ -299,8 +298,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
temp_index++;
next_node = graph_viewer.GetNode(temp_index);
}
auto sub_it = dnnl_ops_.find(next_node->OpType());
if (sub_it != dnnl_ops_.end()) {
if (opManager_.IsOpTypeAvalible(next_node->OpType())) {
const auto& next_node_inputs = next_node->InputDefs();
bool input_from_subgraph = true;
size_t inputs_count = 1;
@ -343,8 +341,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
}
// inner nodes. if inner nodes are not Dnnl nodes
// create subgraph (inception v2)
auto sub_it = dnnl_ops_.find(next_node->OpType());
if (sub_it == dnnl_ops_.end()) {
if (!opManager_.IsOpTypeAvalible(next_node->OpType())) {
// break and create a sub-graph
break_loop = true;
create_subgraph = true;

View file

@ -11,6 +11,7 @@
#include "core/platform/ort_mutex.h"
#include "core/providers/dnnl/subgraph/subgraph.h"
#include "core/platform/ort_mutex.h"
#include "dnnl_op_manager.h"
namespace dnnl {
struct memory;
@ -165,57 +166,6 @@ class DNNLExecutionProvider : public IExecutionProvider {
bool UseSubgraph(const GraphViewer& graph_viewer) const;
// Some dimensions are not supported by DNNL
// example: Pool with NumDimensions <= 3 is not supported
// Fall back to CPU implementation
bool IsDimensionSupported(const Node* node) const {
bool supported = true;
if (node->OpType() == "BatchNormalization") {
auto node_inputs = node->InputDefs();
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) {
supported = false;
}
}
if (node->OpType().find("Pool") != std::string::npos) {
auto node_inputs = node->InputDefs();
#ifdef ENABLE_TRAINING
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) {
#else
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) {
#endif // ENABLE_TRAINING
supported = false;
}
#ifdef ENABLE_TRAINING
if (node->OutputDefs().size() > 2)
supported = false;
#else
if (node->OutputDefs().size() > 1)
supported = false;
#endif // ENABLE_TRAINING
}
if (node->OpType().find("MatMul") != std::string::npos) {
auto node_inputs = node->InputDefs();
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
supported = true;
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
supported = false;
}
}
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
supported = false;
}
}
} else {
supported = false;
}
}
return supported;
}
void CreateOrUpdateDnnlNode(const Node* node,
std::shared_ptr<ort_dnnl::Subgraph>& subgraph_ptr,
@ -238,15 +188,8 @@ class DNNLExecutionProvider : public IExecutionProvider {
}
private:
// supported Dnnl Operators
#ifdef ENABLE_TRAINING
std::set<std::string> dnnl_ops_ = {"Conv", "ConvGrad", "BatchNormalization", "Relu", "ReluGrad", "Sum",
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "MaxPoolGrad", "LRN"};
#else
std::set<std::string> dnnl_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum",
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN", "MatMul"};
#endif // ENABLE_TRAINING
// DnnlOpManager contains information about supported Dnnl Operators
DnnlOpManager opManager_;
mutable std::unordered_map<std::string, std::shared_ptr<ort_dnnl::Subgraph>> mkl_subgraphs_;
};

View file

@ -0,0 +1,146 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#include "dnnl_node_capability.h"
namespace onnxruntime {
// DnnlDefaultNodeCapability class
//-------------------------------------
DnnlDefaultNodeCapability::DnnlDefaultNodeCapability() {
inputTypes_.push_back("float");
}
DnnlDefaultNodeCapability::DnnlDefaultNodeCapability(std::vector<std::string> inputTypes) {
for (std::string s : inputTypes)
inputTypes_.push_back(s);
}
bool DnnlDefaultNodeCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
return true;
}
bool DnnlDefaultNodeCapability::IsTypeSupported(const Node* node) const {
auto node_inputs = node->InputDefs();
if (!node_inputs.empty() && node_inputs[0]->Type() != nullptr) {
for (auto inputType : inputTypes_) {
if (node_inputs[0]->Type()->find(inputType) != std::string::npos) {
return true;
}
}
}
return false;
}
// DnnlPoolNodeCapability class
//-------------------------------------
bool DnnlPoolNodeCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
if (!IsAttributeSupported(node)) return false;
if (!IsDimensionSupported(node)) return false;
return true;
}
bool DnnlPoolNodeCapability::IsAttributeSupported(const Node* node) const {
const NodeAttributes& attributes = node->GetAttributes();
if (node->OpType() == "MaxPool") {
auto attr = attributes.find("dilations");
if (attr != attributes.end()) {
for (int i = 0; i < attr->second().ints_size(); ++i) {
if (attr->second().ints(i) > 1) {
return false;
}
}
}
}
auto attr = attributes.find("ceil_mode");
if (attr != attributes.end()) {
if (attr->second().i() != 0) {
return false;
}
}
return true;
}
bool DnnlPoolNodeCapability::IsDimensionSupported(const Node* node) const {
auto node_inputs = node->InputDefs();
#ifdef ENABLE_TRAINING
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) {
#else
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) {
#endif // ENABLE_TRAINING
return false;
}
#ifdef ENABLE_TRAINING
if (node->OutputDefs().size() > 2)
return false;
#else
if (node->OutputDefs().size() > 1)
return false;
#endif // ENABLE_TRAINING
return true;
}
// DnnlBatchNormalizationNodeCapability class
//-------------------------------------
bool DnnlBatchNormalizationNodeCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
if (!IsDimensionSupported(node)) return false;
return true;
}
bool DnnlBatchNormalizationNodeCapability::IsDimensionSupported(const Node* node) const {
auto node_inputs = node->InputDefs();
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) {
return false;
}
return true;
}
// DnnlReduceMeanNodeCapability class
//-------------------------------------
bool DnnlReduceMeanNodeCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
if (!IsAttributeSupported(node)) return false;
return true;
}
bool DnnlReduceMeanNodeCapability::IsAttributeSupported(const Node* node) const {
const NodeAttributes& attributes = node->GetAttributes();
auto attr = attributes.find("keepdims");
if (attr != attributes.end() && attr->second().i() == 0) {
return false;
}
return true;
}
// DnnlMatMulNodeCapability class
bool DnnlMatMulNodeCapability::Supported(const Node* node) const {
if (!IsTypeSupported(node)) return false;
if (!IsDimensionSupported(node)) return false;
return true;
}
bool DnnlMatMulNodeCapability::IsDimensionSupported(const Node* node) const {
auto node_inputs = node->InputDefs();
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
return false;
}
}
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
return false;
}
}
} else {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,149 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "core/providers/shared_library/provider_api.h"
namespace onnxruntime {
/**
* Pure virtual base class
*
* Individual implementations of this class are expected
* to implement the Supported() member function.
*
* The Supported() function is expected to use the contents
* of the onnxruntime::Node to decided if that node is
* supported in the DnnlExecutionProvider
*/
class DnnlNodeCapability {
public:
virtual ~DnnlNodeCapability(){};
/**
* virtual function expected to be implemented for different node
* types.
* @param node a onnxruntime::Node from the model
*
* @return true if the onnxRuntime::Node is supported in the
* DnnlExecutionProvider return false otherwise.
*/
virtual bool Supported(const Node* node) const = 0;
};
/**
* Default impelementation of the DnnlNodeCapability interface
* This class can be used if the only thing needed to
* decided if the we are capable of running the node using
* the DnnlExecutionProvider is the input data type.
*
* The default constructor assumes that "float" data
* type is supported and no other data types.
*
* To add additional data types an array of data types
* can be passed as strings i.e.
* `DnnlDefaultNodeCapability({"float", "int8"})`
* Would indicate that "float" and "int8" are supported.
*
* At this time the possible data types strings are:
* - "float"
* - "float16"
* - "bfloat16"
* - "double"
* - "int8"
* - "int16"
* - "int32"
* - "int64"
* - "uint8"
* - "uint16"
* - "uint32"
* - "uint64"
* - "complex64"
* - "complex128"
* - "string"
* - "bool"
*
* The strings are from the data_type_utils.cc
* TypesWrapper::TypesWrapper() member function. If a type
* is expected but not found in the above list see if it is
* assigned in the data_type_utils.cc file.
*
* This currently only checks the data type of input[0]. If
* this does not work for the Node then this class will need
* to be updated or another DnnlNodeCapability class will need to be
* implemented for the operator in question.
*/
class DnnlDefaultNodeCapability : public DnnlNodeCapability {
public:
DnnlDefaultNodeCapability();
DnnlDefaultNodeCapability(std::vector<std::string> inputTypes);
bool Supported(const Node* node) const override;
protected:
bool IsTypeSupported(const Node* node) const;
private:
std::vector<std::string> inputTypes_;
};
/**
* Decide if a Pool op is supported by DnnlExecutionProvider
*
* Dnnl does not support all dimension types for Pooling operators
* In addition the "dilations" attribute is not yet supported for
* MaxPool operator.
*/
class DnnlPoolNodeCapability : public DnnlDefaultNodeCapability {
public:
DnnlPoolNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
bool Supported(const Node* node) const override;
private:
bool IsAttributeSupported(const Node* node) const;
bool IsDimensionSupported(const Node* node) const;
};
/**
* Decide if a BatchNormalization op is supported by DnnlExecutionProvider
*/
class DnnlBatchNormalizationNodeCapability : public DnnlDefaultNodeCapability {
public:
DnnlBatchNormalizationNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
bool Supported(const Node* node) const override;
private:
bool IsDimensionSupported(const Node* node) const;
};
/**
* Decide if a ReduceMean op is supported by DnnlExecutionProvider
*
* Dnnl does not support the "keepdims" attribute when it is `0`
*/
class DnnlReduceMeanNodeCapability : public DnnlDefaultNodeCapability {
public:
DnnlReduceMeanNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
bool Supported(const Node* node) const override;
private:
bool IsAttributeSupported(const Node* node) const;
};
/**
* Decide if a MatMul op is supported by DnnlExecutionProvider
*/
class DnnlMatMulNodeCapability : public DnnlDefaultNodeCapability {
public:
DnnlMatMulNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
bool Supported(const Node* node) const override;
private:
bool IsDimensionSupported(const Node* node) const;
};
} // namespace onnxruntime

View file

@ -0,0 +1,38 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#include "dnnl_op_manager.h"
namespace onnxruntime {
DnnlOpManager::DnnlOpManager() {
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
#if defined(ENABLE_TRAINING)
// TODO re-enable ConvGrad currently there is a bug in the ConvGrad code bug was not known till after PR7083
//dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("ReluGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("MaxPoolGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
#endif // ENABLE_TRAINING
}
bool DnnlOpManager::IsNodeSupported(const Node* node) const {
auto it = dnnl_ops_map_.find(node->OpType());
if (it == dnnl_ops_map_.end()) {
return false;
}
return it->second->Supported(node);
}
bool DnnlOpManager::IsOpTypeAvalible(const std::string& opType) const {
auto op_it = dnnl_ops_map_.find(opType);
return (op_it != dnnl_ops_map_.end());
}
} // namespace onnxruntime

View file

@ -0,0 +1,46 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "dnnl_node_capability.h"
#include <map>
#include <string>
#include <memory>
namespace onnxruntime {
class DnnlOpManager {
public:
DnnlOpManager();
/**
* This will check if the ORT node is Supported by the DNNL execution provider
*
* Several things will be checked from the node
* - Is the OpType is regestered with the DNNL execution provider?
* - Are the tensor dimensions Supported by the DNNL execution provider
* - Are operator attributes Supported by the DNNL execution provider
*
* @param node the node that is being checked
*
* @return true if the node is Supported by the DNNL execution provider
* false is returned otherwise.
*/
bool IsNodeSupported(const Node* node) const;
/**
* Find out if the OpType is one of the OpTypes Supported by the DNNL execution provider
*
* This only looks at the OpType it does not look at other factors that may mean
* the operator is not Supported.
*
* @param opType the name of the operator i.e. "Add" or "Conv" etc.
*
* @return true is the OpType is one of those Supported by the DNNL execution provider
* false is returned otherwise.
*/
bool IsOpTypeAvalible(const std::string& opType) const;
private:
std::map<std::string, std::unique_ptr<DnnlNodeCapability>> dnnl_ops_map_;
};
} // namespace onnxruntime