Add layout transformer for NNAPI (#10371)

* Add layout transformer for NNAPI

* plus merge fixes

* plus some more merge fixes

* test fixes

* comments + cleanup

* plus updates

* post merge changes

* enable layout transformer in extended minimal build

* plus more comments

* more tests + fix CI

* plus updates per review

* more updates per review

* fix file name

* fix qdq tests

* plus more updates

* plus updates

* typo fix

* fix qdq selection in 2nd optimization pass

* fix typo

* fix a test

* update dependency structure for layout transformer

* plus updates

* more updates

* plus change

* more updates to fix linker error in minimal build

* remove unnecessary headers
This commit is contained in:
Ashwini Khade 2022-02-15 20:25:29 -08:00 committed by GitHub
parent ceb1e2b1a6
commit f436d3437e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 1097 additions and 700 deletions

View file

@ -347,10 +347,11 @@ if (onnxruntime_MINIMAL_BUILD)
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
# enable EPs that compile kernels at runtime
add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD)
if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
endif()
# enable runtime optimizations. These are required for NNAPI.
# TODO remove onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD since these optimzations are now
# enabled for all extended minimal builds.
SET(onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD ON)
add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
endif()
if (onnxruntime_MINIMAL_BUILD_CUSTOM_OPS)

View file

@ -38,7 +38,6 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.cc"
)

View file

@ -14,7 +14,7 @@
#include "core/framework/tensor.h"
namespace onnxruntime {
enum class DataLayout;
class GraphViewer;
class Node;
struct ComputeCapability;
@ -274,6 +274,12 @@ class IExecutionProvider {
return {};
}
virtual DataLayout GetPreferredLayout() const {
// NCHW is the default ONNX standard data layout. So default to it.
// EPs which prefer a different layout should override to return their preferred layout.
return static_cast<DataLayout>(0);
}
private:
const std::string type_;
AllocatorMap allocators_;

View file

@ -15,6 +15,7 @@ constexpr const char* kMLDomain = "ai.onnx.ml";
constexpr const char* kMSDomain = "com.microsoft";
constexpr const char* kMSExperimentalDomain = "com.microsoft.experimental";
constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc";
constexpr const char* kMSInternalNHWCDomain = "com.ms.internal.nhwc";
constexpr const char* kMSDmlDomain = "com.microsoft.dml";
constexpr const char* kNGraphDomain = "com.intel.ai";
constexpr const char* kMIGraphXDomain = "";

View file

@ -151,6 +151,12 @@ class Node {
*/
int SinceVersion() const noexcept { return since_version_; }
/** Sets the since version (opset version that the Node's operator was first defined in.) for this node.
@remarks Used during layout transformation for setting since vesion for layout transformed nodes with
domain kMSNHWC.
*/
void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; }
#if !defined(ORT_MINIMAL_BUILD)
/** Gets the Node's OpSchema.
@remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */

View file

@ -19,6 +19,12 @@ struct NodeCompare {
bool operator()(const Node* n1, const Node* n2) const;
};
enum class DataLayout {
NCHW,
NHWC,
NCHWC,
};
/**
@class GraphViewer
Class that provides a read-only view of the Graph.

View file

@ -52,7 +52,97 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph
.Provider(provider_type);
}
/// <summary>
/// Validate all the layout sensitive nodes which were transformed for current EP are indeed taken by current EP.
/// If not, then we have a bug. If a node with domain kMSNHWC is left in the graph at this point then
/// graph.Resolve will fail.
/// Since layout transformation is only enabled for compile based EPs, just checking that graph does not contain
/// a node with kMSNHWC domain is enough. This is because after compile all the nodes which the EP claims are fused
/// into 1 and removed from the graph.
/// </summary>
/// <param name="graph">Graph to validate</param>
/// <returns></returns>
static Status ValidateGraphPartitioning(const Graph& graph) {
for (const auto& node : graph.Nodes()) {
if (node.Domain() == kMSInternalNHWCDomain) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Graph contains an invalid node: " + node.Name() + " Op Type: " + node.OpType() +
" with domain: " + kMSInternalNHWCDomain + ". These are temporary nodes added during layout transformations " +
" and are not expected to remain in the graph post partitioning. This is a bug in layout transformer.");
}
}
return Status::OK();
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
/// <summary>
/// Check if a node can be placed on a specific provider. If yes, then set the nodes execution provider.
/// Do nothing if the node is already assigned.
/// </summary>
/// <param name="graph">Graph in question.</param>
/// <param name="capability">Indexed subgraph which needs to be assigned</param>
/// <param name="provider_type">The EP to assign the Indexed subgraph to</param>
static void AssignNodes(Graph& graph, const IndexedSubGraph& capability,
const std::string& provider_type) {
// Before assigning the ep to any node, first walk through all the nodes and ensure
// none of the nodes have already been assigned. If a node is assigned, simply return.
for (auto node_index : capability.nodes) {
const auto* node = graph.GetNode(node_index);
if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
return;
}
}
for (auto node_index : capability.nodes) {
auto* node = graph.GetNode(node_index);
node->SetExecutionProviderType(provider_type);
}
}
#endif //! defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep,
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
TransformLayoutFunction transform_layout) {
{
GraphViewer graph_viewer(graph);
capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
// Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC
// CPU EP layout transformation happens later when level 3 transformers are run.
if (mode != GraphPartitioner::Mode::kAssignOnly &&
current_ep.GetPreferredLayout() == DataLayout::NHWC) {
for (auto& capability : capabilities) {
// in theory an EP could return an empty value...
if (!capability || !capability->sub_graph) {
continue;
}
AssignNodes(graph, *capability->sub_graph, current_ep.Type());
}
// Perform layout transformation on the specific EP assigned graph
bool modified = false;
ORT_RETURN_IF_ERROR(transform_layout(graph, modified, current_ep));
// It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes
// which are reconstructed to update domain or completly new nodes which are necessary for layout transformation.
// Therefore, we re-run GetCapability so that these new nodes can be processed by this EP.
if (modified) {
capabilities.clear();
GraphViewer graph_viewer(graph);
capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
return Status::OK();
}
#if !defined(ORT_MINIMAL_BUILD)
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
auto schema = node.Op();
builder.SetName(schema->Name())
@ -94,23 +184,23 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
// in order of EP priority
bool sub_graph_available_for_assignment = true;
for (auto node_index : capability.nodes) {
const auto* node = graph.GetNode(node_index);
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
// should never be on those lists.
//
// when the ORT format model is loaded we will process it normally with EP priority being applied for
// whichever EPs are enabled at the time.
//
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
// so we want all the nodes that either may take to be preserved. If we did not do this we would
// need to create one ORT format model for Android and one for iOS.
if (mode != GraphPartitioner::Mode::kAssignOnly) {
if (mode != GraphPartitioner::Mode::kAssignOnly) {
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
// should never be on those lists.
//
// when the ORT format model is loaded we will process it normally with EP priority being applied for
// whichever EPs are enabled at the time.
//
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
// so we want all the nodes that either may take to be preserved. If we did not do this we would
// need to create one ORT format model for Android and one for iOS.
for (auto node_index : capability.nodes) {
const auto* node = graph.GetNode(node_index);
if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
@ -165,7 +255,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep,
GraphPartitioner::Mode mode,
int& fused_node_unique_id) {
int& fused_node_unique_id,
TransformLayoutFunction transform_layout_function) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
@ -178,28 +269,32 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
Graph* subgraph = entry.second;
// we pass through the export_dll value and FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id));
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
transform_layout_function));
}
}
// If an execution provider return the capability that he could run a sub-graph,
// onnxruntime will fuse the sub-graph into a function node. if the execution provider
// says he need to compile the graph at runtime (by need_compile flag),
// onnxruntime will invoke the "Compile" method to get compiled binary.
// There are two mode of compile, one is return the entry point to the compiled binary
// directly, another is export the compiled binary to shared library for future reuse.
// If an execution provider returns the capability that it can run a sub-graph,
// onnxruntime will fuse the sub-graph into a function node. For compilation
// based execution providers (one which needs to compile graph at runtime.
// Indicated by need_compile flag), onnxruntime will invoke the "Compile" method
// to get compiled binary. There are two mode of compile, one is return the entry
// point to the compiled binary directly, another is export the compiled binary to
// shared library for future reuse.
// TODO: when the graph contain a function node, and user pass in the dll which could
// TODO: when the graph contains a function node, and user passes in the dll which could
// run the function by SessionOption, we should create a function kernel for it and
// delegate the compute to the functions inside the dlls.
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep, mode, capabilities,
transform_layout_function));
if (capabilities.empty()) {
return Status::OK();
}
const std::string& type = current_ep.Type();
auto fusion_style = current_ep.GetFusionStyle();
std::vector<Node*> nodes_to_compile;
GraphViewer graph_viewer(graph);
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
// nodes_to_compile.
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
@ -211,7 +306,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
}));
for (auto& capability : capabilities) {
if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value...
// in theory an EP could return an empty value...
if (!capability || !capability->sub_graph) {
continue;
}
@ -305,6 +401,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
}
}
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
}
// if this is the main graph call Resolve to put the Graph back into a guaranteed good state
@ -362,14 +460,16 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode,
int& fused_node_unique_id) const {
int& fused_node_unique_id,
TransformLayoutFunction transform_layout_function) const {
bool modified_graph = false;
do {
// process full graph with each EP
for (const auto& ep : providers_) {
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
fused_kernel_registry, *ep, mode, fused_node_unique_id));
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function));
}
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
@ -392,13 +492,15 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep,
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
int& fused_node_unique_id) {
int& fused_node_unique_id,
TransformLayoutFunction transform_layout_function) {
// recurse into nested graphs first to partition bottom up.
for (auto& node : graph.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry,
current_ep, compiled_kernel_hashes, fused_node_unique_id));
current_ep, compiled_kernel_hashes, fused_node_unique_id,
transform_layout_function));
}
}
@ -409,12 +511,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
}
const std::string& type = current_ep.Type();
GraphViewer graph_viewer(graph);
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep,
GraphPartitioner::Mode::kOrtFormatLoad, capabilities, transform_layout_function));
if (capabilities.empty()) {
return Status::OK();
}
@ -451,40 +551,37 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) {
Node& node = nodes_and_viewers[j].fused_node;
std::vector<NodeComputeInfo> single_node_compute_func;
auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func);
if (!status.IsOK()) {
// There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph
LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage();
graph.CancelFuseSubGraph(node);
} else {
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements");
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));
ORT_RETURN_IF_ERROR(current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func));
const auto& cur_capability = capabilities[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element.");
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));
KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
auto kernel_def = builder.Build();
const auto& cur_capability = capabilities[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
// save hash so SessionState can find the kernel. each kernel name should be unique
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
". Execution Provider must generate unique names across the entire model.");
}
KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
auto kernel_def = builder.Build();
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
KernelCreateInfo(std::move(kernel_def),
[](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
return FunctionKernel::Create(func_mgr, info, out);
})));
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
// save hash so SessionState can find the kernel. each kernel name should be unique
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
". Execution Provider must generate unique names across the entire model.");
}
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
KernelCreateInfo(std::move(kernel_def),
[](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
return FunctionKernel::Create(func_mgr, info, out);
})));
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
}
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
return Status::OK();
}
@ -495,7 +592,8 @@ Status GraphPartitioner::PartitionOrtFormatModel(
Graph& graph, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry,
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
int& fused_node_unique_id) const {
int& fused_node_unique_id,
TransformLayoutFunction transform_layout_function) const {
// process full graph with each EP
for (const auto& ep : providers_) {
if (ep->Type() == kCpuExecutionProvider) {
@ -505,13 +603,15 @@ Status GraphPartitioner::PartitionOrtFormatModel(
}
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry,
*ep, compiled_kernel_hashes, fused_node_unique_id));
*ep, compiled_kernel_hashes, fused_node_unique_id,
transform_layout_function));
}
return Status::OK();
}
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode,
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
TransformLayoutFunction transform_layout_function, Mode mode,
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
// 1. Execution providers' capabilities are checked one by one.
@ -535,16 +635,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
fused_node_unique_id));
fused_node_unique_id, transform_layout_function));
#else
ORT_UNUSED_PARAMETER(export_dll);
ORT_THROW("Not supported in this build.");
#endif
#endif //! defined(ORT_MINIMAL_BUILD)
} else {
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
fused_node_unique_id));
fused_node_unique_id, transform_layout_function));
}
if (!fused_kernel_registry->IsEmpty()) {

View file

@ -15,6 +15,7 @@ namespace onnxruntime {
class ExecutionProviders;
class KernelRegistry;
class KernelRegistryManager;
using TransformLayoutFunction = std::function<Status(Graph& graph, bool& modified, IExecutionProvider& current_ep)>;
class GraphPartitioner {
public:
@ -31,7 +32,8 @@ class GraphPartitioner {
}
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
TransformLayoutFunction transform_layout_function,
Mode mode = Mode::kNormal,
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes = nullptr) const;
@ -40,12 +42,13 @@ class GraphPartitioner {
#if !defined(ORT_MINIMAL_BUILD)
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const;
KernelRegistry& fused_kernel_registry, Mode mode,
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
#endif
Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry,
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
int& fused_node_unique_id) const;
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
KernelRegistryManager& kernel_registry_mgr_;
const ExecutionProviders& providers_;

View file

@ -16,6 +16,7 @@
#include "core/framework/session_state_flatbuffers_utils.h"
#include "core/framework/session_state_utils.h"
#include "core/framework/utils.h"
#include "core/framework/static_kernel_def_hashes.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
@ -601,7 +602,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span<const OrtValue>&
auto* node = graph_viewer_->GetNode(node_plan.node_index);
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
static_cast<int>(node->ImplicitInputDefs().size());
//allocate output
// allocate output
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
if (ml_value_idx == NodeIndexInfo::kInvalidEntry ||
@ -631,7 +632,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span<const OrtValue>&
}
}
//release nodes
// release nodes
for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) {
auto ml_value_idx = exe_plan->to_be_freed[index];
const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type;
@ -1015,15 +1016,22 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
// lookup the hashes for any nodes we compiled or added during graph partitioning.
// These node indexes for compiled nodes as well as newly added nodes are not in node_indices
// as they were created at runtime.
if (!compiled_kernel_hashes.empty()) {
for (const auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.count(node.Index()) == 0) {
for (const auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.count(node.Index()) == 0) {
if (node.Domain() == kOnnxDomain || node.Domain() == kOnnxDomainAlias) {
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to find kernel hash for node:", node.Name(), " optype:", node.OpType());
}
} else {
const auto hash_info = compiled_kernel_hashes.find(node.OpType());
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
"Unable to find compiled kernel hash for node '", node.Name(), "'.");
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second));
}
}
@ -1284,7 +1292,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map_, context, p_seq_exec_plan_));
//Record the allocation plan
// Record the allocation plan
// Uncomment the below to dump the allocation plan to std::cout
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
@ -1327,7 +1335,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
},
logger_, data_transfer_mgr_, *p_seq_exec_plan_.get(), session_options));
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
//Record Weight allocation info on device
// Record Weight allocation info on device
MemoryInfo::RecordInitializerAllocInfo(GetInitializedTensors());
#endif

View file

@ -0,0 +1,37 @@
#include "core/framework/static_kernel_def_hashes.h"
namespace onnxruntime {
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version) {
// Layout tranformer can add new nodes to the graph.
// Since layout transformation can happen in an extended build, if these nodes are not picked up and compiled by
// NNAPI or other compiling EPs then we need a way to get the hashes for these nodes. Since the infrastructure
// as well as op_schema required to generate these hashes is not available in an extended minimal build,
// we maintain a static map of nodes to hash value. This hash value can then be used to retireive the
// kernel for the given op.
static std::unordered_map<std::string, HashValue> static_kernel_hashes{
{"Transpose_1", 4324835766923221184ULL},
{"Transpose_13", 17267477159887372848ULL},
{"Squeeze_1", 12889825108950034784ULL},
{"Squeeze_11", 14725795030460042064ULL},
{"Squeeze_13", 16122603335179721968ULL},
{"UnSqueeze_1", 15964030255371555232ULL},
{"UnSqueeze_11", 16989589986691430224ULL},
{"UnSqueeze_13", 9466011545409597224ULL},
{"Gather_1", 625186873870077080ULL},
{"Gather_11", 11761559382112736008ULL},
{"Gather_13", 7462749543760614528ULL},
{"Identity_1", 18001636502361632792ULL},
{"Identity_13", 16879814636194901248ULL},
{"Identity_14", 16515685968327103576ULL},
{"Identity_16", 17661628575887109792ULL},
};
auto key = op_type + "_" + std::to_string(since_version);
auto iter = static_kernel_hashes.find(key);
if (iter != static_kernel_hashes.end()) {
return iter->second;
}
return std::nullopt;
}
}

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <unordered_map>
#include <string>
#include <optional>
#include "core/common/basic_types.h"
namespace onnxruntime {
/**
* @brief Gets the hash value for provided op type + version combination if it is available, otherwise
* returns a nullopt. The hash value is available if this node was added by layout transformer. For all other
* nodes, the hash values should be present either in the serialized session state obtained form ort format model
* or from compiled kernel hash map which is generated during partitioning.
* @return std::optional<HashValue>
*/
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version);
} // namespace onnxruntime

