mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Add DnnlOpManager (#7521)
* Add DnnlOpManager The DnnlOpManager is able to more accurately check if a node is supported by the DNNLExecutionProvider. The DNNLExecutionProvider::GetCapability function has been updated to use the DnnlOpManager. This commit adds the ability to check if data type, attributes, and tensor dimensions of the node are supported. The IsDimensionSupported function is no longer needed since the checks it was doing have been moved into the individual implementations of the virtual class DnnlNodeCapability. Signed-off-by: George Nash <george.nash@intel.com> * Fix AveragePool entry in the DnnlOpManager Added check for ceil_mode attribute in the PoolNodeCapability check. DnnlExecutionProvider does not support ceil_mode other than the default value. Signed-off-by: George Nash <george.nash@intel.com>
This commit is contained in:
parent
dac24f7d63
commit
b4e8e9b004
6 changed files with 389 additions and 70 deletions
|
|
@ -10,6 +10,7 @@
|
|||
#include "subgraph/dnnl_func_kernel.h"
|
||||
#include "dnnl_execution_provider.h"
|
||||
#include "dnnl_fwd.h"
|
||||
#include "dnnl_node_capability.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -85,7 +86,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const {
|
|||
continue;
|
||||
|
||||
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) {
|
||||
FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos;
|
||||
FP16_graph = node->InputDefs()[0]->Type()->find("float16") != std::string::npos;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -96,8 +97,7 @@ bool DNNLExecutionProvider::UseSubgraph(const GraphViewer& graph_viewer) const {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto op_it = dnnl_ops_.find(node->OpType());
|
||||
if (op_it != dnnl_ops_.end()) {
|
||||
if (opManager_.IsOpTypeAvalible(node->OpType())) {
|
||||
dnnl_nodes_in_the_graph = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -248,7 +248,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
continue;
|
||||
}
|
||||
|
||||
if (IsDimensionSupported(node) == false) {
|
||||
if (!opManager_.IsNodeSupported(node)) {
|
||||
node_index++;
|
||||
if (subgraph_ptr->dnnl_nodes.size() > 0) {
|
||||
CreateMetaDef(graph_viewer, *subgraph_attributes, subgraph_ptr, sub_var, result);
|
||||
|
|
@ -259,8 +259,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
continue;
|
||||
}
|
||||
|
||||
auto op_it = dnnl_ops_.find(node->OpType());
|
||||
if (op_it != dnnl_ops_.end()) {
|
||||
if (opManager_.IsOpTypeAvalible(node->OpType())) {
|
||||
sub_var.subgraph_node_indexes.push_back(node->Index());
|
||||
|
||||
// can we fuse (at Dnnl level) nodes?
|
||||
|
|
@ -299,8 +298,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
temp_index++;
|
||||
next_node = graph_viewer.GetNode(temp_index);
|
||||
}
|
||||
auto sub_it = dnnl_ops_.find(next_node->OpType());
|
||||
if (sub_it != dnnl_ops_.end()) {
|
||||
if (opManager_.IsOpTypeAvalible(next_node->OpType())) {
|
||||
const auto& next_node_inputs = next_node->InputDefs();
|
||||
bool input_from_subgraph = true;
|
||||
size_t inputs_count = 1;
|
||||
|
|
@ -343,8 +341,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
}
|
||||
// inner nodes. if inner nodes are not Dnnl nodes
|
||||
// create subgraph (inception v2)
|
||||
auto sub_it = dnnl_ops_.find(next_node->OpType());
|
||||
if (sub_it == dnnl_ops_.end()) {
|
||||
if (!opManager_.IsOpTypeAvalible(next_node->OpType())) {
|
||||
// break and create a sub-graph
|
||||
break_loop = true;
|
||||
create_subgraph = true;
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/providers/dnnl/subgraph/subgraph.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "dnnl_op_manager.h"
|
||||
|
||||
namespace dnnl {
|
||||
struct memory;
|
||||
|
|
@ -165,57 +166,6 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
|
||||
bool UseSubgraph(const GraphViewer& graph_viewer) const;
|
||||
|
||||
// Some dimensions are not supported by DNNL
|
||||
// example: Pool with NumDimensions <= 3 is not supported
|
||||
// Fall back to CPU implementation
|
||||
bool IsDimensionSupported(const Node* node) const {
|
||||
bool supported = true;
|
||||
if (node->OpType() == "BatchNormalization") {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
if (node->OpType().find("Pool") != std::string::npos) {
|
||||
auto node_inputs = node->InputDefs();
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) {
|
||||
#else
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) {
|
||||
#endif // ENABLE_TRAINING
|
||||
supported = false;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (node->OutputDefs().size() > 2)
|
||||
supported = false;
|
||||
#else
|
||||
if (node->OutputDefs().size() > 1)
|
||||
supported = false;
|
||||
#endif // ENABLE_TRAINING
|
||||
}
|
||||
if (node->OpType().find("MatMul") != std::string::npos) {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
|
||||
supported = true;
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
void CreateOrUpdateDnnlNode(const Node* node,
|
||||
std::shared_ptr<ort_dnnl::Subgraph>& subgraph_ptr,
|
||||
|
|
@ -238,15 +188,8 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
}
|
||||
|
||||
private:
|
||||
// supported Dnnl Operators
|
||||
#ifdef ENABLE_TRAINING
|
||||
std::set<std::string> dnnl_ops_ = {"Conv", "ConvGrad", "BatchNormalization", "Relu", "ReluGrad", "Sum",
|
||||
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "MaxPoolGrad", "LRN"};
|
||||
#else
|
||||
std::set<std::string> dnnl_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum",
|
||||
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN", "MatMul"};
|
||||
#endif // ENABLE_TRAINING
|
||||
|
||||
// DnnlOpManager contains information about supported Dnnl Operators
|
||||
DnnlOpManager opManager_;
|
||||
mutable std::unordered_map<std::string, std::shared_ptr<ort_dnnl::Subgraph>> mkl_subgraphs_;
|
||||
};
|
||||
|
||||
|
|
|
|||
146
onnxruntime/core/providers/dnnl/dnnl_node_capability.cc
Normal file
146
onnxruntime/core/providers/dnnl/dnnl_node_capability.cc
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_node_capability.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// DnnlDefaultNodeCapability class
|
||||
//-------------------------------------
|
||||
DnnlDefaultNodeCapability::DnnlDefaultNodeCapability() {
|
||||
inputTypes_.push_back("float");
|
||||
}
|
||||
|
||||
DnnlDefaultNodeCapability::DnnlDefaultNodeCapability(std::vector<std::string> inputTypes) {
|
||||
for (std::string s : inputTypes)
|
||||
inputTypes_.push_back(s);
|
||||
}
|
||||
|
||||
bool DnnlDefaultNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlDefaultNodeCapability::IsTypeSupported(const Node* node) const {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if (!node_inputs.empty() && node_inputs[0]->Type() != nullptr) {
|
||||
for (auto inputType : inputTypes_) {
|
||||
if (node_inputs[0]->Type()->find(inputType) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// DnnlPoolNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlPoolNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsAttributeSupported(node)) return false;
|
||||
if (!IsDimensionSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlPoolNodeCapability::IsAttributeSupported(const Node* node) const {
|
||||
const NodeAttributes& attributes = node->GetAttributes();
|
||||
if (node->OpType() == "MaxPool") {
|
||||
auto attr = attributes.find("dilations");
|
||||
if (attr != attributes.end()) {
|
||||
for (int i = 0; i < attr->second().ints_size(); ++i) {
|
||||
if (attr->second().ints(i) > 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto attr = attributes.find("ceil_mode");
|
||||
if (attr != attributes.end()) {
|
||||
if (attr->second().i() != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlPoolNodeCapability::IsDimensionSupported(const Node* node) const {
|
||||
auto node_inputs = node->InputDefs();
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) {
|
||||
#else
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) {
|
||||
#endif // ENABLE_TRAINING
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (node->OutputDefs().size() > 2)
|
||||
return false;
|
||||
#else
|
||||
if (node->OutputDefs().size() > 1)
|
||||
return false;
|
||||
#endif // ENABLE_TRAINING
|
||||
return true;
|
||||
}
|
||||
|
||||
// DnnlBatchNormalizationNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlBatchNormalizationNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsDimensionSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlBatchNormalizationNodeCapability::IsDimensionSupported(const Node* node) const {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 3) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// DnnlReduceMeanNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlReduceMeanNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsAttributeSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlReduceMeanNodeCapability::IsAttributeSupported(const Node* node) const {
|
||||
const NodeAttributes& attributes = node->GetAttributes();
|
||||
auto attr = attributes.find("keepdims");
|
||||
if (attr != attributes.end() && attr->second().i() == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// DnnlMatMulNodeCapability class
|
||||
bool DnnlMatMulNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsDimensionSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlMatMulNodeCapability::IsDimensionSupported(const Node* node) const {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
149
onnxruntime/core/providers/dnnl/dnnl_node_capability.h
Normal file
149
onnxruntime/core/providers/dnnl/dnnl_node_capability.h
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* Pure virtual base class
|
||||
*
|
||||
* Individual implementations of this class are expected
|
||||
* to implement the Supported() member function.
|
||||
*
|
||||
* The Supported() function is expected to use the contents
|
||||
* of the onnxruntime::Node to decided if that node is
|
||||
* supported in the DnnlExecutionProvider
|
||||
*/
|
||||
class DnnlNodeCapability {
|
||||
public:
|
||||
virtual ~DnnlNodeCapability(){};
|
||||
/**
|
||||
* virtual function expected to be implemented for different node
|
||||
* types.
|
||||
* @param node a onnxruntime::Node from the model
|
||||
*
|
||||
* @return true if the onnxRuntime::Node is supported in the
|
||||
* DnnlExecutionProvider return false otherwise.
|
||||
*/
|
||||
virtual bool Supported(const Node* node) const = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* Default impelementation of the DnnlNodeCapability interface
|
||||
* This class can be used if the only thing needed to
|
||||
* decided if the we are capable of running the node using
|
||||
* the DnnlExecutionProvider is the input data type.
|
||||
*
|
||||
* The default constructor assumes that "float" data
|
||||
* type is supported and no other data types.
|
||||
*
|
||||
* To add additional data types an array of data types
|
||||
* can be passed as strings i.e.
|
||||
* `DnnlDefaultNodeCapability({"float", "int8"})`
|
||||
* Would indicate that "float" and "int8" are supported.
|
||||
*
|
||||
* At this time the possible data types strings are:
|
||||
* - "float"
|
||||
* - "float16"
|
||||
* - "bfloat16"
|
||||
* - "double"
|
||||
* - "int8"
|
||||
* - "int16"
|
||||
* - "int32"
|
||||
* - "int64"
|
||||
* - "uint8"
|
||||
* - "uint16"
|
||||
* - "uint32"
|
||||
* - "uint64"
|
||||
* - "complex64"
|
||||
* - "complex128"
|
||||
* - "string"
|
||||
* - "bool"
|
||||
*
|
||||
* The strings are from the data_type_utils.cc
|
||||
* TypesWrapper::TypesWrapper() member function. If a type
|
||||
* is expected but not found in the above list see if it is
|
||||
* assigned in the data_type_utils.cc file.
|
||||
*
|
||||
* This currently only checks the data type of input[0]. If
|
||||
* this does not work for the Node then this class will need
|
||||
* to be updated or another DnnlNodeCapability class will need to be
|
||||
* implemented for the operator in question.
|
||||
*/
|
||||
class DnnlDefaultNodeCapability : public DnnlNodeCapability {
|
||||
public:
|
||||
DnnlDefaultNodeCapability();
|
||||
DnnlDefaultNodeCapability(std::vector<std::string> inputTypes);
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
protected:
|
||||
bool IsTypeSupported(const Node* node) const;
|
||||
|
||||
private:
|
||||
std::vector<std::string> inputTypes_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a Pool op is supported by DnnlExecutionProvider
|
||||
*
|
||||
* Dnnl does not support all dimension types for Pooling operators
|
||||
* In addition the "dilations" attribute is not yet supported for
|
||||
* MaxPool operator.
|
||||
*/
|
||||
class DnnlPoolNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlPoolNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
bool IsAttributeSupported(const Node* node) const;
|
||||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a BatchNormalization op is supported by DnnlExecutionProvider
|
||||
*/
|
||||
class DnnlBatchNormalizationNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlBatchNormalizationNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a ReduceMean op is supported by DnnlExecutionProvider
|
||||
*
|
||||
* Dnnl does not support the "keepdims" attribute when it is `0`
|
||||
*/
|
||||
class DnnlReduceMeanNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlReduceMeanNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
bool IsAttributeSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a MatMul op is supported by DnnlExecutionProvider
|
||||
*/
|
||||
class DnnlMatMulNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlMatMulNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
38
onnxruntime/core/providers/dnnl/dnnl_op_manager.cc
Normal file
38
onnxruntime/core/providers/dnnl/dnnl_op_manager.cc
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_op_manager.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
DnnlOpManager::DnnlOpManager() {
|
||||
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
#if defined(ENABLE_TRAINING)
|
||||
// TODO re-enable ConvGrad currently there is a bug in the ConvGrad code bug was not known till after PR7083
|
||||
//dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("ReluGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MaxPoolGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
#endif // ENABLE_TRAINING
|
||||
}
|
||||
|
||||
bool DnnlOpManager::IsNodeSupported(const Node* node) const {
|
||||
auto it = dnnl_ops_map_.find(node->OpType());
|
||||
if (it == dnnl_ops_map_.end()) {
|
||||
return false;
|
||||
}
|
||||
return it->second->Supported(node);
|
||||
}
|
||||
|
||||
bool DnnlOpManager::IsOpTypeAvalible(const std::string& opType) const {
|
||||
auto op_it = dnnl_ops_map_.find(opType);
|
||||
return (op_it != dnnl_ops_map_.end());
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
46
onnxruntime/core/providers/dnnl/dnnl_op_manager.h
Normal file
46
onnxruntime/core/providers/dnnl/dnnl_op_manager.h
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "dnnl_node_capability.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace onnxruntime {
|
||||
class DnnlOpManager {
|
||||
public:
|
||||
DnnlOpManager();
|
||||
|
||||
/**
|
||||
* This will check if the ORT node is Supported by the DNNL execution provider
|
||||
*
|
||||
* Several things will be checked from the node
|
||||
* - Is the OpType is regestered with the DNNL execution provider?
|
||||
* - Are the tensor dimensions Supported by the DNNL execution provider
|
||||
* - Are operator attributes Supported by the DNNL execution provider
|
||||
*
|
||||
* @param node the node that is being checked
|
||||
*
|
||||
* @return true if the node is Supported by the DNNL execution provider
|
||||
* false is returned otherwise.
|
||||
*/
|
||||
bool IsNodeSupported(const Node* node) const;
|
||||
|
||||
/**
|
||||
* Find out if the OpType is one of the OpTypes Supported by the DNNL execution provider
|
||||
*
|
||||
* This only looks at the OpType it does not look at other factors that may mean
|
||||
* the operator is not Supported.
|
||||
*
|
||||
* @param opType the name of the operator i.e. "Add" or "Conv" etc.
|
||||
*
|
||||
* @return true is the OpType is one of those Supported by the DNNL execution provider
|
||||
* false is returned otherwise.
|
||||
*/
|
||||
bool IsOpTypeAvalible(const std::string& opType) const;
|
||||
|
||||
private:
|
||||
std::map<std::string, std::unique_ptr<DnnlNodeCapability>> dnnl_ops_map_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue