diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index cdd9d4986d..6ede25fd17 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -10,6 +10,7 @@ #include "subgraph/dnnl_func_kernel.h" #include "dnnl_execution_provider.h" #include "dnnl_fwd.h" +#include "dnnl_node_capability.h" namespace onnxruntime { @@ -85,7 +86,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const { continue; if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) { - FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos; + FP16_graph = node->InputDefs()[0]->Type()->find("float16") != std::string::npos; break; } } @@ -96,8 +97,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const { continue; } - auto op_it = dnnl_ops_.find(node->OpType()); - if (op_it != dnnl_ops_.end()) { + if (opManager_.IsOpTypeAvalible(node->OpType())) { dnnl_nodes_in_the_graph = true; break; } @@ -248,7 +248,7 @@ std::vector> DNNLExecutionProvider::GetCapabi continue; } - if (IsDimensionSupported(node) == false) { + if (!opManager_.IsNodeSupported(node)) { node_index++; if (subgraph_ptr->dnnl_nodes.size() > 0) { CreateMetaDef(graph_viewer, *subgraph_attributes, subgraph_ptr, sub_var, result); @@ -259,8 +259,7 @@ std::vector> DNNLExecutionProvider::GetCapabi continue; } - auto op_it = dnnl_ops_.find(node->OpType()); - if (op_it != dnnl_ops_.end()) { + if (opManager_.IsOpTypeAvalible(node->OpType())) { sub_var.subgraph_node_indexes.push_back(node->Index()); // can we fuse (at Dnnl level) nodes? @@ -299,8 +298,7 @@ std::vector> DNNLExecutionProvider::GetCapabi temp_index++; next_node = graph_viewer.GetNode(temp_index); } - auto sub_it = dnnl_ops_.find(next_node->OpType()); - if (sub_it != dnnl_ops_.end()) { + if (opManager_.IsOpTypeAvalible(next_node->OpType())) { const auto& next_node_inputs = next_node->InputDefs(); bool input_from_subgraph = true; size_t inputs_count = 1; @@ -343,8 +341,7 @@ std::vector> DNNLExecutionProvider::GetCapabi } // inner nodes. if inner nodes are not Dnnl nodes // create subgraph (inception v2) - auto sub_it = dnnl_ops_.find(next_node->OpType()); - if (sub_it == dnnl_ops_.end()) { + if (!opManager_.IsOpTypeAvalible(next_node->OpType())) { // break and create a sub-graph break_loop = true; create_subgraph = true; diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 0178c3876e..524dff3183 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -11,6 +11,7 @@ #include "core/platform/ort_mutex.h" #include "core/providers/dnnl/subgraph/subgraph.h" #include "core/platform/ort_mutex.h" +#include "dnnl_op_manager.h" namespace dnnl { struct memory; @@ -165,57 +166,6 @@ class DNNLExecutionProvider : public IExecutionProvider { bool UseSubgraph(const GraphViewer& graph_viewer) const; - // Some dimensions are not supported by DNNL - // example: Pool with NumDimensions <= 3 is not supported - // Fall back to CPU implementation - bool IsDimensionSupported(const Node* node) const { - bool supported = true; - if (node->OpType() == "BatchNormalization") { - auto node_inputs = node->InputDefs(); - if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) { - supported = false; - } - } - if (node->OpType().find("Pool") != std::string::npos) { - auto node_inputs = node->InputDefs(); -#ifdef ENABLE_TRAINING - if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) { -#else - if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) { -#endif // ENABLE_TRAINING - supported = false; - } - -#ifdef ENABLE_TRAINING - if (node->OutputDefs().size() > 2) - supported = false; -#else - if (node->OutputDefs().size() > 1) - supported = false; -#endif // ENABLE_TRAINING - } - if (node->OpType().find("MatMul") != std::string::npos) { - auto node_inputs = node->InputDefs(); - if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) && - (node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) && - (node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) { - supported = true; - for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) { - if (utils::HasDimValue(dim) && dim.dim_value() == 0) { - supported = false; - } - } - for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) { - if (utils::HasDimValue(dim) && dim.dim_value() == 0) { - supported = false; - } - } - } else { - supported = false; - } - } - return supported; - } void CreateOrUpdateDnnlNode(const Node* node, std::shared_ptr& subgraph_ptr, @@ -238,15 +188,8 @@ class DNNLExecutionProvider : public IExecutionProvider { } private: -// supported Dnnl Operators -#ifdef ENABLE_TRAINING - std::set dnnl_ops_ = {"Conv", "ConvGrad", "BatchNormalization", "Relu", "ReluGrad", "Sum", - "AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "MaxPoolGrad", "LRN"}; -#else - std::set dnnl_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum", - "AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN", "MatMul"}; -#endif // ENABLE_TRAINING - + // DnnlOpManager contains information about supported Dnnl Operators + DnnlOpManager opManager_; mutable std::unordered_map> mkl_subgraphs_; }; diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc new file mode 100644 index 0000000000..47e6949ca0 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -0,0 +1,146 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_node_capability.h" + +namespace onnxruntime { +// DnnlDefaultNodeCapability class +//------------------------------------- +DnnlDefaultNodeCapability::DnnlDefaultNodeCapability() { + inputTypes_.push_back("float"); +} + +DnnlDefaultNodeCapability::DnnlDefaultNodeCapability(std::vector inputTypes) { + for (std::string s : inputTypes) + inputTypes_.push_back(s); +} + +bool DnnlDefaultNodeCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + return true; +} + +bool DnnlDefaultNodeCapability::IsTypeSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); + if (!node_inputs.empty() && node_inputs[0]->Type() != nullptr) { + for (auto inputType : inputTypes_) { + if (node_inputs[0]->Type()->find(inputType) != std::string::npos) { + return true; + } + } + } + return false; +} + +// DnnlPoolNodeCapability class +//------------------------------------- +bool DnnlPoolNodeCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsAttributeSupported(node)) return false; + if (!IsDimensionSupported(node)) return false; + return true; +} + +bool DnnlPoolNodeCapability::IsAttributeSupported(const Node* node) const { + const NodeAttributes& attributes = node->GetAttributes(); + if (node->OpType() == "MaxPool") { + auto attr = attributes.find("dilations"); + if (attr != attributes.end()) { + for (int i = 0; i < attr->second().ints_size(); ++i) { + if (attr->second().ints(i) > 1) { + return false; + } + } + } + } + auto attr = attributes.find("ceil_mode"); + if (attr != attributes.end()) { + if (attr->second().i() != 0) { + return false; + } + } + return true; +} + +bool DnnlPoolNodeCapability::IsDimensionSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); +#ifdef ENABLE_TRAINING + if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) { +#else + if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) { +#endif // ENABLE_TRAINING + return false; + } + +#ifdef ENABLE_TRAINING + if (node->OutputDefs().size() > 2) + return false; +#else + if (node->OutputDefs().size() > 1) + return false; +#endif // ENABLE_TRAINING + return true; +} + +// DnnlBatchNormalizationNodeCapability class +//------------------------------------- +bool DnnlBatchNormalizationNodeCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsDimensionSupported(node)) return false; + return true; +} + +bool DnnlBatchNormalizationNodeCapability::IsDimensionSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); + if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) { + return false; + } + return true; +} + +// DnnlReduceMeanNodeCapability class +//------------------------------------- +bool DnnlReduceMeanNodeCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsAttributeSupported(node)) return false; + return true; +} + +bool DnnlReduceMeanNodeCapability::IsAttributeSupported(const Node* node) const { + const NodeAttributes& attributes = node->GetAttributes(); + auto attr = attributes.find("keepdims"); + if (attr != attributes.end() && attr->second().i() == 0) { + return false; + } + return true; +} + +// DnnlMatMulNodeCapability class +bool DnnlMatMulNodeCapability::Supported(const Node* node) const { + if (!IsTypeSupported(node)) return false; + if (!IsDimensionSupported(node)) return false; + return true; +} + +bool DnnlMatMulNodeCapability::IsDimensionSupported(const Node* node) const { + auto node_inputs = node->InputDefs(); + if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) && + (node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) && + (node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) { + for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) { + if (utils::HasDimValue(dim) && dim.dim_value() == 0) { + return false; + } + } + for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) { + if (utils::HasDimValue(dim) && dim.dim_value() == 0) { + return false; + } + } + } else { + return false; + } + return true; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h new file mode 100644 index 0000000000..1c1ea0aaef --- /dev/null +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -0,0 +1,149 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { + +/** + * Pure virtual base class + * + * Individual implementations of this class are expected + * to implement the Supported() member function. + * + * The Supported() function is expected to use the contents + * of the onnxruntime::Node to decided if that node is + * supported in the DnnlExecutionProvider + */ +class DnnlNodeCapability { + public: + virtual ~DnnlNodeCapability(){}; + /** + * virtual function expected to be implemented for different node + * types. + * @param node a onnxruntime::Node from the model + * + * @return true if the onnxRuntime::Node is supported in the + * DnnlExecutionProvider return false otherwise. + */ + virtual bool Supported(const Node* node) const = 0; +}; + +/** + * Default impelementation of the DnnlNodeCapability interface + * This class can be used if the only thing needed to + * decided if the we are capable of running the node using + * the DnnlExecutionProvider is the input data type. + * + * The default constructor assumes that "float" data + * type is supported and no other data types. + * + * To add additional data types an array of data types + * can be passed as strings i.e. + * `DnnlDefaultNodeCapability({"float", "int8"})` + * Would indicate that "float" and "int8" are supported. + * + * At this time the possible data types strings are: + * - "float" + * - "float16" + * - "bfloat16" + * - "double" + * - "int8" + * - "int16" + * - "int32" + * - "int64" + * - "uint8" + * - "uint16" + * - "uint32" + * - "uint64" + * - "complex64" + * - "complex128" + * - "string" + * - "bool" + * + * The strings are from the data_type_utils.cc + * TypesWrapper::TypesWrapper() member function. If a type + * is expected but not found in the above list see if it is + * assigned in the data_type_utils.cc file. + * + * This currently only checks the data type of input[0]. If + * this does not work for the Node then this class will need + * to be updated or another DnnlNodeCapability class will need to be + * implemented for the operator in question. + */ +class DnnlDefaultNodeCapability : public DnnlNodeCapability { + public: + DnnlDefaultNodeCapability(); + DnnlDefaultNodeCapability(std::vector inputTypes); + + bool Supported(const Node* node) const override; + + protected: + bool IsTypeSupported(const Node* node) const; + + private: + std::vector inputTypes_; +}; + +/** + * Decide if a Pool op is supported by DnnlExecutionProvider + * + * Dnnl does not support all dimension types for Pooling operators + * In addition the "dilations" attribute is not yet supported for + * MaxPool operator. + */ +class DnnlPoolNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlPoolNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsAttributeSupported(const Node* node) const; + bool IsDimensionSupported(const Node* node) const; +}; + +/** + * Decide if a BatchNormalization op is supported by DnnlExecutionProvider + */ +class DnnlBatchNormalizationNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlBatchNormalizationNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsDimensionSupported(const Node* node) const; +}; + +/** + * Decide if a ReduceMean op is supported by DnnlExecutionProvider + * + * Dnnl does not support the "keepdims" attribute when it is `0` + */ +class DnnlReduceMeanNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlReduceMeanNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsAttributeSupported(const Node* node) const; +}; + +/** + * Decide if a MatMul op is supported by DnnlExecutionProvider + */ +class DnnlMatMulNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlMatMulNodeCapability() : DnnlDefaultNodeCapability({"float"}) {} + + bool Supported(const Node* node) const override; + + private: + bool IsDimensionSupported(const Node* node) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc new file mode 100644 index 0000000000..0aac1c8c86 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -0,0 +1,38 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_op_manager.h" + +namespace onnxruntime { +DnnlOpManager::DnnlOpManager() { + dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr(new DnnlMatMulNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr(new DnnlPoolNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr(new DnnlDefaultNodeCapability()))); +#if defined(ENABLE_TRAINING) + // TODO re-enable ConvGrad currently there is a bug in the ConvGrad code bug was not known till after PR7083 + //dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("ReluGrad", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("MaxPoolGrad", std::unique_ptr(new DnnlPoolNodeCapability()))); +#endif // ENABLE_TRAINING +} + +bool DnnlOpManager::IsNodeSupported(const Node* node) const { + auto it = dnnl_ops_map_.find(node->OpType()); + if (it == dnnl_ops_map_.end()) { + return false; + } + return it->second->Supported(node); +} + +bool DnnlOpManager::IsOpTypeAvalible(const std::string& opType) const { + auto op_it = dnnl_ops_map_.find(opType); + return (op_it != dnnl_ops_map_.end()); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.h b/onnxruntime/core/providers/dnnl/dnnl_op_manager.h new file mode 100644 index 0000000000..ffdf9ad9e0 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.h @@ -0,0 +1,46 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_node_capability.h" +#include +#include +#include + +namespace onnxruntime { +class DnnlOpManager { + public: + DnnlOpManager(); + + /** + * This will check if the ORT node is Supported by the DNNL execution provider + * + * Several things will be checked from the node + * - Is the OpType is regestered with the DNNL execution provider? + * - Are the tensor dimensions Supported by the DNNL execution provider + * - Are operator attributes Supported by the DNNL execution provider + * + * @param node the node that is being checked + * + * @return true if the node is Supported by the DNNL execution provider + * false is returned otherwise. + */ + bool IsNodeSupported(const Node* node) const; + + /** + * Find out if the OpType is one of the OpTypes Supported by the DNNL execution provider + * + * This only looks at the OpType it does not look at other factors that may mean + * the operator is not Supported. + * + * @param opType the name of the operator i.e. "Add" or "Conv" etc. + * + * @return true is the OpType is one of those Supported by the DNNL execution provider + * false is returned otherwise. + */ + bool IsOpTypeAvalible(const std::string& opType) const; + + private: + std::map> dnnl_ops_map_; +}; +} // namespace onnxruntime \ No newline at end of file