View file

@ -6,7 +6,7 @@
#include "core/optimizer/initializer.h"
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/api_impl.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
@ -55,7 +55,7 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
if (domain != kMSDomain) {
SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", "com.microsoft");
SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", kMSDomain);
}
modified = true;

View file

@ -111,7 +111,9 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
std::vector<NodeGroup> qdq_selections;
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto* node = graph_viewer.GetNode(index);
if (node->Domain() != kOnnxDomain) {
// post layout transformation all the layout sensitive nodes are converted to domain
// kMSInternalNHWCDomain. Therefore need to allow this domain as well.
if (node->Domain() != kOnnxDomain && node->Domain() != kMSInternalNHWCDomain) {
continue;
}

View file

@ -9,6 +9,7 @@
#include <string_view>
#include <unordered_map>
#include <vector>
#include <unordered_set>
namespace onnx_layout_transformation {
namespace api {
@ -204,7 +205,7 @@ class NodeRef {
}
std::string_view node_domain = Domain();
return node_domain == domain ||
((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx"));
((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx"));
}
/// <summary>
@ -217,6 +218,19 @@ class NodeRef {
return GetAttributeInt(name).value_or(default_value);
}
/// <summary>
/// Returns the Execution Provider assigned to this node. Any empty string means this node is
/// not assigned to any EP.
/// </summary>
/// <returns>EP type or empty string</returns>
virtual const std::string& GetExecutionProviderType() const = 0;
/// <summary>
/// Returns the schema since version for the op_type of this node. Value os -1 means it is not set.
/// </summary>
/// <returns>since version or default value -1</returns>
virtual int SinceVersion() const = 0;
virtual ~NodeRef(){};
};
@ -224,7 +238,6 @@ class NodeRef {
/// Information regarding the consumers of a value.
/// </summary>
struct ValueConsumers {
/// <summary>
/// List of nodes in the current graph containing value as an input
/// </summary>
@ -335,6 +348,14 @@ class GraphRef {
virtual std::unique_ptr<NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
size_t num_outputs, std::string_view domain = "") = 0;
/// <summary>
/// Creates a copy of the provided node in the graph with the specified op type and domain.
/// </summary>
/// <param name="op_type">The new node's op type</param>
/// <param name="domain">The new node's domain. Empty string signifies default onnx domain.</param>
/// <returns>The new node</returns>
virtual std::unique_ptr<NodeRef> CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain = "") = 0;
/// <summary>
/// Deletes a node from the graph. Behavior is undefined if node has any consumers.
/// </summary>
@ -409,6 +430,17 @@ class GraphRef {
constexpr int64_t kMinSupportedOpset = 7;
constexpr int64_t kMaxSupportedOpset = 15;
enum class OptimizerMode {
OPTIMIZE_TRANSPOSE, // simple transpose optimization
OPTIMIZE_LAYOUT_TRANSFORM // transpose optimization post layout transformation
};
/// <summary>
/// Gets a list of layout sensitive ops defined by ONNX standard.
/// </summary>
/// <returns>const reference to an unordered set of op_types which are layout sensitive</returns>
const std::unordered_set<std::string_view>& GetLayoutSensitiveOps();
/// <summary>
/// Performs transpose optimization on a graph. Returns true if the graph was modified.
///
@ -420,8 +452,16 @@ constexpr int64_t kMaxSupportedOpset = 15;
/// </summary>
/// <param name="graph">The graph to optimize (or a portion of a graph, see api::GraphRef docs)</param>
/// <param name="allow_extended_ops">Whether com.microsoft ops can be used for optimization</param>
/// <param name="provider_type">Execution provider if applicable.</param>
/// <param name="mode">Current mode. Optimizer can be called in the context of transpose optimizations or during layout transformations.</param>
/// <param name="layout_sensitive_ops">List of ops which are treated as layout sensitive by the ONNX standard as well as any runtime specific ops.
/// These ops should be provided when mode is set to OPTIMIZE_LAYOUT_TRANSFORM. If these ops are not provided, transpose optimizer may convert the
/// layout for these ops </param>
/// <returns>true if the graph was modified</returns>
bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
bool Optimize(api::GraphRef& graph, bool allow_extended_ops,
const std::string& provider_type = "",
OptimizerMode mode = OptimizerMode::OPTIMIZE_TRANSPOSE,
const std::unordered_set<std::string_view>& layout_sensitive_ops = {});
/* Layout Transformation Tools
* These methods help change the channel ordering of layout sensitive ops (like Conv). ONNX currently only supports
@ -429,14 +469,14 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
* the new ordering. The existence of a robust transpose optimizer means that we can freely add transpose ops during
* conversion and then call Optimize to remove as many as possible. To change the channel ordering of some/all ops
* in a model, a user of this tool should do the following:
*
*
* 1. Iterate over the graph nodes and identify nodes to convert. For each one:
* a. Change the op type and domain (and possibly attributes) to the op/contrib op with the desired ordering.
* b. The model is now invalid since the input tensors are in the original ordering (and all consumers
* expect the original ordering). Use WrapTransposesAroundNode helper to insert transposes around the
* inputs/outputs of the op to correct this.
* 2. The model is now correct but has many unnecessary Transpose ops. Call Optimize on the graph.
*
*
* After step 1, the Transpose ops will wrap converted ops in a similar manner to q/dq ops in quantization.
* The perm attributes essentially encode the information about which ops are being reordered.
*/
@ -444,13 +484,13 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
/// <summary>
/// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose
/// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change.
///
///
/// Ex:
/// * -> NhwcConv -> **
/// becomes
/// * -> Transpose -> NhwcConv -> Transpose -> **
/// Conv inputs/outputs have new shape. Shapes of * and ** are unchanged (carrying NCHW data).
///
///
/// input_perms/output_perms are matched with node inputs/outputs positionally. Their lengths must be at most equal to
/// the number of inputs/outputs, respectively. nullptr entires indicate an input or output should not be transposed.
/// </summary>

View file

@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/transpose_optimizer/api_impl.h"
#include "optimizer_api.h"
#include "optimizer_utils.h"
#include <deque>
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/execution_provider.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/cpu/tensor/transpose.h"
using namespace ONNX_NAMESPACE;
@ -15,7 +15,6 @@ using namespace ::onnxruntime::common;
using namespace onnx_layout_transformation;
namespace onnxruntime {
class ApiValueInfo final : public api::ValueInfoRef {
private:
NodeArg& node_arg_;
@ -84,6 +83,8 @@ class ApiNode final : public api::NodeRef {
void CopyAttributes(const api::NodeRef& node) override;
void ClearAttribute(std::string_view name) override;
void SetInput(size_t i, std::string_view name) override;
const std::string& GetExecutionProviderType() const override;
virtual int SinceVersion() const override;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode);
@ -114,6 +115,9 @@ class ApiGraph final : public api::GraphRef {
void ReshapeInitializer(std::string_view name, const std::vector<int64_t>& shape) override;
std::unique_ptr<api::NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
size_t num_outputs = 1, std::string_view domain = "") override;
std::unique_ptr<api::NodeRef> CopyNode(const api::NodeRef& source_node, std::string_view op_type,
std::string_view domain = "") override;
void RemoveNode(api::NodeRef& node) override;
void RemoveInitializer(std::string_view name) override;
std::string_view AddInitializer(api::DataType dtype, const std::vector<int64_t>& shape,
@ -377,6 +381,15 @@ void ApiNode::SetInput(size_t i, std::string_view name) {
}
}
}
const std::string& ApiNode::GetExecutionProviderType() const {
return node_.GetExecutionProviderType();
}
int ApiNode::SinceVersion() const {
return node_.SinceVersion();
}
// </ApiNode>
std::optional<int64_t> ApiGraph::Opset(std::string_view domain) const {
@ -562,11 +575,11 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vector<int64
node_arg->SetShape(new_shape);
}
std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
const std::vector<std::string_view>& inputs, size_t num_outputs,
std::string_view domain) {
static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type,
const std::vector<std::string_view>& inputs, size_t num_outputs,
std::string_view domain, int since_version, std::string_view node_ep) {
const std::string op_type_str(op_type);
std::string name = graph_.GenerateNodeName(op_type_str);
std::string name = graph.GenerateNodeName(op_type_str);
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
@ -574,48 +587,107 @@ std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
for (const auto& input : inputs) {
NodeArg* arg;
if (input == "") {
arg = &graph_.GetOrCreateNodeArg("", nullptr);
arg = &graph.GetOrCreateNodeArg("", nullptr);
} else {
arg = graph_.GetNodeArg(std::string(input));
arg = graph.GetNodeArg(std::string(input));
}
input_args.push_back(arg);
}
output_args.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
std::string output = graph_.GenerateNodeArgName(name + "_out" + std::to_string(i));
NodeArg* arg = &graph_.GetOrCreateNodeArg(output, nullptr);
std::string output = graph.GenerateNodeArgName(name + "_out" + std::to_string(i));
NodeArg* arg = &graph.GetOrCreateNodeArg(output, nullptr);
output_args.push_back(arg);
}
std::vector<NodeArg*> outputs;
Node& node = graph_.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr,
std::string(domain));
Node& node = graph.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr,
std::string(domain));
if (new_node_ep_ != nullptr) {
node.SetExecutionProviderType(new_node_ep_);
if (node.SinceVersion() == -1) {
node.SetSinceVersion(since_version);
}
node.SetExecutionProviderType(std::string(node_ep));
for (size_t i = 0; i < input_args.size(); ++i) {
NodeArg* arg = input_args[i];
if (arg->Exists()) {
const std::string& name_str = arg->Name();
graph_.AddConsumerNode(name_str, &node);
const auto* inp_node = graph_.GetProducerNode(name_str);
graph.AddConsumerNode(name_str, &node);
const auto* inp_node = graph.GetProducerNode(name_str);
if (inp_node != nullptr) {
int inp_node_out_index = graph_utils::GetNodeOutputIndexFromOutputName(*inp_node, name_str);
graph_.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast<int>(i));
graph.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast<int>(i));
}
}
}
for (NodeArg* arg : output_args) {
graph_.UpdateProducerNode(arg->Name(), node.Index());
graph.UpdateProducerNode(arg->Name(), node.Index());
}
return node;
}
// This is a list of onnx ops and their versions which transpose_optimizer can potentially add to the graph.
// This is needed in minimal build since opschema is not available.
// The versions MUST be sorted due to how the model opset is matched with the most recent operator version.
static const std::unordered_map<std::string, std::vector<int>> onnx_ops_available_versions = {
{"Squeeze", {1, 11, 13}},
{"Unsqueeze", {1, 11, 13}},
{"Gather", {1, 11, 13}},
{"Transpose", {1, 13}},
{"Identity", {1, 13, 14, 16}},
};
// Based on the opset version imported for this model, returns the since version for the node.
static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain,
const std::unordered_map<std::string, int>& domain_to_version_map) {
int since_version = -1;
ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ",
domain, " provided for op: ", op_type);
auto opset_import_iter = domain_to_version_map.find(std::string(domain));
ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), "Onnx domain not found in opset imports.");
int opset_version = opset_import_iter->second;
auto iter = onnx_ops_available_versions.find(std::string(op_type));
ORT_ENFORCE(iter != onnx_ops_available_versions.end(),
"Transpose Optimizer is adding an unexpected node: ", op_type,
"An entry for this node should be added in onnx_ops_available_versions and static_kernel_hashes map.");
for (auto version : iter->second) {
if (version <= opset_version) {
since_version = version;
}
}
return since_version;
}
std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
const std::vector<std::string_view>& inputs, size_t num_outputs,
std::string_view domain) {
int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap());
Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs,
domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : "");
return std::make_unique<ApiNode>(node, graph_);
}
std::unique_ptr<api::NodeRef> ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type,
std::string_view domain) {
Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(),
source_node.Outputs().size(), domain, source_node.SinceVersion(), source_node.GetExecutionProviderType());
std::unique_ptr<api::NodeRef> new_node = std::make_unique<ApiNode>(node, graph_);
new_node->CopyAttributes(source_node);
return new_node;
}
void ApiGraph::RemoveNode(api::NodeRef& node) {
Node& ort_node = static_cast<ApiNode&>(node).Node();
for (const auto* node_arg : ort_node.InputDefs()) {
@ -705,6 +777,10 @@ std::unique_ptr<api::GraphRef> MakeApiGraph(onnxruntime::Graph& graph, Allocator
return std::make_unique<ApiGraph>(graph, std::move(cpu_allocator), new_node_ep);
}
std::unique_ptr<api::NodeRef> MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node) {
return std::make_unique<ApiNode>(node, graph);
}
onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef& graph) {
return static_cast<ApiGraph&>(graph).Graph();
}
@ -713,4 +789,99 @@ onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& nod
return static_cast<ApiNode&>(node).Node();
}
namespace layout_transformer {
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps() {
static std::unordered_set<std::string_view> ort_layout_senstive_ops = []() {
const auto& layout_sensitive_ops = onnx_layout_transformation::GetLayoutSensitiveOps();
std::unordered_set<std::string_view> ort_specific_ops = {"Resize", "FusedConv", "QLinearAveragePool", "QLinearGlobalAveragePool"};
ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend());
return ort_specific_ops;
}();
return ort_layout_senstive_ops;
}
Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider) {
// sub graph recurse will be added later
auto api_graph = MakeApiGraph(graph, execution_provider.GetAllocator(0, OrtMemTypeDefault), nullptr);
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
for (auto& node : api_graph->Nodes()) {
if (layout_sensitive_ops.count(node->OpType())) {
if (node->GetExecutionProviderType() != execution_provider.Type()) {
continue;
}
auto domain = node->Domain();
// Skip if domain is incorrect
if (domain != kOnnxDomain && domain != kOnnxDomainAlias && domain != kMSDomain) {
continue;
}
// if already transformed then change the domain to kMSInternalNHWCDomain this way the EP
// knows this op is in the expected format.
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
// Changing the domain for the node requires creating a new node and replacing the old one
// therefore set the modified flag.
modified = true;
continue;
}
// Skip if unknown rank
auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape();
if (!shape.has_value()) {
continue;
}
// Convert to channels last
size_t rank = shape->size();
bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false;
if (has_channel_last_attr) {
node->SetAttributeInt("channels_last", 1);
}
auto input_perm = onnx_layout_transformation::ChannelFirstToLastPerm(rank);
auto output_perm = onnx_layout_transformation::ChannelLastToFirstPerm(rank);
// Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation
// for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout
// transformer only converts layout for 0th input, weights should be handled by every EP.
if (node->OpType() == "Resize") {
// Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case,
// we need to jump a few extra hoops to make sure its inputs are correctly handled. Current code skips
// layout conversion for ROI because it needs special handling as ROI size is 2*rank.
// Enable passing in ROI for layout conversion when an EP which supports ROI starts using layout transformer.
// NNAPI which currently uses layout transformer does not support it.
std::vector<const std::vector<int64_t>*> input_perms{&input_perm, nullptr};
for (size_t i = 2; i < node->Inputs().size(); i++) {
auto constant = api_graph->GetConstant(node->Inputs()[i]);
if (constant != nullptr && constant->Data().size() > 0) {
input_perms.push_back(&input_perm);
} else {
input_perms.push_back(nullptr);
}
}
onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm});
} else {
onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
}
onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
modified = true;
}
}
if (modified) {
onnx_layout_transformation::Optimize(*api_graph, /*allow_extended_ops*/ true, execution_provider.Type(),
onnx_layout_transformation::OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM,
layout_sensitive_ops);
}
return Status::OK();
}
} // namespace layout_transformer
} // namespace onnxruntime

View file

@ -3,15 +3,11 @@
#pragma once
#include "core/framework/execution_provider.h"
#include "optimizer_api.h"
#include "core/graph/graph.h"
#include "core/optimizer/transpose_optimizer/api.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
#include "core/framework/execution_provider.h"
namespace onnxruntime {
/// <summary>
/// Creates concrete implementation of api for transpose optimizer. IMPORTANT: graph must have up-to-date edges,
/// node_arg-to-producer, and node_arg-to-consumer relationships. Otherwise call Resolve() before this.
@ -24,6 +20,14 @@ std::unique_ptr<onnx_layout_transformation::api::GraphRef> MakeApiGraph(onnxrunt
AllocatorPtr cpu_allocator,
const char* new_node_ep);
/// <summary>
/// Creates NodeRef.
/// </summary>
/// <param name="graph">ORT Graph which owns the node</param>
/// <param name="node">ORT Node to wrap with API.</param>
/// <returns>api::NodeRef for use with transpose optimizer</returns>
std::unique_ptr<onnx_layout_transformation::api::NodeRef> MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node);
/// <summary>
/// Reveals underlying ORT graph from an api::GraphRef
/// </summary>
@ -38,4 +42,25 @@ onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef&
/// <returns>ORT node</returns>
onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& node);
namespace layout_transformer {
/// <summary>
/// Gets a list of layout sensitive ops for ORT. This list contains onnx standard defined
/// layout senstive ops + contrib ops + ops which are not layout sensitive but are treated as
/// layout sensitive by ORT EPs (exmaple Resize).
/// </summary>
/// <returns>unordered set of op_types which are layout sensitive</returns>
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps();
/// <summary>
/// Transforms data layout from NCHW to NHWC. Applies transforms to layout sensitive nodes
/// assigned to execution_provider provided by the caller and any other non-layout sensitive
/// nodes in order to optimize the transposes as much as possible.
/// </summary>
/// <param name="graph">graph to transform</param>
/// <param name="modified">indicates whether the graph is modified during transformation</param>
/// <param name="execution_provider">execution provider for which the transformation needs to be performed</param>
/// <returns></returns>
Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider);
} // namespace layout_transformer
} // namespace onnxruntime

View file

@ -1,14 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
#include "ort_transpose_optimizer.h"
#include <deque>
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/api_impl.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "optimizer_utils.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;

View file

@ -1,13 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "api.h"
#include "optimizer_api.h"
#include <algorithm>
#include <gsl/gsl>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <cstring>
namespace onnx_layout_transformation {
@ -16,6 +17,9 @@ struct OptimizerCtx {
api::GraphRef& graph;
bool allow_extended_ops;
bool skip_cost_check;
const std::string provider_type;
OptimizerMode mode;
std::unordered_set<std::string_view> layout_sensitive_ops;
};
// Each op handler points to a (potentially shared) function for determining which input indices are eligible for
@ -44,7 +48,6 @@ struct HandlerInfo {
bool transposes_outputs = true;
};
/////// <Helper Utils> ///////
/* Small utilities for editing nodes and manipulating axes/permutations */
@ -63,14 +66,14 @@ static std::vector<int32_t> DataInt32(api::TensorRef& tensor) {
}
static std::string_view AddInitializerInt64(api::GraphRef& graph, const std::vector<int64_t>& shape,
const std::vector<int64_t>& values) {
const std::vector<int64_t>& values) {
const uint8_t* raw_data = reinterpret_cast<const uint8_t*>(values.data());
std::vector<uint8_t> data(raw_data, raw_data + values.size() * sizeof(int64_t));
return graph.AddInitializer(api::DataType::INT64, shape, data);
}
static std::string_view AddInitializerInt32(api::GraphRef& graph, const std::vector<int64_t>& shape,
const std::vector<int32_t>& values) {
const std::vector<int32_t>& values) {
const uint8_t* raw_data = reinterpret_cast<const uint8_t*>(values.data());
std::vector<uint8_t> data(raw_data, raw_data + values.size() * sizeof(int32_t));
return graph.AddInitializer(api::DataType::INT32, shape, data);
@ -108,7 +111,7 @@ static std::unique_ptr<api::NodeRef> MakeTranspose(api::GraphRef& graph, std::st
}
// Creates a Squeeze/Unsqueeze node. Does not update output ValueInfo.
static std::unique_ptr<api::NodeRef> MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph,
static std::unique_ptr<api::NodeRef> MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph,
std::string_view op_type, std::string_view input,
const std::vector<int64_t>& axes) {
if (opset < 13) {
@ -141,7 +144,7 @@ static bool IsValidPerm(const std::vector<int64_t>& perm) {
static std::optional<std::vector<int64_t>> GetPermAttrIfValid(const api::NodeRef& node) {
std::optional<std::vector<int64_t>> perm = node.GetAttributeInts("perm");
if (perm != std::nullopt && !IsValidPerm(*perm)) {
if (perm.has_value() && !IsValidPerm(*perm)) {
return std::nullopt;
}
return perm;
@ -224,8 +227,12 @@ static bool IsIdentityPerm(const std::vector<int64_t>& perm) {
}
// Computes permutation from channel last to channel first ordering of given rank. Nearly all handlers work for any
// permutation, but some are restricted. Also used for layout transformation. Rank must be >= 1.
// permutation, but some are restricted. Also used for layout transformation.
std::vector<int64_t> ChannelLastToFirstPerm(size_t rank) {
if (rank < 2) {
return {};
}
std::vector<int64_t> p(rank);
p[0] = 0;
p[1] = rank - 1;
@ -330,9 +337,9 @@ static std::vector<int64_t> SqueezePerm(const std::vector<int64_t>& axes, const
return new_perm;
}
// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or
// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or
// have negative values.
//
//
// Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [2, 0]
static std::vector<int64_t> AxesForTransposedInput(const std::vector<int64_t>& axes,
const std::vector<int64_t>& perm) {
@ -346,7 +353,7 @@ static std::vector<int64_t> AxesForTransposedInput(const std::vector<int64_t>& a
// Computes a new axes attribute for an input that has been permuted using perm and sorts the result. Axes attributes
// are commonly sorted (unless order matters like in Slice). Unsafe if axes/perm are invalid or have negative values.
//
//
// Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [0, 2]
static std::vector<int64_t> SortedAxesForTransposedInput(const std::vector<int64_t>& axes,
const std::vector<int64_t>& perm) {
@ -376,10 +383,8 @@ static std::vector<int64_t> SortedAxesForTransposedInput(const std::vector<int64
/////// <Core Helpers> ///////
/* These helpers hide the most gnarly parts of the transpose optimizer. */
static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector<int64_t>& axes);
// Replaces ith input to node with unsqueezed value. Might create a new Unsqueeze node, find an existing one,
// or reshape an initializer. Unsqueezing can be necessary before transposing inputs of a node that supports
// broadcasting.
@ -442,7 +447,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
// a Transpose, optimize it here.
if (inp_node != nullptr && inp_node->IsOp("Transpose")) {
auto perm = GetPermAttrIfValid(*inp_node);
if (perm != std::nullopt) {
if (perm.has_value()) {
auto perm_inv = InvertPerm(*perm);
std::vector<size_t> indices = {0};
HandlerArgs args{ctx, *inp_node, unsqueeze, *perm, perm_inv, indices};
@ -456,6 +461,29 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
node.SetInput(i, unsq_out);
}
static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::TensorRef& constant,
size_t i, std::string_view input_name, const std::vector<int64_t>& perm) {
// Create new transposed initializer
auto rank = perm.size();
auto shape = constant.Shape();
std::vector<uint8_t> data = constant.Data();
std::vector<uint8_t> new_data(data.size());
size_t bytes_per_val = data.size() / rank;
uint8_t* dst = new_data.data();
for (size_t j = 0; j < rank; ++j) {
uint8_t* src = data.data() + perm[j] * bytes_per_val;
std::memcpy(dst, src, bytes_per_val);
dst += bytes_per_val;
}
std::string_view new_initializer = graph.AddInitializer(constant.DType(), shape, new_data);
node.SetInput(i, new_initializer);
if (!graph.HasValueConsumers(input_name)) {
graph.RemoveInitializer(input_name);
}
}
// Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one,
// or transpose an initializer.
static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
@ -469,6 +497,18 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
// Case 1: input is a constant with a known list of consumer nodes
if (constant != nullptr && consumers->comprehensive) {
// Input is scalar, return early.
if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) {
return;
}
// This is a special case where the constant is 1D with length == perm.
// TODO: TransposeInitializer should be updated to handle this case.
// Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if
// there are no other consumers.
if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast<int64_t>(perm.size())) {
Permute1DConstant(graph, node, *constant, i, input, perm);
return;
}
if (consumers->nodes.size() > 0) {
// Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv
// to counteract the effect. These Transposes will hopefully be optimized out later.
@ -496,6 +536,10 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
}
node.SetInput(i, pre_transpose_value);
return;
} else if (*perm2 == perm) {
// we are trying to add a duplicate transpose.
// do nothing and return
return;
}
// Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove
@ -513,7 +557,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
return;
}
}
// Case 3: A Transpose op might already exist
for (size_t j = 0; j < consumers->nodes.size(); ++j) {
api::NodeRef& consumer = *consumers->nodes[j];
@ -533,7 +577,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
}
// Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank.
static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank,
static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank,
const std::vector<size_t>& input_indices) {
auto inputs = node.Inputs();
@ -563,7 +607,7 @@ static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t tar
}
// Transposes specified inputs according to perm.
// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it.
// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it.
static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector<int64_t>& perm,
const std::vector<size_t>& input_indices) {
auto perm_inv = InvertPerm(perm);
@ -573,7 +617,7 @@ static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::ve
}
inline static void TransposeFirstInput(OptimizerCtx& ctx, api::NodeRef& node, const std::vector<int64_t>& perm) {
std::vector<size_t> indices {0};
std::vector<size_t> indices{0};
TransposeInputs(ctx, node, perm, indices);
}
@ -630,7 +674,7 @@ static void TransposeOutputs(OptimizerCtx& ctx, api::NodeRef& node, const std::v
// A rank of 5 is used if rank cannot be determined since 5 is the largest rank we expect from something like a Conv
// and an unknown rank likely corresponds to a data-carrying (non-weight) tensor, which will be large.
// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown.
// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown.
static int EstimateValueRank(api::GraphRef& graph, std::string_view input) {
auto value_info = graph.GetValueInfo(input);
std::optional<std::vector<int64_t>> shape = value_info->Shape();
@ -774,7 +818,7 @@ std::vector<size_t> NonScalarInputs(OptimizerCtx& ctx, api::NodeRef& node) {
constexpr HandlerInfo broadcast_node_handler = {&NonScalarInputs, &HandleSimpleNodeBroadcast};
// Transposes all inputs and all outputs. Updates axis attribute.
static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis=std::nullopt) {
static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis = std::nullopt) {
size_t rank = args.perm.size();
std::optional<int64_t> axis = args.node.GetAttributeInt("axis");
if (axis == std::nullopt) {
@ -848,7 +892,6 @@ static bool HandleShape(HandlerArgs& args) {
std::vector<int64_t> new_perm;
// For opset 15, Shape(Transpose(x, perm))[starts:stops] = Gather(Shape(x), perm[starts:stops])
if (args.ctx.opset >= 15) {
// Assign new_perm = perm[starts:stops]
int64_t start = args.node.GetAttributeIntDefault("start", 0);
int64_t end = args.node.GetAttributeIntDefault("end", rank_int);
@ -870,7 +913,7 @@ static bool HandleShape(HandlerArgs& args) {
}
// Make new_perm initializer
std::vector<int64_t> perm_shape {gsl::narrow_cast<int64_t>(new_perm.size())};
std::vector<int64_t> perm_shape{gsl::narrow_cast<int64_t>(new_perm.size())};
std::string_view perm_const = AddInitializerInt64(args.ctx.graph, perm_shape, new_perm);
// X -> Shape -> Y, Gather
@ -900,7 +943,7 @@ static bool HandleShape(HandlerArgs& args) {
constexpr HandlerInfo shape_handler = {&FirstInput, &HandleShape, /*transposes_outputs*/ false};
// Permutes a 1D node input by creating a new initializer or inserting a Gather op
void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector<int64_t>& perm) {
static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector<int64_t>& perm) {
size_t rank = perm.size();
int64_t rank_int = gsl::narrow_cast<int64_t>(rank);
@ -909,25 +952,7 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std:
if (constant != nullptr) {
auto shape = constant->Shape();
if (shape.size() == 1 && (shape[0] == rank_int || shape[0] == 0)) {
// Create new transposed initializer
std::vector<uint8_t> data = constant->Data();
std::vector<uint8_t> new_data(data.size());
size_t bytes_per_val = data.size() / rank;
uint8_t* dst = new_data.data();
for (size_t j = 0; j < rank; ++j) {
uint8_t* src = data.data() + perm[j] * bytes_per_val;
for (size_t k = 0; k < bytes_per_val; ++k) {
*dst++ = *src++;
}
}
std::string_view new_initializer = graph.AddInitializer(constant->DType(), shape, new_data);
node.SetInput(i, new_initializer);
if (!graph.HasValueConsumers(input)) {
graph.RemoveInitializer(input);
}
Permute1DConstant(graph, node, *constant, i, input, perm);
return;
}
}
@ -942,33 +967,39 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std:
node.SetInput(i, gather_output);
}
//static bool HandleResize(HandlerArgs& args) {
// static bool HandleResize(HandlerArgs& args) {
// auto inputs = args.node.Inputs();
// int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
//
// auto p = ChannelFirstToLastPerm(rank_int);
// auto& perm = p == args.perm ? args.perm : args.perm_inv;
// auto& perm_inv = p == args.perm ? args.perm_inv : args.perm;
//
// if (args.ctx.opset < 11) {
// PermuteInput(args.ctx.graph, args.node, 1, args.perm_inv);
// } else {
// if (inputs[1] != "") {
// std::vector<int64_t> double_perm_inv = args.perm_inv;
// double_perm_inv.reserve(2 * args.perm_inv.size());
// for (int64_t p : args.perm_inv) {
// double_perm_inv.push_back(p + rank_int);
// }
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
// }
// for (size_t i = 2; i < inputs.size(); ++i) {
// if (inputs[i] != "") {
// PermuteInput(args.ctx.graph, args.node, i, args.perm_inv);
// }
// }
// }
// PermuteInput(args.ctx.graph, args.node, 1, perm);
// } else {
// if (inputs[1] != "") {
// std::vector<int64_t> double_perm_inv = perm;
// double_perm_inv.reserve(2 * args.perm.size());
// for (int64_t p1 : perm) {
// double_perm_inv.push_back(p1 + rank_int);
// }
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
// }
// for (size_t i = 2; i < inputs.size(); ++i) {
// if (inputs[i] != "") {
// PermuteInput(args.ctx.graph, args.node, i, perm);
// }
// }
// }
//
// TransposeFirstInput(args.ctx, args.node, args.perm_inv);
// TransposeOutputs(args.ctx, args.node, args.perm);
// TransposeFirstInput(args.ctx, args.node, perm);
// TransposeOutputs(args.ctx, args.node, perm_inv);
//
// return true;
//}
// SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, args.node.OpType(), "com.microsoft.nhwc");
//
// return true;
// }
// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};
@ -1028,7 +1059,6 @@ static bool HandleReduceOp(HandlerArgs& args) {
}
} else {
if (!NormalizeAndValidateAxes(*axes, args.perm.size())) {
return false;
}
@ -1143,7 +1173,6 @@ static bool HandleSqueeze(HandlerArgs& args) {
if (!args.ctx.graph.HasValueConsumers(axes_inp)) {
args.ctx.graph.RemoveInitializer(axes_inp);
}
}
// Transpose inputs/outputs
@ -1178,25 +1207,31 @@ static bool HandleUnsqueeze(HandlerArgs& args) {
constexpr HandlerInfo unsqueeze_handler = {&FirstInput, &HandleUnsqueeze};
static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) {
size_t rank = args.perm.size();
if (args.ctx.opset >= 13) {
static bool HandleQuantizeDequantizeScale(const api::GraphRef& graph, const std::vector<int64_t>& perm,
api::NodeRef& node, int64_t opset) {
if (opset >= 13) {
size_t rank = perm.size();
// Update axis in Opset >= 13 if scale/zero_point are non-scalar
auto inputs = args.node.Inputs();
auto inputs = node.Inputs();
std::optional<std::vector<int64_t>> inp_shape = args.ctx.graph.GetValueInfo(inputs[1])->Shape();
bool scalar_params = inp_shape != std::nullopt && inp_shape->size() == 0;
auto inp_shape = graph.GetValueInfo(inputs[1])->Shape();
bool scalar_params = inp_shape.has_value() && inp_shape->size() == 0;
if (!scalar_params) {
int64_t axis = args.node.GetAttributeIntDefault("axis", 1);
int64_t axis = node.GetAttributeIntDefault("axis", 1);
if (!NormalizeAndValidateAxis(axis, rank)) {
return false;
}
args.node.SetAttributeInt("axis", args.perm[gsl::narrow_cast<size_t>(axis)]);
node.SetAttributeInt("axis", perm[gsl::narrow_cast<size_t>(axis)]);
}
}
return true;
}
static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) {
if (!HandleQuantizeDequantizeScale(args.ctx.graph, args.perm, args.node, args.ctx.opset)) {
return false;
}
TransposeFirstInput(args.ctx, args.node, args.perm_inv);
TransposeOutputs(args.ctx, args.node, args.perm);
@ -1215,7 +1250,7 @@ static bool HandleArgMinMax(HandlerArgs& args) {
return false;
}
int64_t new_axis = args.perm[gsl::narrow_cast<size_t>(axis)];
std::vector<int64_t> new_axes {new_axis};
std::vector<int64_t> new_axes{new_axis};
args.node.SetAttributeInt("axis", new_axis);
TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs);
@ -1368,7 +1403,7 @@ static bool HandleTile(HandlerArgs& args) {
} else {
// Case 2: Repeats is computed. Insert Gather node.
std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv);
std::vector<std::string_view> gather_inputs {repeats_inp, perm_inv_const};
std::vector<std::string_view> gather_inputs{repeats_inp, perm_inv_const};
auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1);
api::NodeRef& gather_node = *gather_node_ptr;
std::string_view gather_output = gather_node.Outputs()[0];
@ -1398,7 +1433,7 @@ static bool HandleTranspose(HandlerArgs& args) {
if (args.perm_inv == *node_perm) {
// Case 1: Permutations cancel.
auto consumers = args.ctx.graph.GetValueConsumers(args.node.Outputs()[0]);
auto consumers = args.ctx.graph.GetValueConsumers(node_output);
if (consumers->comprehensive) {
// If possible, replace references to output of 2nd transpose with input to 1st
ReplaceValueReferences(consumers->nodes, node_output, transpose_input);
@ -1425,14 +1460,13 @@ static bool HandleTranspose(HandlerArgs& args) {
} else {
// Worst-case scenario: Both parent output and 2nd transpose output cannot be removed (both graph outputs)
// despite computing the same value. Use an Identity op instead.
std::vector<std::string_view> single_empty_input {""};
std::vector<std::string_view> single_empty_input{""};
auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1);
api::NodeRef& identity = *identity_ptr;
args.ctx.graph.MoveOutput(args.node, 0, identity, 0);
identity.SetInput(0, transpose_input);
}
}
// In any case, the 2nd transpose can be removed.
args.ctx.graph.RemoveNode(args.node);
} else {
@ -1442,7 +1476,7 @@ static bool HandleTranspose(HandlerArgs& args) {
args.node.SetInput(0, transpose_input);
}
// 2nd transpose no longer references 1st. Remove 2nd if possible.
// 2nd transpose no longer references 1st. Remove first if possible.
if (!args.ctx.graph.HasValueConsumers(args.transpose.Outputs()[0])) {
args.ctx.graph.RemoveNode(args.transpose);
}
@ -1484,7 +1518,7 @@ constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &Han
static bool HandleQLinearPoolOp(HandlerArgs& args) {
// Swap between channel first/last variants. Only works for applicable values of perm.
int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 1);
int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 0);
size_t rank = args.perm.size();
if (rank < 2) return false;
auto p = ChannelLastToFirstPerm(rank);
@ -1500,7 +1534,11 @@ static bool HandleQLinearPoolOp(HandlerArgs& args) {
constexpr HandlerInfo q_linear_pool_op_handler = {&FirstInput, &HandleQLinearPoolOp};
static bool HandleMaxPool(HandlerArgs& args) {
// Replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool.
// For CPU EP replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool.
if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") {
return false;
}
auto outputs = args.node.Outputs();
if (outputs.size() == 2 && outputs[1] != "") {
// Can't optimize if optional "indices" output is provided
@ -1520,7 +1558,6 @@ static bool HandleMaxPool(HandlerArgs& args) {
auto new_node = SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, "NhwcMaxPool", "com.microsoft");
new_node->ClearAttribute("storage_order"); // Only relevant for indices output. Prohibited for NhwcMaxPool.
TransposeFirstInput(args.ctx, *new_node, args.perm_inv);
TransposeOutputs(args.ctx, *new_node, args.perm);
return true;
@ -1529,71 +1566,121 @@ static bool HandleMaxPool(HandlerArgs& args) {
constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool};
// TODO: check binary size of this and replace it with constexpr if large
static const std::unordered_map<std::string_view, const HandlerInfo&> handler_map {
static const std::unordered_map<std::string_view, const HandlerInfo&> handler_map{
{"Cast", simple_node_handler}, {"Exp", simple_node_handler}, {"Identity", simple_node_handler},
{"LeakyRelu", simple_node_handler}, {"Log", simple_node_handler}, {"Reciprocal", simple_node_handler},
{"Relu", simple_node_handler}, {"Sigmoid", simple_node_handler}, {"Sqrt", simple_node_handler},
{"Tanh", simple_node_handler}, {"Abs", simple_node_handler}, {"Not", simple_node_handler},
{"Ceil", simple_node_handler}, {"Floor", simple_node_handler}, {"Neg", simple_node_handler},
{"Erf", simple_node_handler}, {"HardSigmoid", simple_node_handler}, {"Round", simple_node_handler},
{"IsInf", simple_node_handler}, {"IsNaN", simple_node_handler},
{"Selu", simple_node_handler}, {"Shrink", simple_node_handler}, {"Sign", simple_node_handler},
{"Softplus", simple_node_handler}, {"Softsign", simple_node_handler}, {"ThresholdedRelu", simple_node_handler},
{"Celu", simple_node_handler}, {"HardSwish", simple_node_handler},
{"Cast", simple_node_handler},
{"Exp", simple_node_handler},
{"Identity", simple_node_handler},
{"LeakyRelu", simple_node_handler},
{"Log", simple_node_handler},
{"Reciprocal", simple_node_handler},
{"Relu", simple_node_handler},
{"Sigmoid", simple_node_handler},
{"Sqrt", simple_node_handler},
{"Tanh", simple_node_handler},
{"Abs", simple_node_handler},
{"Not", simple_node_handler},
{"Ceil", simple_node_handler},
{"Floor", simple_node_handler},
{"Neg", simple_node_handler},
{"Erf", simple_node_handler},
{"HardSigmoid", simple_node_handler},
{"Round", simple_node_handler},
{"IsInf", simple_node_handler},
{"IsNaN", simple_node_handler},
{"Selu", simple_node_handler},
{"Shrink", simple_node_handler},
{"Sign", simple_node_handler},
{"Softplus", simple_node_handler},
{"Softsign", simple_node_handler},
{"ThresholdedRelu", simple_node_handler},
{"Celu", simple_node_handler},
{"HardSwish", simple_node_handler},
{"Sin", simple_node_handler}, {"Cos", simple_node_handler}, {"Tan", simple_node_handler},
{"Sinh", simple_node_handler}, {"Cosh", simple_node_handler}, {"Tanh", simple_node_handler},
{"Asin", simple_node_handler}, {"Acos", simple_node_handler}, {"Atan", simple_node_handler},
{"Asinh", simple_node_handler}, {"Acosh", simple_node_handler}, {"Atanh", simple_node_handler},
{"Sin", simple_node_handler},
{"Cos", simple_node_handler},
{"Tan", simple_node_handler},
{"Sinh", simple_node_handler},
{"Cosh", simple_node_handler},
{"Tanh", simple_node_handler},
{"Asin", simple_node_handler},
{"Acos", simple_node_handler},
{"Atan", simple_node_handler},
{"Asinh", simple_node_handler},
{"Acosh", simple_node_handler},
{"Atanh", simple_node_handler},
{"Add", broadcast_node_handler}, {"Max", broadcast_node_handler}, {"Min", broadcast_node_handler},
{"Mul", broadcast_node_handler}, {"Sub", broadcast_node_handler}, {"Div", broadcast_node_handler},
{"And", broadcast_node_handler}, {"Or", broadcast_node_handler}, {"Xor", broadcast_node_handler},
{"Mod", broadcast_node_handler}, {"PRelu", broadcast_node_handler}, {"BitShift", broadcast_node_handler},
{"Equal", broadcast_node_handler}, {"Greater", broadcast_node_handler}, {"Less", broadcast_node_handler},
{"GreaterOrEqual", broadcast_node_handler}, {"LessOrEqual", broadcast_node_handler},
{"Mean", broadcast_node_handler}, {"Sum", broadcast_node_handler}, {"Pow", broadcast_node_handler},
{"Where", broadcast_node_handler},
{"Add", broadcast_node_handler},
{"Max", broadcast_node_handler},
{"Min", broadcast_node_handler},
{"Mul", broadcast_node_handler},
{"Sub", broadcast_node_handler},
{"Div", broadcast_node_handler},
{"And", broadcast_node_handler},
{"Or", broadcast_node_handler},
{"Xor", broadcast_node_handler},
{"Mod", broadcast_node_handler},
{"PRelu", broadcast_node_handler},
{"BitShift", broadcast_node_handler},
{"Equal", broadcast_node_handler},
{"Greater", broadcast_node_handler},
{"Less", broadcast_node_handler},
{"GreaterOrEqual", broadcast_node_handler},
{"LessOrEqual", broadcast_node_handler},
{"Mean", broadcast_node_handler},
{"Sum", broadcast_node_handler},
{"Pow", broadcast_node_handler},
{"Where", broadcast_node_handler},
{"Clip", node_1_inp_handler}, {"CastLike", node_1_inp_handler},
{"Clip", node_1_inp_handler},
{"CastLike", node_1_inp_handler},
{"Transpose", transpose_handler},
{"Concat", concat_handler},
{"Split", split_handler},
{"Shape", shape_handler},
{"Pad", pad_handler},
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
// https://github.com/microsoft/onnxruntime/issues/9857
//{"Resize", resize_handler},
{"ReduceSum", reduce_sum_handler},
{"Transpose", transpose_handler},
{"Concat", concat_handler},
{"Split", split_handler},
{"Shape", shape_handler},
{"Pad", pad_handler},
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
// https://github.com/microsoft/onnxruntime/issues/9857
// {"Resize", resize_handler},
{"ReduceSum", reduce_sum_handler},
{"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, {"ReduceMax", reduce_op_handler},
{"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler},
{"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler},
{"ReduceLogSum", reduce_op_handler},
{"ReduceLogSumExp", reduce_op_handler},
{"ReduceMax", reduce_op_handler},
{"ReduceMean", reduce_op_handler},
{"ReduceMin", reduce_op_handler},
{"ReduceProd", reduce_op_handler},
{"ReduceSumSquare", reduce_op_handler},
{"ReduceL1", reduce_op_handler},
{"ReduceL2", reduce_op_handler},
{"ArgMin", arg_min_max_handler}, {"ArgMax", arg_min_max_handler},
{"ArgMin", arg_min_max_handler},
{"ArgMax", arg_min_max_handler},
{"Squeeze", squeeze_handler},
{"Unsqueeze", unsqueeze_handler},
{"Slice", slice_handler},
{"Tile", tile_handler},
{"Squeeze", squeeze_handler},
{"Unsqueeze", unsqueeze_handler},
{"Slice", slice_handler},
{"Tile", tile_handler},
{"Softmax", soft_hard_max_handler}, {"Hardmax", soft_hard_max_handler}, {"LogSoftmax", soft_hard_max_handler},
{"Softmax", soft_hard_max_handler},
{"Hardmax", soft_hard_max_handler},
{"LogSoftmax", soft_hard_max_handler},
{"QuantizeLinear", quantize_dequantize_linear_handler}, {"DequantizeLinear", quantize_dequantize_linear_handler},
{"QuantizeLinear", quantize_dequantize_linear_handler},
{"DequantizeLinear", quantize_dequantize_linear_handler},
};
static const std::unordered_map<std::string_view, const HandlerInfo&> extended_handler_map{
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
{"com.microsoft.QLinearLeakyRelu", node_1_inp_handler},
{"com.microsoft.QLinearConcat", q_linear_concat_handler},
{"com.microsoft.QLinearAdd", q_linear_binary_op_handler},
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
{"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler},
{"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler},
{"MaxPool", max_pool_op_handler},
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
{"com.microsoft.QLinearLeakyRelu", node_1_inp_handler},
{"com.microsoft.QLinearConcat", q_linear_concat_handler},
{"com.microsoft.QLinearAdd", q_linear_binary_op_handler},
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
{"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler},
{"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler},
{"MaxPool", max_pool_op_handler},
};
static const HandlerInfo* GetHandler(api::NodeRef& node, bool allow_extended_ops) {
@ -1674,7 +1761,9 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef&
}
// Returns nullopt if graph opset is unsupported.
std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops) {
std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops,
const std::string& provider_type, OptimizerMode mode,
const std::unordered_set<std::string_view>& layout_sensitive_ops) {
auto opset = graph.Opset("");
if (opset == std::nullopt) {
opset = graph.Opset("ai.onnx");
@ -1688,14 +1777,17 @@ std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allo
allow_extended_ops = false;
}
}
OptimizerCtx ctx{*opset, graph, allow_extended_ops, /*skip_cost_check*/ false};
// during layout transformation we want to push the transposes as far out as possible.
// it is important that the EP gets the entire graph in the layout it prefers.
bool skip_cost_check = mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM;
OptimizerCtx ctx{*opset, graph, allow_extended_ops, skip_cost_check, provider_type, mode, layout_sensitive_ops};
return ctx;
}
// Performs optimization. General algorithm: iterate over nodes in topological order. If a node has a transpose
// as input, push it through if the transpose cost does not increase and is likely to decrease.
bool OptimizeImpl(OptimizerCtx& ctx) {
const std::vector<std::unique_ptr<api::NodeRef>> nodes = ctx.graph.Nodes();
std::unordered_set<std::string> outputs_leading_to_transpose;
@ -1703,7 +1795,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
// First iterate over sorted nodes in reverse order to find which outputs have paths through supported ops to
// transpose nodes. We pull push transposes towards these outputs.
for (size_t i = 0; i < nodes.size(); ++i) {
api::NodeRef& node = *nodes[nodes.size() - i - 1];
if (node.IsOp("Transpose")) {
outputs_leading_to_transpose.insert(std::string(node.Inputs()[0]));
@ -1731,6 +1822,13 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
// New transpose nodes are inserted, but always as an input to an existing node.
for (size_t i = 0; i < nodes.size(); ++i) {
api::NodeRef& node = *nodes[i];
if (ctx.mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM &&
ctx.layout_sensitive_ops.count(node.OpType()) && node.GetExecutionProviderType() != ctx.provider_type) {
// If the current op is layout sensitive and it is not assigned to the given provider
// then do not process transpose.
continue;
}
std::vector<std::string_view> inputs = node.Inputs();
for (size_t j = 0; j < inputs.size(); ++j) {
std::string_view inp = inputs[j];
@ -1741,7 +1839,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
if (transpose != nullptr && transpose->IsOp("Transpose")) {
std::optional<std::vector<int64_t>> perm = GetPermAttrIfValid(*transpose);
if (perm != std::nullopt) {
std::vector<int64_t> perm_inv = InvertPerm(*perm);
if (ProcessTranspose(ctx, *transpose, node, *perm, j, outputs_leading_to_transpose)) {
changed = true;
// Subsequent inputs may have changed and node may have been removed.
@ -1751,11 +1848,69 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
}
}
}
// Currently limiting the second optimization pass to layout transform mode
// TODO: Enable this for both the modes.
if (ctx.mode == OptimizerMode::OPTIMIZE_TRANSPOSE) {
return changed;
}
// Run second optimization pass.
// If any transpose succeeds a DQ node, move it above the DQ node.
// In case of QDQ models this helps to preserve the QDQ node unit
// In all other scenarios this is beneficial as well because moving transpose above DQ node is more efficient as
// transpose node now handles less data.
auto graph_nodes = ctx.graph.Nodes();
for (size_t i = 1; i < graph_nodes.size(); i++) {
if (graph_nodes[i]->OpType() == "Transpose") {
auto& transpose_node = *graph_nodes[i];
auto dq_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]);
if (!dq_node || dq_node->OpType() != "DequantizeLinear") {
continue;
}
auto consumers = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]);
bool is_part_of_qdq_unit = std::find_if(consumers->nodes.cbegin(), consumers->nodes.cend(),
[](const std::unique_ptr<api::NodeRef>& node) {
return node->OpType() == "QuantizeLinear";
}) != consumers->nodes.cend();
if (is_part_of_qdq_unit) {
continue;
}
// Update Dequantize Node and move the transpose above it
auto perm = GetPermAttrIfValid(transpose_node);
if (!perm.has_value()) {
continue;
}
if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) {
continue;
}
TransposeFirstInput(ctx, *dq_node, *perm);
// remove existing transpose node
transpose_node.SetInput(0, "");
ctx.graph.MoveOutput(transpose_node, 0, *dq_node, 0);
ctx.graph.RemoveNode(transpose_node);
changed = true;
}
}
return changed;
}
bool Optimize(api::GraphRef& graph, bool allow_extended_ops) {
auto ctx = MakeOptimizerContext(graph, allow_extended_ops);
const std::unordered_set<std::string_view>& GetLayoutSensitiveOps() {
// List of all layout sensitive ops defined in ONNX standard.
static std::unordered_set<std::string_view> layout_sensitive_ops = {"Conv", "QLinearConv", "BatchNormalization",
"AveragePool", "GlobalAveragePool", "MaxPool",
"GlobalMaxPool", "LRN"};
return layout_sensitive_ops;
}
bool Optimize(api::GraphRef& graph, bool allow_extended_ops,
const std::string& provider_type, OptimizerMode mode,
const std::unordered_set<std::string_view>& layout_sensitive_ops) {
auto ctx = MakeOptimizerContext(graph, allow_extended_ops, provider_type, mode, layout_sensitive_ops);
if (ctx == std::nullopt) {
return false;
}
@ -1781,15 +1936,14 @@ void WrapTransposesAroundNode(api::GraphRef& graph, api::NodeRef& node,
std::unique_ptr<api::NodeRef> SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node,
std::string_view op_type, std::string_view domain) {
auto inputs = node.Inputs();
auto outputs = node.Outputs();
auto new_node = graph.AddNode(op_type, inputs, outputs.size(), domain);
auto new_node = graph.CopyNode(node, op_type, domain);
for (size_t j = 0; j < outputs.size(); ++j) {
if (outputs[j] != "") {
graph.MoveOutput(node, j, *new_node, j);
}
}
new_node->CopyAttributes(node);
graph.RemoveNode(node);
return new_node;
}

View file

@ -69,7 +69,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
};
result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
gen_metadef_name, COREML);
gen_metadef_name, COREML, kCoreMLExecutionProvider);
const auto num_of_partitions = result.size();
const auto num_of_supported_nodes = std::transform_reduce(

View file

@ -262,7 +262,7 @@ Status ModelBuilder::RegisterInitializers() {
shaper_.AddShape(name, operand_type.dimensions);
uint32_t index = 0;
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index));
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index));
const size_t size = operand_type.GetOperandBlobByteSize();
const size_t padded_size = GetPaddedByteSize(size);
sizeAll += padded_size;
@ -352,7 +352,7 @@ Status ModelBuilder::RegisterModelInputs() {
shaper_.AddShape(input_name, operand_type.dimensions);
uint32_t index = 0;
ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, false /* is_nhwc */, index));
ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, index));
input_index_vec_.push_back(index);
nnapi_model_->AddInput(input_name, operand_type);
}
@ -386,11 +386,6 @@ Status ModelBuilder::RegisterModelOutputs() {
}
std::string nnapi_output_name = output_name;
if (IsOperandNHWC(output_name)) {
// We need to transpose the output still in nhwc back to nchw
nnapi_output_name = GetUniqueName(output_name + "_nhwc_to_nchw");
ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(*this, output_name, nnapi_output_name));
}
output_index_vec_.push_back(operand_indices_[nnapi_output_name]);
nnapi_model_->AddOutput(output_name, nnapi_output_name, operand_types_.at(nnapi_output_name));
@ -405,10 +400,10 @@ void ModelBuilder::RegisterModelShaper() {
Status ModelBuilder::AddNewOperand(const std::string& name,
const OperandType& operand_type,
bool is_nhwc, uint32_t& index) {
uint32_t& index) {
LOGS_DEFAULT(VERBOSE) << "operand name: " << name;
ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operand_type, index));
RegisterOperand(name, index, operand_type, is_nhwc);
RegisterOperand(name, index, operand_type);
return Status::OK();
}
@ -432,13 +427,10 @@ Status ModelBuilder::AddNewNNAPIOperand(const OperandType& operand_type, uint32_
}
void ModelBuilder::RegisterOperand(const std::string& name, uint32_t index,
const OperandType& operand_type, bool is_nhwc) {
const OperandType& operand_type) {
operand_indices_[name] = index;
operand_types_.emplace(name, operand_type);
operands_.insert(name);
if (is_nhwc)
RegisterNHWCOperand(name);
}
Status ModelBuilder::SetOperandValue(uint32_t index,
@ -466,7 +458,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
const android::nn::wrapper::OperandType& operand_type) {
shaper_.AddShape(name, operand_type.dimensions);
uint32_t index = 0;
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index));
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index));
const size_t size = operand_type.GetOperandBlobByteSize();
// for small size operand, the value will be copied
@ -528,12 +520,11 @@ Status ModelBuilder::AddOperations() {
Status ModelBuilder::AddOperation(int op, const std::vector<uint32_t>& input_indices,
const std::vector<std::string>& output_names,
const std::vector<OperandType>& types,
const std::vector<bool>& is_nhwc_vec) {
const std::vector<OperandType>& types) {
std::vector<uint32_t> output_indices;
for (size_t i = 0; i < types.size(); i++) {
uint32_t index = 0;
ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], is_nhwc_vec[i], index));
ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], index));
output_indices.push_back(index);
}
@ -693,7 +684,6 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
const auto& op_type = node_unit.GetNode().OpType();
if (!Contains(op_builders, op_type))
return nullptr;
return op_builders.at(op_type);
}
@ -708,47 +698,13 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) {
return unique_name;
}
DataLayout ModelBuilder::GetPreferredLayout() const {
return use_nchw_ ? DataLayout::NCHW : DataLayout::NHWC;
}
const InitializedTensorSet& ModelBuilder::GetInitializerTensors() const {
return graph_viewer_.GetAllInitializedTensors();
}
void ModelBuilder::RegisterNHWCOperand(const std::string& name) {
nhwc_operands_.insert(name);
}
bool ModelBuilder::IsOperandNHWC(const std::string& name) const {
return Contains(nhwc_operands_, name);
}
bool ModelBuilder::GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name) {
if (Contains(nhwc_to_nchw_map_, nhwc_name)) {
nchw_name = nhwc_to_nchw_map_[nhwc_name];
return true;
}
return false;
}
bool ModelBuilder::GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name) {
if (Contains(nchw_to_nhwc_map_, nchw_name)) {
nhwc_name = nchw_to_nhwc_map_[nchw_name];
return true;
}
return false;
}
Status ModelBuilder::SetNHWCToNCHWOperandMap(const std::string& nhwc_name,
const std::string& nchw_name) {
ORT_RETURN_IF_NOT(!Contains(nhwc_to_nchw_map_, nhwc_name), "A previous nchw to nhwc map exists");
nhwc_to_nchw_map_[nhwc_name] = nchw_name;
return Status::OK();
}
Status ModelBuilder::SetNCHWToNHWCOperandMap(const std::string& nchw_name,
const std::string& nhwc_name) {
ORT_RETURN_IF_NOT(!Contains(nchw_to_nhwc_map_, nchw_name), "A previous nchw to nhwc map exists");
nchw_to_nhwc_map_[nchw_name] = nhwc_name;
return Status::OK();
}
} // namespace nnapi
} // namespace onnxruntime

View file

@ -13,6 +13,7 @@
namespace onnxruntime {
class GraphViewer;
enum class DataLayout;
class NodeUnit;
class Node;
class NodeArg;
@ -47,8 +48,7 @@ class ModelBuilder {
// Add an NNAPI operation (operator)
common::Status AddOperation(int op, const std::vector<uint32_t>& input_indices,
const std::vector<std::string>& output_names,
const std::vector<android::nn::wrapper::OperandType>& types,
const std::vector<bool>& is_nhwc_vec);
const std::vector<android::nn::wrapper::OperandType>& types);
// Find if the given node_unit has a fuseable activation (Relu/Relu1/Relu6)
// For now we only support node_unit with a single output
@ -69,8 +69,7 @@ class ModelBuilder {
// Register informations for a particular operand
void RegisterOperand(const std::string& name, uint32_t index,
const android::nn::wrapper::OperandType& operand_type,
bool is_nhwc);
const android::nn::wrapper::OperandType& operand_type);
// Generate an unique name for intermediate result
std::string GetUniqueName(const std::string& base_name);
@ -79,6 +78,9 @@ class ModelBuilder {
void SetUseNCHW(bool use_nchw) { use_nchw_ = use_nchw; }
bool UseNCHW() const { return use_nchw_; }
// Returns the preferred layout for this EP.
DataLayout GetPreferredLayout() const;
// Relax fp32 computation to fp16
// It is off by default
void SetUseFp16(bool use_fp16) { use_fp16_ = use_fp16; }
@ -106,21 +108,9 @@ class ModelBuilder {
const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
void RegisterNHWCOperand(const std::string& name);
bool IsOperandNHWC(const std::string& name) const;
// Get the operand transposed to nchw/nhwc from given nhwc/nchw operand, if it exists
bool GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name);
bool GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name);
// Get the NodeUnit which contains the given node
const NodeUnit& GetNodeUnit(const Node* node) const;
common::Status SetNHWCToNCHWOperandMap(const std::string& nhwc_name,
const std::string& nchw_name);
common::Status SetNCHWToNHWCOperandMap(const std::string& nchw_name,
const std::string& nhwc_name);
private:
const NnApi* nnapi_{nullptr};
const GraphViewer& graph_viewer_;
@ -148,12 +138,6 @@ class ModelBuilder {
std::unordered_map<std::string, std::shared_ptr<IOpSupportChecker>> op_support_checkers_;
// Operands in nhwc
std::unordered_set<std::string> nhwc_operands_;
// Maps between nhwc and nchw, and vice versa
std::unordered_map<std::string, std::string> nhwc_to_nchw_map_;
std::unordered_map<std::string, std::string> nchw_to_nhwc_map_;
std::vector<uint32_t> input_index_vec_;
std::vector<uint32_t> output_index_vec_;
@ -206,7 +190,6 @@ class ModelBuilder {
common::Status AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type, uint32_t& index);
common::Status AddNewOperand(const std::string& name,
const android::nn::wrapper::OperandType& operand_type,
bool is_nhwc,
uint32_t& index);
static const IOpBuilder* GetOpBuilder(const NodeUnit& node_unit);

View file

@ -40,8 +40,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder,
const std::string& input,
const std::string& perm_name,
std::vector<int32_t> perm,
const std::string& output,
bool output_is_nhwc) {
const std::string& output) {
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());
@ -59,102 +58,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder,
OperandType output_operand_type = operand_types.at(input);
output_operand_type.SetDimensions(shaper[output]);
return model_builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, {output},
{output_operand_type}, {output_is_nhwc});
}
Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder,
const std::string& input,
const std::string& output,
bool nchw_to_nhwc) {
ORT_RETURN_IF_NOT(!model_builder.UseNCHW(), "model_builder.UseNCHW() is on");
const auto& shaper(model_builder.GetShaper());
ORT_RETURN_IF_NOT(4 == shaper[input].size(),
"TransposeBetweenNCHWAndNHWC input has to be a 4d tensor, actual dimensions: ", shaper[input].size());
std::string perm_name;
std::vector<int32_t> perm;
if (nchw_to_nhwc) {
perm_name = model_builder.GetUniqueName(input + "nchw_to_nhwc_perm");
perm = {0, 2, 3, 1};
} else { // nhwc_to_nchw
perm_name = model_builder.GetUniqueName(input + "nhwc_to_nchw_perm");
perm = {0, 3, 1, 2};
}
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, nchw_to_nhwc));
if (nchw_to_nhwc) {
ORT_RETURN_IF_ERROR(model_builder.SetNCHWToNHWCOperandMap(input, output));
} else { // nhwc_to_nchw
ORT_RETURN_IF_ERROR(model_builder.SetNHWCToNCHWOperandMap(input, output));
}
LOGS_DEFAULT(VERBOSE) << "Operand [" << input << "] with shape "
<< Shape2String(shaper[input])
<< " is transposed "
<< (nchw_to_nhwc ? "nchw_to_nhwc" : "nhwc_to_nchw")
<< " to [" << output << "] with shape "
<< Shape2String(shaper[output]);
return Status::OK();
}
Status TransposeNHWCToNCHW(ModelBuilder& model_builder,
const std::string& input,
const std::string& output) {
return TransposeBetweenNCHWAndNHWC(model_builder, input, output, false /* nchw_to_nhwc */);
}
Status TransposeNCHWToNHWC(ModelBuilder& model_builder,
const std::string& input,
const std::string& output) {
return TransposeBetweenNCHWAndNHWC(model_builder, input, output, true /* nchw_to_nhwc */);
}
// Convert the input from nchw to nhwc
// Caller should ensure input is currently in nchw format using ModelBuilder::IsOperandNHWC
Status GetNHWCInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nhwc_input) {
const auto& nchw_input = node_unit.Inputs()[input_index].node_arg.Name();
if (!model_builder.GetNHWCOperand(nchw_input, nhwc_input)) {
nhwc_input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc");
ORT_RETURN_IF_ERROR(TransposeNCHWToNHWC(model_builder, nchw_input, nhwc_input));
}
return Status::OK();
}
// Convert the input from nhwc to nchw
// Caller should ensure input is currently in nhwc format using ModelBuilder::IsOperandNHWC
Status GetNCHWInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nchw_input) {
const auto& nhwc_input = node_unit.Inputs()[input_index].node_arg.Name();
if (!model_builder.GetNCHWOperand(nhwc_input, nchw_input)) {
nchw_input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw");
ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(model_builder, nhwc_input, nchw_input));
}
return Status::OK();
}
// Transpose layouts if necessary for element wise operators with 2 inputs
// and return the layout type of output tensor
// If both inputs have same layout, the output will have the same layout
// Otherwise we will need transpose the nhwc input back to nchw, and output will be nchw
Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const NodeUnit& node_unit,
std::string& input1, std::string& input2,
bool& output_is_nhwc) {
bool input1_is_nhwc = model_builder.IsOperandNHWC(input1);
bool input2_is_nhwc = model_builder.IsOperandNHWC(input2);
output_is_nhwc = false;
if (input1_is_nhwc == input2_is_nhwc) {
output_is_nhwc = input1_is_nhwc;
} else if (input1_is_nhwc) {
// need transpose input1 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input1));
} else { // input2_is_nhwc
// need transpose input2 back to nchw
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 1, input2));
}
return Status::OK();
{output_operand_type});
}
static Status AddBinaryOperator(int32_t op_type,
@ -164,7 +68,6 @@ static Status AddBinaryOperator(int32_t op_type,
bool add_activation,
int32_t fuse_code,
const std::string& output,
bool output_is_nhwc,
float output_scale = 0.0f,
int32_t output_zero_point = 0) {
auto& shaper(model_builder.GetShaper());
@ -183,7 +86,7 @@ static Status AddBinaryOperator(int32_t op_type,
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output],
output_scale, output_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_type, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -231,7 +134,7 @@ static Status AddSqueezeOp(ModelBuilder& model_builder,
ORT_RETURN_IF_ERROR(shaper.Squeeze(input, axes, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SQUEEZE, input_indices,
{output}, {output_operand_type}, {false}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -595,6 +498,16 @@ static void AddInputToSkip(ModelBuilder& model_builder, const NodeUnitIODef& io_
AddQuantizationScaleAndZeroPointToSkip(model_builder, *io_def.quant_param);
}
static Status IsOpInRequiredLayout(bool use_nchw, const NodeUnit& node_unit) {
bool is_op_nhwc = node_unit.Domain() == kMSInternalNHWCDomain;
if (is_op_nhwc && use_nchw) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Expected layout and operator layout do not match. Possible bug in layout optimizer.");
}
return Status::OK();
}
template <class T>
void CreateSharedOpBuilderImpl(const std::string& op_type,
OpBuilderRegistrations& op_registrations,
@ -715,10 +628,6 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
std::string input2 = inputs[1].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = false;
ORT_RETURN_IF_ERROR(
TransposeBinaryOpInputLayout(model_builder, node_unit, input1, input2, output_is_nhwc));
float a_scale = 0.0f,
b_scale = 0.0f,
y_scale = 0.0f;
@ -747,7 +656,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
return AddBinaryOperator(op_code, model_builder,
input1, input2,
add_activation, fuse_code,
output, output_is_nhwc, y_scale, y_zero_point);
output, y_scale, y_zero_point);
}
#pragma endregion
@ -766,19 +675,18 @@ Status ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
// skip this relu if it is some op's fuse output
if (Contains(model_builder.GetFusedActivations(), input)) {
LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node_unit.Name() << "] fused";
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
} else {
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RELU, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
}
return Status::OK();
@ -825,15 +733,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension");
}
if (model_builder.IsOperandNHWC(input)) {
ORT_RETURN_IF_NOT(input_dims == 4, "Only 4D shape can be nhwc");
// we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc
const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
for (size_t i = 0; i < perm.size(); i++)
perm[i] = axis_nchw_to_nhwc[perm[i]];
}
// Check if the quantization scale and ZP are correct
if (IsQuantizedOp(node_unit)) {
float x_scale = 0.0f;
@ -845,11 +744,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
std::string perm_name = model_builder.GetUniqueName(node_unit.Name() + input + "perm");
// It is possible this onnx transpose operator can be nchw->nhwc, but so far I don't see
// any scenario will do this since onnx is nchw only, assume the output is always not nhwc
// even it is, there will be extra transpose in the onnx model to convert it back to nchw
// before conv/pool/... operators
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, false /* is_nhwc */));
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output));
return Status::OK();
}
@ -884,7 +779,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
}
// We can skip the Reshape if all the output edges satisfies both the following conditions
// 1. The output the reshape/flatten is not an output of the graph
// 1. The output of the reshape/flatten is not an output of the graph
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
// and not any other types of operator,
// and the input rank >= 2 and output_rank == 2
@ -975,7 +870,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
// NNAPI CPU impl and NNAPI hardware accelerator impl
if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
// Since reshape can be skipped, only register the dimension and type, with same index and new name
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
} else {
// We still need to perform a reshape here
// Add input
@ -987,7 +882,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen);
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type));
input_indices.push_back(operand_indices.at(shape_name));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false}));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}));
}
return Status::OK();
@ -998,10 +893,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
const auto& initializers(model_builder.GetInitializerTensors());
auto input = node_unit.Inputs()[0].node_arg.Name();
if (model_builder.IsOperandNHWC(input)) {
// We want to transpose nhwc operand back to nchw before reshape
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
}
const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name());
std::vector<uint8_t> unpacked_tensor;
@ -1099,10 +990,10 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
const auto tensor_imm_product_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_mul");
Shape tensor_a_dimen = {size};
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = input_is_nhwc;
bool use_nchw = model_builder.UseNCHW();
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
if (!input_is_nhwc) {
if (use_nchw) {
// the batch normalization is applied on C channel,
// if the input is NC[HW], will need correct shape for tensor_a/b
// to make sure we are broadcasting on the correct channel,
@ -1126,8 +1017,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
model_builder,
input, tensor_a_name,
true /* add_activation */, ANEURALNETWORKS_FUSED_NONE,
tensor_imm_product_name,
output_is_nhwc));
tensor_imm_product_name));
// Add
int32_t fuse_code = model_builder.FindActivation(node_unit);
@ -1135,8 +1025,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
model_builder,
tensor_imm_product_name, tensor_b_name,
true /* add_activation */, fuse_code,
output,
output_is_nhwc));
output));
return Status::OK();
}
@ -1190,16 +1079,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
auto input = node_unit.Inputs()[0].node_arg.Name();
bool use_nchw = model_builder.UseNCHW();
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = false;
if (use_nchw) {
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
} else {
output_is_nhwc = true;
if (!input_is_nhwc) {
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
}
}
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
const auto& output = node_unit.Outputs()[0].node_arg.Name();
const auto& op_type = node_unit.OpType();
@ -1290,7 +1170,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -1361,16 +1241,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
auto input = inputs[0].node_arg.Name();
bool use_nchw = model_builder.UseNCHW();
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = false;
if (use_nchw) {
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
} else {
output_is_nhwc = true;
if (!input_is_nhwc) {
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
}
}
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
const auto& weight = inputs[1].node_arg.Name();
const auto& weight_tensor = *initializers.at(weight);
@ -1559,7 +1430,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -1579,7 +1450,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
auto to = helper.Get("to", 0);
Type type;
@ -1599,7 +1469,7 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, {output},
{output_operand_type}, {output_is_nhwc}));
{output_operand_type}));
return Status::OK();
}
@ -1620,21 +1490,14 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
NodeAttrHelper helper(node_unit);
auto input = node_unit.Inputs()[0].node_arg.Name();
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = input_is_nhwc;
// TODO: Needs fix.
if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) {
if (model_builder.IsOperandNHWC(input)) {
output_is_nhwc = false;
// We want to transpose nhwc operand back to nchw before softmax
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
}
ORT_ENFORCE(model_builder.UseNCHW(),
"For Android API Level < 29 input for softmax needs to be NCHW.");
}
int32_t axis = helper.Get("axis", 1);
if (output_is_nhwc) {
const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
axis = axis_nchw_to_nhwc[axis];
}
const auto& output = node_unit.Outputs()[0].node_arg.Name();
float beta = 1.f;
@ -1650,7 +1513,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -1672,14 +1535,13 @@ Status IdentityOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, con
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input)); // input
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
return Status::OK();
}
@ -1839,7 +1701,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
ORT_RETURN_IF_ERROR(shaper.FC(input1, input2, output));
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output], y_scale, y_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices,
{output}, {output_operand_type}, {false}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -1896,8 +1758,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
bool is_qlinear_sigmoid = op_type == "QLinearSigmoid";
@ -1945,7 +1805,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
input_indices.push_back(operand_indices.at(input));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -1967,8 +1827,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
std::vector<uint32_t> input_indices;
const auto& input0 = inputs[0].node_arg.Name();
bool all_input_have_same_layout = true;
bool output_is_nhwc = false;
const auto node_input_size = inputs.size();
// First if the inputs are uint8, we need verify all the inputs have same scale and zero points
@ -1989,48 +1847,17 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
}
}
// First we want to see if all the input are same layout
for (size_t i = 0; i < node_input_size - 1; i++) {
all_input_have_same_layout =
all_input_have_same_layout &&
model_builder.IsOperandNHWC(inputs[i].node_arg.Name()) ==
model_builder.IsOperandNHWC(inputs[i + 1].node_arg.Name());
}
std::vector<std::string> input_names;
input_names.reserve(node_input_size);
if (all_input_have_same_layout) {
// if all the inputs are of same layout, output will be the same layout
output_is_nhwc = model_builder.IsOperandNHWC(input0);
for (size_t i = 0; i < node_input_size; i++) {
const auto& input = inputs[i].node_arg.Name();
input_indices.push_back(operand_indices.at(input));
input_names.push_back(input);
}
} else {
// if all the inputs are not same layout,
// will need transpos those nhwc tensors back to nchw
for (size_t i = 0; i < node_input_size; i++) {
auto input = inputs[i].node_arg.Name();
if (model_builder.IsOperandNHWC(input)) {
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, i, input));
}
input_indices.push_back(operand_indices.at(input));
input_names.push_back(input);
}
for (size_t i = 0; i < node_input_size; i++) {
const auto& input = inputs[i].node_arg.Name();
input_indices.push_back(operand_indices.at(input));
input_names.push_back(input);
}
int rank = shaper[input0].size();
int32_t axis = static_cast<int32_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
if (output_is_nhwc) {
ORT_RETURN_IF_NOT(rank == 4,
"nhwc is only on 4d shape, input ", input0, " has rank: ", rank);
// we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc
const uint32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
axis = axis_nchw_to_nhwc[axis];
}
ADD_SCALAR_OPERAND(model_builder, input_indices, axis);
const auto& output = node_unit.Outputs()[0].node_arg.Name();
@ -2038,7 +1865,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
OperandType output_operand_type = operand_types.at(input0);
output_operand_type.SetDimensions(shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2089,10 +1916,6 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
auto input = node_unit.Inputs()[0].node_arg.Name();
if (model_builder.IsOperandNHWC(input)) {
// We want to transpose nhwc operand back to nchw before squeeze
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
}
std::vector<int32_t> axes;
ORT_RETURN_IF_ERROR(GetAxes(model_builder, node_unit, axes));
@ -2121,7 +1944,6 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
float scale = 0.0f;
int32_t zero_point = 0;
@ -2134,7 +1956,7 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_QUANTIZE, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2161,7 +1983,6 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
const auto& input = inputs[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
float scale = 0.0;
int32_t zero_point = 0;
@ -2176,7 +1997,7 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_DEQUANTIZE, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2198,13 +2019,14 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
auto input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
auto use_nchw = model_builder.UseNCHW();
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) {
// on android api level 28, we need to transpose the nchw input to nhwc
output_is_nhwc = true;
if (!model_builder.IsOperandNHWC(input)) {
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
}
// it is very rare that users set nchw format when using nnapi. Therefore, instead of
// adding the ability to support conversion we fail and stop.
ORT_ENFORCE(!use_nchw, "NCHW format is not supported on android api level 28");
}
auto alpha = helper.Get("alpha", 0.0001f);
@ -2225,16 +2047,16 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
// specify axis is only available on api level >= 29
if (android_feature_level > ANEURALNETWORKS_FEATURE_LEVEL_2) {
// ONNX LRN is always performed on C dimension
int32_t axis = output_is_nhwc
? 3 // nhwc
: 1; // nchw
int32_t axis = use_nchw
? 1 // nchw
: 3; // nhwc
ADD_SCALAR_OPERAND(model_builder, input_indices, axis);
}
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2266,14 +2088,13 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
if (Contains(model_builder.GetFusedActivations(), input)) {
LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node_unit.Name() << "] fused";
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
return Status::OK();
}
@ -2293,7 +2114,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2344,16 +2165,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
auto input = inputs[0].node_arg.Name();
bool use_nchw = model_builder.UseNCHW();
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
bool output_is_nhwc = false;
if (use_nchw) {
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
} else {
output_is_nhwc = true;
if (!input_is_nhwc) {
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
}
}
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
// Check if the quantization scale and ZP is correct
if (IsQuantizedOp(node_unit)) {
@ -2373,16 +2185,19 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
bool using_half_pixel = coord_trans_mode == "half_pixel";
bool using_align_corners = coord_trans_mode == "align_corners";
// if the node domain is NHWC it means all the node inputs are converted to NHWC format by the layout transformer.
// pick the index for height and width based on the format.
int h_idx = use_nchw ? 2 : 1;
int w_idx = use_nchw ? 3 : 2;
if (inputs.size() == 3) { // we are using scales
const auto& scales_name = inputs[2].node_arg.Name();
const auto& scales_tensor = *initializers.at(scales_name);
std::vector<uint8_t> unpacked_tensor;
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor));
const float* scales_data = reinterpret_cast<const float*>(unpacked_tensor.data());
float scale_h = scales_data[2];
float scale_w = scales_data[3];
ORT_RETURN_IF_ERROR(
shaper.ResizeUsingScales(input, scale_h, scale_w, use_nchw, output));
shaper.ResizeUsingScales(input, scales_data[h_idx], scales_data[w_idx], use_nchw, output));
} else { // we are using sizes
const auto& sizes_name = inputs[3].node_arg.Name();
const auto& sizes_tensor = *initializers.at(sizes_name);
@ -2390,12 +2205,12 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor));
const int64_t* sizes_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
ORT_RETURN_IF_ERROR(
shaper.ResizeUsingOutputSizes(input, SafeInt<uint32_t>(sizes_data[2]), SafeInt<uint32_t>(sizes_data[3]), use_nchw, output));
shaper.ResizeUsingOutputSizes(input, SafeInt<uint32_t>(sizes_data[h_idx]), SafeInt<uint32_t>(sizes_data[w_idx]), use_nchw, output));
}
const auto& output_shape = shaper[output];
int32_t output_h = use_nchw ? output_shape[2] : output_shape[1];
int32_t output_w = use_nchw ? output_shape[3] : output_shape[2];
int32_t output_h = output_shape[h_idx];
int32_t output_w = output_shape[w_idx];
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input));
@ -2420,7 +2235,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
OperandType output_operand_type = operand_types.at(input);
output_operand_type.SetDimensions(output_shape);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2436,10 +2251,6 @@ class FlattenOpBuilder : public BaseOpBuilder {
Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
auto input = node_unit.Inputs()[0].node_arg.Name();
if (model_builder.IsOperandNHWC(input)) {
// We want to transpose nhwc operand back to nchw before reshape
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
}
// Flatten is basically a reshape to 2d tensor
// Get the shape for Reshape here
@ -2467,8 +2278,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
static Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
const std::string& input1, const std::string& input2,
bool output_is_nhwc);
const std::string& input1, const std::string& input2);
};
/* static */ void MinMaxOpBuilder::CreateSharedOpBuilder(
@ -2482,8 +2292,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
}
/* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
const std::string& input1, const std::string& input2,
bool output_is_nhwc) {
const std::string& input1, const std::string& input2) {
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());
@ -2506,7 +2315,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output));
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
{output}, {output_operand_type}, {output_is_nhwc}));
{output}, {output_operand_type}));
return Status::OK();
}
@ -2515,11 +2324,8 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto& inputs = node_unit.Inputs();
std::string input1 = inputs[0].node_arg.Name();
std::string input2 = inputs[1].node_arg.Name();
bool output_is_nhwc = false;
ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node_unit,
input1, input2, output_is_nhwc));
return AddMinMaxOperator(model_builder, node_unit, input1, input2, output_is_nhwc);
return AddMinMaxOperator(model_builder, node_unit, input1, input2);
}
#pragma endregion
@ -2537,7 +2343,6 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
const auto& operand_types(model_builder.GetOperandTypes());
const auto& input = node_unit.Inputs()[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
NodeAttrHelper helper(node_unit);
@ -2546,7 +2351,7 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
input_indices.push_back(operand_indices.at(input));
ADD_SCALAR_OPERAND(model_builder, input_indices, alpha);
return model_builder.AddOperation(ANEURALNETWORKS_ELU, input_indices,
{output}, {output_operand_type}, {output_is_nhwc});
{output}, {output_operand_type});
}
#pragma endregion
@ -2644,7 +2449,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto& input = inputs[0].node_arg.Name();
const auto& output = node_unit.Outputs()[0].node_arg.Name();
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
// No shape inference for Slice, everything is calculated here, we only need to add the output shape
// to the shaper
@ -2718,7 +2522,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // end_mask
ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // shrink_axis_mask
}
return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}, {output_is_nhwc});
return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type});
}
#pragma endregion

View file

@ -72,6 +72,10 @@ static bool HasExternalInitializer(const InitializedTensorSet& initializers, con
return false;
}
inline bool IsNodeLayoutNHWC(const NodeUnit& node_unit) {
return node_unit.Domain() == kMSInternalNHWCDomain;
}
static bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers,
const NodeUnitIODef& io_def,
const OpSupportCheckParams& params,
@ -1768,7 +1772,7 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
}
const float* scales_data = reinterpret_cast<const float*>(unpacked_tensor.data());
float scale_n = scales_data[0];
float scale_c = scales_data[1];
float scale_c = IsNodeLayoutNHWC(node_unit) ? scales_data[3] : scales_data[1];
if (scale_n != 1.0f || scale_c != 1.0f) {
LOGS_DEFAULT(VERBOSE) << "Scales of N/C channel should be 1"
<< "Resize of N/C channels are not supported"
@ -1785,14 +1789,16 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
LOGS_DEFAULT(ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage();
return false;
}
int channel_idx = IsNodeLayoutNHWC(node_unit) ? 3 : 1;
const int64_t* sizes_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
uint32_t size_n = SafeInt<uint32_t>(sizes_data[0]);
uint32_t size_c = SafeInt<uint32_t>(sizes_data[1]);
if (size_n != input_shape[0] || size_c != input_shape[1]) {
LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C chanel should match the input sizes, "
uint32_t size_c = SafeInt<uint32_t>(sizes_data[channel_idx]);
if (size_n != input_shape[0] || size_c != input_shape[channel_idx]) {
LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C channel should match the input sizes, "
<< "Resize of N/C channels are not supported"
<< ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n
<< ". input_size_c, " << input_shape[1] << ", output_size_c, " << size_c;
<< ". input_size_c, " << input_shape[channel_idx] << ", output_size_c, " << size_c;
return false;
}
}

View file

@ -176,7 +176,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
};
result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
gen_metadef_name, NNAPI);
gen_metadef_name, NNAPI, kNnapiExecutionProvider);
const auto num_of_partitions = result.size();
const auto num_of_supported_nodes = std::transform_reduce(
@ -203,6 +203,10 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
return result;
}
DataLayout NnapiExecutionProvider::GetPreferredLayout() const {
return nnapi_flags_ & NNAPI_FLAG_USE_NCHW ? DataLayout::NCHW : DataLayout::NHWC;
}
#ifdef __ANDROID__
static Status GetOutputBuffer(Ort::CustomOpApi& ort,
OrtKernelContext* context,

View file

@ -33,6 +33,8 @@ class NnapiExecutionProvider : public IExecutionProvider {
uint32_t GetNNAPIFlags() const { return nnapi_flags_; }
DataLayout GetPreferredLayout() const override;
private:
// The bit flags which define bool options for NNAPI EP, bits are defined as
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h

View file

@ -87,6 +87,7 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
const GraphViewer& graph_viewer,
const IsNodeSupportedFn& is_node_supported_fn,
const OnGroupClosedFn& on_group_closed_fn,
const std::string& execution_provider_type,
bool debug_output) {
#ifdef NDEBUG
ORT_UNUSED_PARAMETER(debug_output);
@ -104,7 +105,8 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
std::deque<const Node*> nodes_to_process_with_next_group{};
// initialize in-degrees and find root nodes
for (const auto& node : graph_viewer.Nodes()) {
for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto& node = *graph_viewer.GetNode(node_index);
const auto node_input_edge_count = node.GetInputEdgesCount();
in_degree.insert({node.Index(), node_input_edge_count});
if (node_input_edge_count == 0) {
@ -156,9 +158,9 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
const Node& node = *nodes_to_process.front();
nodes_to_process.pop_front();
// a node that is already assigned to an EP other than current EP is unsupported
const bool is_node_supported =
node.GetExecutionProviderType().empty() && // a node that is already assigned to an EP is unsupported
is_node_supported_fn(node);
(node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node);
if (!is_node_supported && Contains(supported_group_border, &node)) {
// an unsupported node on the border will be processed after the current partition node group
@ -203,9 +205,7 @@ std::unordered_set<const Node*> CreateExcludedNodeSet(const GraphViewer& graph_v
const std::unordered_set<std::string>& stop_ops) {
std::unordered_set<const Node*> excluded_nodes;
for (const NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) {
const Node& node = *graph_viewer.GetNode(node_index);
for (const auto& node : graph_viewer.Nodes()) {
if (!Contains(excluded_nodes, &node) && Contains(stop_ops, node.OpType())) {
excluded_nodes.insert(&node);
@ -309,10 +309,12 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const OnGroupClosedFn& on_partition_closed_fn,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
bool debug_output) {
const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer,
is_node_supported_fn,
on_partition_closed_fn,
execution_provider_type,
debug_output);
std::vector<std::unique_ptr<ComputeCapability>> partitions{};
@ -335,6 +337,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const std::unordered_set<std::string>& stop_ops,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
bool debug_output) {
const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops);
const bool check_excluded_nodes = !excluded_nodes.empty();
@ -348,6 +351,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
{},
generate_metadef_name_fn,
execution_provider_name,
execution_provider_type,
debug_output);
}

View file

@ -54,6 +54,7 @@ Create the supported partitions for the execution provider.
@param on_group_closed_fn Callback to indicate a completed partition node group.
@param generate_metadef_name_fn Callback to create the name for the MetaDef.
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
No-op in a release build.
@ -65,6 +66,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const OnGroupClosedFn& on_group_closed_fn,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
bool debug_output = false);
/**
@ -75,6 +77,7 @@ Create the supported partitions for the execution provider.
@param stop_ops Set of operator names at which we stop considering nodes for assignment to this execution provider.
@param generate_metadef_name Functor to create the name for the MetaDef.
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
No-op in a release build.
@ -86,6 +89,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const std::unordered_set<std::string>& stop_ops,
const GenerateMetadefNameFn& generate_metadef_name,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
bool debug_output = false);
/**

View file

@ -135,6 +135,7 @@ struct KernelRegistry;
struct Function;
struct Graph;
struct GraphViewer;
enum class DataLayout;
struct Model;
struct Path;
struct Node;

View file

@ -222,6 +222,7 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
}
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSInternalNHWCDomain, 1, 1);
#ifdef USE_DML
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1);
#endif

View file

@ -32,6 +32,7 @@
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/utils.h"
#include "core/framework/static_kernel_def_hashes.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_transformer_utils.h"
@ -40,6 +41,7 @@
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#include "core/optimizer/transformer_memcpy.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
#include "core/platform/Barrier.h"
#include "core/platform/ort_mutex.h"
#include "core/platform/threadpool.h"
@ -913,7 +915,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Do partitioning based on execution providers' capability.
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
session_state.GetMutableFuncMgr(), mode));
session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayout, mode));
// apply transformers except default transformers
// Default transformers are required for correctness and they are owned and run by inference session
@ -1138,6 +1141,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(),
session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayout,
GraphPartitioner::Mode::kOrtFormatLoad,
&compiled_kernel_hashes));
@ -1211,6 +1215,17 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs
graph.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
ORT_RETURN_IF_ERROR(set_node_ep(node_index, kernel_def_hash));
}
// layout transformer which is enabled in extended minimal build can add new nodes.
// The following loop fetches the hash values for these nodes.
for (const auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType().empty()) {
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
}
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
return Status::OK();
@ -1236,7 +1251,7 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
}
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
//VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
#pragma warning(disable : 26117)
#endif
common::Status InferenceSession::Initialize() {

View file

@ -20,6 +20,7 @@
#include "gtest/gtest.h"
#include "test/test_environment.h"
#include "test/util/include/default_providers.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
using namespace ONNX_NAMESPACE;
using namespace std;
@ -142,7 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
DefaultLoggingManager().DefaultLogger(), profiler);
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayout);
ASSERT_TRUE(status.IsOK()) << status;
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
@ -207,7 +209,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
// Partition the graph
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayout);
ASSERT_TRUE(status.IsOK()) << status;
// Finalize the session state
@ -256,7 +259,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
// Partition the graph
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayout);
ASSERT_TRUE(status.IsOK()) << status;
// Finalize the session state

View file

@ -689,14 +689,15 @@ class ResizeOpTester : public OpTester {
OpTester::AddNodes(graph, graph_input_defs, graph_output_defs, add_attribute_funcs);
// set the Graph inputs to just X and roi (exclude 'scales') so the 'scales' are a constant initializer
if (scales_in_initializer_) {
// this isn't intended to work with a scenario where the optional 'sizes' input is provided
ASSERT_TRUE(graph_input_defs.size() == 3);
if(scales_in_initializer_) {
graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()),
graph.GetNodeArg(graph_input_defs[1]->Name())});
}
if (sizes_in_initializer_) {
if(sizes_in_initializer_) {
ASSERT_TRUE(graph_input_defs.size() == 4);
} else {
ASSERT_TRUE(graph_input_defs.size() == 3);
}
} else if (sizes_in_initializer_) {
ASSERT_TRUE(graph_input_defs.size() == 4); // 'sizes' is 4th input
graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()),
graph.GetNodeArg(graph_input_defs[1]->Name()),
@ -744,9 +745,10 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Scales) {
TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) {
auto run_test = [](bool sizes_in_initializer) {
ResizeOpTester test(false, sizes_in_initializer);
ResizeOpTester test(sizes_in_initializer, sizes_in_initializer);
std::vector<float> roi{};
std::vector<float> scales{};
std::vector<int64_t> sizes{1, 1, 4, 4};
test.AddAttribute("mode", "nearest");
@ -760,7 +762,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) {
test.AddInput<float>("X", {N, C, H, W}, X);
test.AddInput<float>("roi", {0}, roi);
test.AddInput<float>("scales", {0}, {});
test.AddInput<float>("scales", {0}, scales, sizes_in_initializer);
test.AddInput<int64_t>("sizes", {4}, sizes, sizes_in_initializer);
std::vector<float> Y = {1.0f, 1.0f, 2.0f, 2.0f,

View file

@ -13,6 +13,7 @@
#include "core/graph/model.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
#include <queue>
@ -22,12 +23,14 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP";
InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops,
const std::unordered_set<std::string>& stop_ops,
bool debug_output)
bool debug_output,
DataLayout preferred_layout)
: IExecutionProvider{utils::kInternalTestingExecutionProvider, true},
ep_name_{INTERNAL_TESTING_EP},
ops_{ops},
stop_ops_{stop_ops},
debug_output_{debug_output} {
debug_output_{debug_output},
preferred_layout_{preferred_layout} {
//
// TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through
// the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first
@ -44,6 +47,10 @@ InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::un
InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {}
DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const {
return preferred_layout_;
}
std::vector<std::unique_ptr<ComputeCapability>>
InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& /*registries*/) const {
@ -89,7 +96,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
};
return utils::CreateSupportedPartitions(graph_viewer, supported_nodes, stop_ops_,
generate_metadef_name, ep_name_, debug_output_);
generate_metadef_name, ep_name_, onnxruntime::utils::kInternalTestingExecutionProvider, debug_output_);
}
common::Status InternalTestingExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
@ -99,14 +106,19 @@ common::Status InternalTestingExecutionProvider::Compile(const std::vector<Fused
NodeComputeInfo compute_info;
const Node& node = node_and_viewer.fused_node;
//{
// const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
// std::cout << "Fusing nodes: ";
// for (const auto& unfused_node : graph_viewer.Nodes()) {
// std::cout << " '" << unfused_node.Name() << "':" << unfused_node.Index();
// }
// std::cout << std::endl;
//}
if (preferred_layout_ == DataLayout::NHWC) {
const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
auto layout_sensitive_ops = layout_transformer::GetORTLayoutSensitiveOps();
for (const auto& unfused_node : graph_viewer.Nodes()) {
std::cout << unfused_node.OpType() << std::endl;
if (layout_sensitive_ops.count(unfused_node.OpType()) && unfused_node.Domain() != kMSInternalNHWCDomain) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Found a layout sensitive op which is still in NCHW format. Node: ",
unfused_node.OpType(), " ", unfused_node.Name(),
" The preferrd layout for this EP is NHWC. This is a possible bug in layout transformer.");
}
}
}
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) {
return 0;

View file

@ -10,7 +10,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
public:
InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops,
const std::unordered_set<std::string>& stop_ops = {},
bool debug_output = false);
bool debug_output = false,
DataLayout preferred_layout = static_cast<DataLayout>(0));
virtual ~InternalTestingExecutionProvider();
std::vector<std::unique_ptr<ComputeCapability>>
@ -24,6 +25,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
return FusionStyle::FilteredGraphViewer;
}
DataLayout GetPreferredLayout() const override;
private:
const std::string ep_name_;
@ -38,5 +41,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
const std::unordered_set<std::string> stop_ops_;
const bool debug_output_;
DataLayout preferred_layout_;
};
} // namespace onnxruntime

View file

@ -31,6 +31,7 @@ namespace test {
// it would be possible to use ORT format models but the same partitioning code would run either way
#if !defined(ORT_MINIMAL_BUILD)
#define ORT_MODEL_FOLDER ORT_TSTR("testdata/")
// model has an unsupported node between the supported nodes after the initial topo sort.
// the partition aware topo sort should result in the unsupported node moving to earlier in the order,
// and allow a single partition of supported nodes to be created.
@ -44,7 +45,7 @@ TEST(InternalTestingEP, TestSortResultsInSinglePartition) {
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_1.onnx");
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_1.onnx";
ASSERT_STATUS_OK(session->Load(model_path));
const auto& graph = session->GetGraph();
GraphViewer viewer{graph};
@ -89,7 +90,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) {
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_2.onnx");
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_2.onnx";
ASSERT_STATUS_OK(session->Load(model_path));
const auto& graph = session->GetGraph();
GraphViewer viewer{graph};
@ -195,7 +196,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin
}
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
std::make_unique<InternalTestingExecutionProvider>(ops, stop_ops, debug_output)));
std::make_unique<InternalTestingExecutionProvider>(ops, stop_ops, debug_output, DataLayout::NHWC)));
ASSERT_STATUS_OK(session->Load(model_uri));
const auto& graph = session->GetGraph();
@ -310,6 +311,7 @@ TEST(InternalTestingEP, DISABLED_TestNnapiPartitioningMlPerfModels) {
"deeplabv3_mnv2_ade20k_float.onnx",
"mobilebert.onnx",
"mobiledet.onnx",
};
for (const auto& model_uri : model_paths) {

View file

@ -24,8 +24,10 @@ namespace onnxruntime {
namespace test {
#define ORT_MODEL_FOLDER ORT_TSTR("testdata/")
static void CreateSession(const SessionOptions& so, std::unique_ptr<InferenceSessionWrapper>& session,
const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "mnist.onnx", // arbitrary test model
bool enable_custom_ep = true,
const std::unordered_set<std::string>* override_supported_ops = nullptr) {
session = std::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
@ -86,7 +88,7 @@ static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enable
#if !defined(DISABLE_SPARSE_TENSORS)
#if !defined(ORT_MINIMAL_BUILD)
TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.test_output.ort");
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.test_output.ort";
//
// First load the onnx format model and save as an ORT model.
@ -130,7 +132,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
}
TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
// make sure we can't save a model with compiled ops. input/output model format doesn't matter
SessionOptions so;
@ -152,7 +154,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
// test to validate a minimal build
TEST(InternalTestingEP, TestLoadOrtModel) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
std::unique_ptr<InferenceSessionWrapper> session;
bool enable_custom_ep = true;
@ -164,7 +166,7 @@ TEST(InternalTestingEP, TestLoadOrtModel) {
// test that is the custom EP cannot take all nodes due to device limitations
// that we fallback to the CPU implementations and can execute the model
TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/};
std::unique_ptr<InferenceSessionWrapper> session;
@ -225,7 +227,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph,
// Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs.
// There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels.
TEST(InternalTestingEP, TestModelWithSubgraph) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "ort_github_issue_4031.onnx.ort";
const std::unordered_set<std::string> supported_ops{"Add"};
std::unique_ptr<InferenceSessionWrapper> session;
@ -302,8 +304,11 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
// In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected
// So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider
// But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node
// This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
// Post layout transformations we cannot revert back if compile fails because
// the layout transformation for this EP is already done at this stage and reverting
// can result in more failures.
// This is to test the model initialization fails if compile fails.
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
const std::unordered_set<std::string>& supported_ops{"Conv", "Gemm"};
const std::unordered_set<std::string>& compile_failure_ops{"Gemm"};
@ -325,33 +330,12 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
}
// Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm"
// We should have 2 partitions taken by the EP
// 2 Conv
{
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
std::make_unique<CompileFailureTestExecutionProvider>(supported_ops, compile_failure_ops)));
ASSERT_STATUS_OK(session.Load(ort_model_path));
ASSERT_STATUS_OK(session.Initialize());
// 2 Conv nodes shoule be replaced with fused nodes
const auto& graph = session.GetGraph();
int num_replaced_nodes = CountAndValidateAssignedNodes(
session.GetGraph(), {"Conv"}, const_cast<SessionState&>(session.GetSessionState()).GetMutableFuncMgr());
ASSERT_EQ(num_replaced_nodes, 2);
// The Gemm node should still not have been replaced
int count_compile_failure_nodes = 0;
for (const auto& node : graph.Nodes()) {
if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end())
count_compile_failure_nodes++;
}
ASSERT_EQ(count_compile_failure_nodes, 1);
// Execute the session, since the last node is Gemm, and its input 0 is all 0s
// So the result should be the bias initializer of the Gemm node
ExecuteMnist(session, true /* enable_custom_ep */);
ASSERT_STATUS_NOT_OK(session.Initialize());
}
}
} // namespace test

View file

@ -80,6 +80,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/platform/env_var_utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "gtest/gtest.h"
using json = nlohmann::json;
@ -189,5 +190,30 @@ TEST(KernelDefHashTest, ExpectedCpuKernelDefHashes) {
CheckKernelDefHashes(cpu_kernel_def_hashes, expected_cpu_kernel_def_hashes, is_strict);
}
// This test is to ensure the latest opset version for ops which can be added
// during layout transformation step are added. IF this test fails then it means
// there is a new version available for one of the ops in the map.
// Adding this test here because resolution for this test failure requires fetching the hash
// for one of the ops in the list below and this file has information around that.
// Please update the following 3 places:
// 1. api_impl.cc "onnx_ops_available_versions" map, include the latest version in the map
// 2. static_kernel_def_hashes.cc "static_kernel_hashes" include an entry for latest version and it's associated hash
// 3. This file "onnx_ops_available_versions" map, include the latest version in the map
TEST(KernelDefHashTest, TestNewOpsVersionSupportDuringLayoutTransform) {
static const std::unordered_map<std::string, std::vector<int>> onnx_ops_available_versions = {
{"Squeeze", {1, 11, 13}},
{"Unsqueeze", {1, 11, 13}},
{"Gather", {1, 11, 13}},
{"Transpose", {1, 13}},
{"Identity", {1, 13, 14, 16}},
};
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
for (const auto& [op_type, version_list] : onnx_ops_available_versions) {
auto schema = schema_registry->GetSchema(op_type, INT_MAX, kOnnxDomain);
EXPECT_EQ(schema->SinceVersion(), version_list[version_list.size() - 1]) << "A new version for op: " << op_type
<< "is available. Please update the files mentioned in the comments of this test.";
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -315,12 +315,17 @@ TEST(NnapiExecutionProviderTest, TestQDQConv) {
TEST(NnapiExecutionProviderTest, TestQDQResize) {
// NNAPI EP does not support the default setting of Resize Op
// Use bi-linear and asymmetric for NNAPI EP only
// Setting verify_entire_graph_use_ep for this test as false. This is because layout transformation adds
// Transpose (NCHW -> NHWC) nodes. Post tranformation graph looks like this Transpose -> DQ -> Resize -> Q -> Transpose
// NNAPI does not pick the first Transpose as its input is graph/partition input
// See https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc#L305
// onnxruntime::nnapi::IsInternalQuantizationSupported
RunQDQModelTest(BuildQDQResizeTestCase({1, 3, 64, 64} /* input_shape */,
{1, 3, 32, 32} /* sizes_data */,
"linear" /* mode */,
"asymmetric" /* coordinate_transformation_mode */),
"nnapi_qdq_test_graph_resize",
{true /* verify_entire_graph_use_ep */});
{false /* verify_entire_graph_use_ep */});
}
TEST(NnapiExecutionProviderTest, TestQDQAveragePool) {