From f436d3437e85fa0c28a6264dd361958ae4d63e08 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 15 Feb 2022 20:25:29 -0800 Subject: [PATCH] Add layout transformer for NNAPI (#10371) * Add layout transformer for NNAPI * plus merge fixes * plus some more merge fixes * test fixes * comments + cleanup * plus updates * post merge changes * enable layout transformer in extended minimal build * plus more comments * more tests + fix CI * plus updates per review * more updates per review * fix file name * fix qdq tests * plus more updates * plus updates * typo fix * fix qdq selection in 2nd optimization pass * fix typo * fix a test * update dependency structure for layout transformer * plus updates * more updates * plus change * more updates to fix linker error in minimal build * remove unnecessary headers --- cmake/CMakeLists.txt | 9 +- cmake/onnxruntime_optimizer.cmake | 1 - .../core/framework/execution_provider.h | 8 +- include/onnxruntime/core/graph/constants.h | 1 + include/onnxruntime/core/graph/graph.h | 6 + include/onnxruntime/core/graph/graph_viewer.h | 6 + .../core/framework/graph_partitioner.cc | 248 +++++++--- .../core/framework/graph_partitioner.h | 9 +- onnxruntime/core/framework/session_state.cc | 26 +- .../framework/static_kernel_def_hashes.cc | 37 ++ .../core/framework/static_kernel_def_hashes.h | 19 + .../core/optimizer/nhwc_transformer.cc | 4 +- .../selectors_actions/shared/utils.cc | 4 +- .../{api.h => optimizer_api.h} | 54 ++- .../{api_impl.cc => optimizer_api_impl.cc} | 215 ++++++++- .../{api_impl.h => optimizer_utils.h} | 37 +- .../ort_transpose_optimizer.cc | 5 +- .../transpose_optimizer.cc | 444 ++++++++++++------ .../coreml/coreml_execution_provider.cc | 2 +- .../nnapi_builtin/builders/model_builder.cc | 68 +-- .../nnapi_builtin/builders/model_builder.h | 29 +- .../nnapi_builtin/builders/op_builder.cc | 346 +++----------- .../builders/op_support_checker.cc | 16 +- .../nnapi_builtin/nnapi_execution_provider.cc | 6 +- .../nnapi_builtin/nnapi_execution_provider.h | 2 + .../core/providers/partitioning_utils.cc | 16 +- .../core/providers/partitioning_utils.h | 4 + .../providers/shared_library/provider_api.h | 1 + onnxruntime/core/session/environment.cc | 1 + onnxruntime/core/session/inference_session.cc | 19 +- .../test/framework/session_state_test.cc | 10 +- .../providers/cpu/tensor/resize_op_test.cc | 18 +- .../internal_testing_execution_provider.cc | 34 +- .../internal_testing_execution_provider.h | 7 +- .../internal_testing_partitioning_tests.cc | 8 +- .../internal_testing_tests.cc | 44 +- .../test/providers/kernel_def_hash_test.cc | 26 + .../test/providers/nnapi/nnapi_basic_test.cc | 7 +- 38 files changed, 1097 insertions(+), 700 deletions(-) create mode 100644 onnxruntime/core/framework/static_kernel_def_hashes.cc create mode 100644 onnxruntime/core/framework/static_kernel_def_hashes.h rename onnxruntime/core/optimizer/transpose_optimizer/{api.h => optimizer_api.h} (91%) rename onnxruntime/core/optimizer/transpose_optimizer/{api_impl.cc => optimizer_api_impl.cc} (73%) rename onnxruntime/core/optimizer/transpose_optimizer/{api_impl.h => optimizer_utils.h} (52%) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5abf37a6d5..c6353a01dd 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -347,10 +347,11 @@ if (onnxruntime_MINIMAL_BUILD) if (onnxruntime_EXTENDED_MINIMAL_BUILD) # enable EPs that compile kernels at runtime add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD) - - if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) - add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) - endif() + # enable runtime optimizations. These are required for NNAPI. + # TODO remove onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD since these optimzations are now + # enabled for all extended minimal builds. + SET(onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD ON) + add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) endif() if (onnxruntime_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index ecf28ffb82..9311c88c70 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -38,7 +38,6 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc" - "${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.cc" ) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1091271966..0b351c865a 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -14,7 +14,7 @@ #include "core/framework/tensor.h" namespace onnxruntime { - +enum class DataLayout; class GraphViewer; class Node; struct ComputeCapability; @@ -274,6 +274,12 @@ class IExecutionProvider { return {}; } + virtual DataLayout GetPreferredLayout() const { + // NCHW is the default ONNX standard data layout. So default to it. + // EPs which prefer a different layout should override to return their preferred layout. + return static_cast(0); + } + private: const std::string type_; AllocatorMap allocators_; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index d7cd4a919a..0990e09bbf 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -15,6 +15,7 @@ constexpr const char* kMLDomain = "ai.onnx.ml"; constexpr const char* kMSDomain = "com.microsoft"; constexpr const char* kMSExperimentalDomain = "com.microsoft.experimental"; constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; +constexpr const char* kMSInternalNHWCDomain = "com.ms.internal.nhwc"; constexpr const char* kMSDmlDomain = "com.microsoft.dml"; constexpr const char* kNGraphDomain = "com.intel.ai"; constexpr const char* kMIGraphXDomain = ""; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 421d01fedd..307ee4701b 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -151,6 +151,12 @@ class Node { */ int SinceVersion() const noexcept { return since_version_; } + /** Sets the since version (opset version that the Node's operator was first defined in.) for this node. + @remarks Used during layout transformation for setting since vesion for layout transformed nodes with + domain kMSNHWC. + */ + void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; } + #if !defined(ORT_MINIMAL_BUILD) /** Gets the Node's OpSchema. @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */ diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index c1041bd58a..fd2af7b7d2 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -19,6 +19,12 @@ struct NodeCompare { bool operator()(const Node* n1, const Node* n2) const; }; +enum class DataLayout { + NCHW, + NHWC, + NCHWC, +}; + /** @class GraphViewer Class that provides a read-only view of the Graph. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 1ebccacad4..57b5d6c7f8 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -52,7 +52,97 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph .Provider(provider_type); } +/// +/// Validate all the layout sensitive nodes which were transformed for current EP are indeed taken by current EP. +/// If not, then we have a bug. If a node with domain kMSNHWC is left in the graph at this point then +/// graph.Resolve will fail. +/// Since layout transformation is only enabled for compile based EPs, just checking that graph does not contain +/// a node with kMSNHWC domain is enough. This is because after compile all the nodes which the EP claims are fused +/// into 1 and removed from the graph. +/// +/// Graph to validate +/// +static Status ValidateGraphPartitioning(const Graph& graph) { + for (const auto& node : graph.Nodes()) { + if (node.Domain() == kMSInternalNHWCDomain) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Graph contains an invalid node: " + node.Name() + " Op Type: " + node.OpType() + + " with domain: " + kMSInternalNHWCDomain + ". These are temporary nodes added during layout transformations " + + " and are not expected to remain in the graph post partitioning. This is a bug in layout transformer."); + } + } + return Status::OK(); +} + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) + +/// +/// Check if a node can be placed on a specific provider. If yes, then set the nodes execution provider. +/// Do nothing if the node is already assigned. +/// +/// Graph in question. +/// Indexed subgraph which needs to be assigned +/// The EP to assign the Indexed subgraph to +static void AssignNodes(Graph& graph, const IndexedSubGraph& capability, + const std::string& provider_type) { + // Before assigning the ep to any node, first walk through all the nodes and ensure + // none of the nodes have already been assigned. If a node is assigned, simply return. + for (auto node_index : capability.nodes) { + const auto* node = graph.GetNode(node_index); + if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { + return; + } + } + + for (auto node_index : capability.nodes) { + auto* node = graph.GetNode(node_index); + node->SetExecutionProviderType(provider_type); + } +} + +#endif //! defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) + +static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, + GraphPartitioner::Mode mode, std::vector>& capabilities, + TransformLayoutFunction transform_layout) { + { + GraphViewer graph_viewer(graph); + capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type())); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) + // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // CPU EP layout transformation happens later when level 3 transformers are run. + if (mode != GraphPartitioner::Mode::kAssignOnly && + current_ep.GetPreferredLayout() == DataLayout::NHWC) { + for (auto& capability : capabilities) { + // in theory an EP could return an empty value... + if (!capability || !capability->sub_graph) { + continue; + } + AssignNodes(graph, *capability->sub_graph, current_ep.Type()); + } + + // Perform layout transformation on the specific EP assigned graph + bool modified = false; + ORT_RETURN_IF_ERROR(transform_layout(graph, modified, current_ep)); + + // It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes + // which are reconstructed to update domain or completly new nodes which are necessary for layout transformation. + // Therefore, we re-run GetCapability so that these new nodes can be processed by this EP. + if (modified) { + capabilities.clear(); + GraphViewer graph_viewer(graph); + capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type())); + } + } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) + + return Status::OK(); +} + #if !defined(ORT_MINIMAL_BUILD) + static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) { auto schema = node.Op(); builder.SetName(schema->Name()) @@ -94,23 +184,23 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done // in order of EP priority bool sub_graph_available_for_assignment = true; - for (auto node_index : capability.nodes) { - const auto* node = graph.GetNode(node_index); - if (nullptr == node || !node->GetExecutionProviderType().empty()) { - // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, - // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by - // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to - // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes - // should never be on those lists. - // - // when the ORT format model is loaded we will process it normally with EP priority being applied for - // whichever EPs are enabled at the time. - // - // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. - // We want the ORT format model to be able to be run as efficiently as possible on either platform, - // so we want all the nodes that either may take to be preserved. If we did not do this we would - // need to create one ORT format model for Android and one for iOS. - if (mode != GraphPartitioner::Mode::kAssignOnly) { + if (mode != GraphPartitioner::Mode::kAssignOnly) { + // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, + // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by + // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to + // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes + // should never be on those lists. + // + // when the ORT format model is loaded we will process it normally with EP priority being applied for + // whichever EPs are enabled at the time. + // + // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. + // We want the ORT format model to be able to be run as efficiently as possible on either platform, + // so we want all the nodes that either may take to be preserved. If we did not do this we would + // need to create one ORT format model for Android and one for iOS. + for (auto node_index : capability.nodes) { + const auto* node = graph.GetNode(node_index); + if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { // The node was fused or assigned, so that the whole sub-graph will not be assigned to this // The assumption is that this can only run the sub-graph as a whole unit. sub_graph_available_for_assignment = false; @@ -165,7 +255,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa KernelRegistry& fused_kernel_registry, IExecutionProvider& current_ep, GraphPartitioner::Mode mode, - int& fused_node_unique_id) { + int& fused_node_unique_id, + TransformLayoutFunction transform_layout_function) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -178,28 +269,32 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa Graph* subgraph = entry.second; // we pass through the export_dll value and FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, - fused_kernel_registry, current_ep, mode, fused_node_unique_id)); + fused_kernel_registry, current_ep, mode, fused_node_unique_id, + transform_layout_function)); } } - // If an execution provider return the capability that he could run a sub-graph, - // onnxruntime will fuse the sub-graph into a function node. if the execution provider - // says he need to compile the graph at runtime (by need_compile flag), - // onnxruntime will invoke the "Compile" method to get compiled binary. - // There are two mode of compile, one is return the entry point to the compiled binary - // directly, another is export the compiled binary to shared library for future reuse. + // If an execution provider returns the capability that it can run a sub-graph, + // onnxruntime will fuse the sub-graph into a function node. For compilation + // based execution providers (one which needs to compile graph at runtime. + // Indicated by need_compile flag), onnxruntime will invoke the "Compile" method + // to get compiled binary. There are two mode of compile, one is return the entry + // point to the compiled binary directly, another is export the compiled binary to + // shared library for future reuse. - // TODO: when the graph contain a function node, and user pass in the dll which could + // TODO: when the graph contains a function node, and user passes in the dll which could // run the function by SessionOption, we should create a function kernel for it and // delegate the compute to the functions inside the dlls. + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep, mode, capabilities, + transform_layout_function)); + if (capabilities.empty()) { + return Status::OK(); + } + const std::string& type = current_ep.Type(); auto fusion_style = current_ep.GetFusionStyle(); std::vector nodes_to_compile; - - GraphViewer graph_viewer(graph); - std::vector> capabilities = - current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type)); - // filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with // nodes_to_compile. std::vector> capabilities_to_compile; @@ -211,7 +306,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa })); for (auto& capability : capabilities) { - if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value... + // in theory an EP could return an empty value... + if (!capability || !capability->sub_graph) { continue; } @@ -305,6 +401,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa graph.FinalizeFuseSubGraph(indexed_sub_graph, *node); } } + + ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph)); } // if this is the main graph call Resolve to put the Graph back into a guaranteed good state @@ -362,14 +460,16 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, Mode mode, - int& fused_node_unique_id) const { + int& fused_node_unique_id, + TransformLayoutFunction transform_layout_function) const { bool modified_graph = false; do { // process full graph with each EP for (const auto& ep : providers_) { ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_, - fused_kernel_registry, *ep, mode, fused_node_unique_id)); + fused_kernel_registry, *ep, mode, fused_node_unique_id, + transform_layout_function)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -392,13 +492,15 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, IExecutionProvider& current_ep, std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id) { + int& fused_node_unique_id, + TransformLayoutFunction transform_layout_function) { // recurse into nested graphs first to partition bottom up. for (auto& node : graph.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { Graph* subgraph = entry.second; ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, - current_ep, compiled_kernel_hashes, fused_node_unique_id)); + current_ep, compiled_kernel_hashes, fused_node_unique_id, + transform_layout_function)); } } @@ -409,12 +511,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, } const std::string& type = current_ep.Type(); - GraphViewer graph_viewer(graph); std::vector nodes_and_viewers; - - std::vector> capabilities = - current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type)); - + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep, + GraphPartitioner::Mode::kOrtFormatLoad, capabilities, transform_layout_function)); if (capabilities.empty()) { return Status::OK(); } @@ -451,40 +551,37 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) { Node& node = nodes_and_viewers[j].fused_node; std::vector single_node_compute_func; - auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func); - if (!status.IsOK()) { - // There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph - LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage(); - graph.CancelFuseSubGraph(node); - } else { - ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements"); - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0]))); + ORT_RETURN_IF_ERROR(current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func)); - const auto& cur_capability = capabilities[j]; - const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; - const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); + ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element."); + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0]))); - KernelDefBuilder builder; - BuildFusedKernelDef(builder, metadef, type); - auto kernel_def = builder.Build(); + const auto& cur_capability = capabilities[j]; + const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); - // save hash so SessionState can find the kernel. each kernel name should be unique - if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { - ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, - ". Execution Provider must generate unique names across the entire model."); - } + KernelDefBuilder builder; + BuildFusedKernelDef(builder, metadef, type); + auto kernel_def = builder.Build(); - ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( - KernelCreateInfo(std::move(kernel_def), - [](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr& out) -> Status { - return FunctionKernel::Create(func_mgr, info, out); - }))); - - // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one - graph.FinalizeFuseSubGraph(indexed_sub_graph, node); + // save hash so SessionState can find the kernel. each kernel name should be unique + if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { + ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, + ". Execution Provider must generate unique names across the entire model."); } + + ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( + KernelCreateInfo(std::move(kernel_def), + [](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + return FunctionKernel::Create(func_mgr, info, out); + }))); + + // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one + graph.FinalizeFuseSubGraph(indexed_sub_graph, node); } + ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph)); + return Status::OK(); } @@ -495,7 +592,8 @@ Status GraphPartitioner::PartitionOrtFormatModel( Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id) const { + int& fused_node_unique_id, + TransformLayoutFunction transform_layout_function) const { // process full graph with each EP for (const auto& ep : providers_) { if (ep->Type() == kCpuExecutionProvider) { @@ -505,13 +603,15 @@ Status GraphPartitioner::PartitionOrtFormatModel( } ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry, - *ep, compiled_kernel_hashes, fused_node_unique_id)); + *ep, compiled_kernel_hashes, fused_node_unique_id, + transform_layout_function)); } return Status::OK(); } -Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode, +Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, + TransformLayoutFunction transform_layout_function, Mode mode, std::unordered_map* compiled_kernel_hashes) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -535,16 +635,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode, - fused_node_unique_id)); + fused_node_unique_id, transform_layout_function)); #else ORT_UNUSED_PARAMETER(export_dll); ORT_THROW("Not supported in this build."); -#endif +#endif //! defined(ORT_MINIMAL_BUILD) } else { ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided"); ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes, - fused_node_unique_id)); + fused_node_unique_id, transform_layout_function)); } if (!fused_kernel_registry->IsEmpty()) { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 2926aaf410..db9c3fb31d 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,6 +15,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistry; class KernelRegistryManager; +using TransformLayoutFunction = std::function; class GraphPartitioner { public: @@ -31,7 +32,8 @@ class GraphPartitioner { } // Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad. - Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, + Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, + TransformLayoutFunction transform_layout_function, Mode mode = Mode::kNormal, std::unordered_map* compiled_kernel_hashes = nullptr) const; @@ -40,12 +42,13 @@ class GraphPartitioner { #if !defined(ORT_MINIMAL_BUILD) Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, - KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const; + KernelRegistry& fused_kernel_registry, Mode mode, + int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const; #endif Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, std::unordered_map& compiled_kernel_hashes, - int& fused_node_unique_id) const; + int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const; KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 7d91f03829..7ec09dc58e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -16,6 +16,7 @@ #include "core/framework/session_state_flatbuffers_utils.h" #include "core/framework/session_state_utils.h" #include "core/framework/utils.h" +#include "core/framework/static_kernel_def_hashes.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -601,7 +602,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span& auto* node = graph_viewer_->GetNode(node_plan.node_index); int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); - //allocate output + // allocate output for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); if (ml_value_idx == NodeIndexInfo::kInvalidEntry || @@ -631,7 +632,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span& } } - //release nodes + // release nodes for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) { auto ml_value_idx = exe_plan->to_be_freed[index]; const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type; @@ -1015,15 +1016,22 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) - // lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices + // lookup the hashes for any nodes we compiled or added during graph partitioning. + // These node indexes for compiled nodes as well as newly added nodes are not in node_indices // as they were created at runtime. - if (!compiled_kernel_hashes.empty()) { - for (const auto& node : graph_.Nodes()) { - if (kernel_create_info_map_.count(node.Index()) == 0) { + for (const auto& node : graph_.Nodes()) { + if (kernel_create_info_map_.count(node.Index()) == 0) { + if (node.Domain() == kOnnxDomain || node.Domain() == kOnnxDomainAlias) { + auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); + if (kernel_hash.has_value()) { + ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to find kernel hash for node:", node.Name(), " optype:", node.OpType()); + } + } else { const auto hash_info = compiled_kernel_hashes.find(node.OpType()); ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(), "Unable to find compiled kernel hash for node '", node.Name(), "'."); - ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second)); } } @@ -1284,7 +1292,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version) { + // Layout tranformer can add new nodes to the graph. + // Since layout transformation can happen in an extended build, if these nodes are not picked up and compiled by + // NNAPI or other compiling EPs then we need a way to get the hashes for these nodes. Since the infrastructure + // as well as op_schema required to generate these hashes is not available in an extended minimal build, + // we maintain a static map of nodes to hash value. This hash value can then be used to retireive the + // kernel for the given op. + static std::unordered_map static_kernel_hashes{ + {"Transpose_1", 4324835766923221184ULL}, + {"Transpose_13", 17267477159887372848ULL}, + {"Squeeze_1", 12889825108950034784ULL}, + {"Squeeze_11", 14725795030460042064ULL}, + {"Squeeze_13", 16122603335179721968ULL}, + {"UnSqueeze_1", 15964030255371555232ULL}, + {"UnSqueeze_11", 16989589986691430224ULL}, + {"UnSqueeze_13", 9466011545409597224ULL}, + {"Gather_1", 625186873870077080ULL}, + {"Gather_11", 11761559382112736008ULL}, + {"Gather_13", 7462749543760614528ULL}, + {"Identity_1", 18001636502361632792ULL}, + {"Identity_13", 16879814636194901248ULL}, + {"Identity_14", 16515685968327103576ULL}, + {"Identity_16", 17661628575887109792ULL}, + }; + + auto key = op_type + "_" + std::to_string(since_version); + auto iter = static_kernel_hashes.find(key); + if (iter != static_kernel_hashes.end()) { + return iter->second; + } + + return std::nullopt; +} +} \ No newline at end of file diff --git a/onnxruntime/core/framework/static_kernel_def_hashes.h b/onnxruntime/core/framework/static_kernel_def_hashes.h new file mode 100644 index 0000000000..cfe9698935 --- /dev/null +++ b/onnxruntime/core/framework/static_kernel_def_hashes.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "core/common/basic_types.h" +namespace onnxruntime { +/** + * @brief Gets the hash value for provided op type + version combination if it is available, otherwise + * returns a nullopt. The hash value is available if this node was added by layout transformer. For all other + * nodes, the hash values should be present either in the serialized session state obtained form ort format model + * or from compiled kernel hash map which is generated during partitioning. + * @return std::optional + */ +std::optional GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version); +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index 36ed8b38e9..01fe2bb5d3 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -6,7 +6,7 @@ #include "core/optimizer/initializer.h" #include "core/optimizer/nhwc_transformer.h" #include "core/optimizer/utils.h" -#include "core/optimizer/transpose_optimizer/api_impl.h" +#include "core/optimizer/transpose_optimizer/optimizer_utils.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; @@ -55,7 +55,7 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); if (domain != kMSDomain) { - SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", "com.microsoft"); + SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", kMSDomain); } modified = true; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index b734278653..b9c27625d0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -111,7 +111,9 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap std::vector qdq_selections; for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { const auto* node = graph_viewer.GetNode(index); - if (node->Domain() != kOnnxDomain) { + // post layout transformation all the layout sensitive nodes are converted to domain + // kMSInternalNHWCDomain. Therefore need to allow this domain as well. + if (node->Domain() != kOnnxDomain && node->Domain() != kMSInternalNHWCDomain) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimizer/api.h b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h similarity index 91% rename from onnxruntime/core/optimizer/transpose_optimizer/api.h rename to onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h index 61445c101f..0e31160357 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/api.h +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace onnx_layout_transformation { namespace api { @@ -204,7 +205,7 @@ class NodeRef { } std::string_view node_domain = Domain(); return node_domain == domain || - ((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx")); + ((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx")); } /// @@ -217,6 +218,19 @@ class NodeRef { return GetAttributeInt(name).value_or(default_value); } + /// + /// Returns the Execution Provider assigned to this node. Any empty string means this node is + /// not assigned to any EP. + /// + /// EP type or empty string + virtual const std::string& GetExecutionProviderType() const = 0; + + /// + /// Returns the schema since version for the op_type of this node. Value os -1 means it is not set. + /// + /// since version or default value -1 + virtual int SinceVersion() const = 0; + virtual ~NodeRef(){}; }; @@ -224,7 +238,6 @@ class NodeRef { /// Information regarding the consumers of a value. /// struct ValueConsumers { - /// /// List of nodes in the current graph containing value as an input /// @@ -335,6 +348,14 @@ class GraphRef { virtual std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain = "") = 0; + /// + /// Creates a copy of the provided node in the graph with the specified op type and domain. + /// + /// The new node's op type + /// The new node's domain. Empty string signifies default onnx domain. + /// The new node + virtual std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain = "") = 0; + /// /// Deletes a node from the graph. Behavior is undefined if node has any consumers. /// @@ -409,6 +430,17 @@ class GraphRef { constexpr int64_t kMinSupportedOpset = 7; constexpr int64_t kMaxSupportedOpset = 15; +enum class OptimizerMode { + OPTIMIZE_TRANSPOSE, // simple transpose optimization + OPTIMIZE_LAYOUT_TRANSFORM // transpose optimization post layout transformation +}; + +/// +/// Gets a list of layout sensitive ops defined by ONNX standard. +/// +/// const reference to an unordered set of op_types which are layout sensitive +const std::unordered_set& GetLayoutSensitiveOps(); + /// /// Performs transpose optimization on a graph. Returns true if the graph was modified. /// @@ -420,8 +452,16 @@ constexpr int64_t kMaxSupportedOpset = 15; /// /// The graph to optimize (or a portion of a graph, see api::GraphRef docs) /// Whether com.microsoft ops can be used for optimization +/// Execution provider if applicable. +/// Current mode. Optimizer can be called in the context of transpose optimizations or during layout transformations. +/// List of ops which are treated as layout sensitive by the ONNX standard as well as any runtime specific ops. +/// These ops should be provided when mode is set to OPTIMIZE_LAYOUT_TRANSFORM. If these ops are not provided, transpose optimizer may convert the +/// layout for these ops /// true if the graph was modified -bool Optimize(api::GraphRef& graph, bool allow_extended_ops); +bool Optimize(api::GraphRef& graph, bool allow_extended_ops, + const std::string& provider_type = "", + OptimizerMode mode = OptimizerMode::OPTIMIZE_TRANSPOSE, + const std::unordered_set& layout_sensitive_ops = {}); /* Layout Transformation Tools * These methods help change the channel ordering of layout sensitive ops (like Conv). ONNX currently only supports @@ -429,14 +469,14 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops); * the new ordering. The existence of a robust transpose optimizer means that we can freely add transpose ops during * conversion and then call Optimize to remove as many as possible. To change the channel ordering of some/all ops * in a model, a user of this tool should do the following: - * + * * 1. Iterate over the graph nodes and identify nodes to convert. For each one: * a. Change the op type and domain (and possibly attributes) to the op/contrib op with the desired ordering. * b. The model is now invalid since the input tensors are in the original ordering (and all consumers * expect the original ordering). Use WrapTransposesAroundNode helper to insert transposes around the * inputs/outputs of the op to correct this. * 2. The model is now correct but has many unnecessary Transpose ops. Call Optimize on the graph. - * + * * After step 1, the Transpose ops will wrap converted ops in a similar manner to q/dq ops in quantization. * The perm attributes essentially encode the information about which ops are being reordered. */ @@ -444,13 +484,13 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops); /// /// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose /// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change. -/// +/// /// Ex: /// * -> NhwcConv -> ** /// becomes /// * -> Transpose -> NhwcConv -> Transpose -> ** /// Conv inputs/outputs have new shape. Shapes of * and ** are unchanged (carrying NCHW data). -/// +/// /// input_perms/output_perms are matched with node inputs/outputs positionally. Their lengths must be at most equal to /// the number of inputs/outputs, respectively. nullptr entires indicate an input or output should not be transposed. /// diff --git a/onnxruntime/core/optimizer/transpose_optimizer/api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc similarity index 73% rename from onnxruntime/core/optimizer/transpose_optimizer/api_impl.cc rename to onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index a5bec56a39..503e912bd4 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/transpose_optimizer/api_impl.h" - +#include "optimizer_api.h" +#include "optimizer_utils.h" #include #include "core/graph/graph_utils.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" -#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/execution_provider.h" +#include "core/graph/graph_viewer.h" #include "core/providers/cpu/tensor/transpose.h" using namespace ONNX_NAMESPACE; @@ -15,7 +15,6 @@ using namespace ::onnxruntime::common; using namespace onnx_layout_transformation; namespace onnxruntime { - class ApiValueInfo final : public api::ValueInfoRef { private: NodeArg& node_arg_; @@ -84,6 +83,8 @@ class ApiNode final : public api::NodeRef { void CopyAttributes(const api::NodeRef& node) override; void ClearAttribute(std::string_view name) override; void SetInput(size_t i, std::string_view name) override; + const std::string& GetExecutionProviderType() const override; + virtual int SinceVersion() const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); @@ -114,6 +115,9 @@ class ApiGraph final : public api::GraphRef { void ReshapeInitializer(std::string_view name, const std::vector& shape) override; std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, size_t num_outputs = 1, std::string_view domain = "") override; + + std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, + std::string_view domain = "") override; void RemoveNode(api::NodeRef& node) override; void RemoveInitializer(std::string_view name) override; std::string_view AddInitializer(api::DataType dtype, const std::vector& shape, @@ -377,6 +381,15 @@ void ApiNode::SetInput(size_t i, std::string_view name) { } } } + +const std::string& ApiNode::GetExecutionProviderType() const { + return node_.GetExecutionProviderType(); +} + +int ApiNode::SinceVersion() const { + return node_.SinceVersion(); +} + // std::optional ApiGraph::Opset(std::string_view domain) const { @@ -562,11 +575,11 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectorSetShape(new_shape); } -std::unique_ptr ApiGraph::AddNode(std::string_view op_type, - const std::vector& inputs, size_t num_outputs, - std::string_view domain) { +static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type, + const std::vector& inputs, size_t num_outputs, + std::string_view domain, int since_version, std::string_view node_ep) { const std::string op_type_str(op_type); - std::string name = graph_.GenerateNodeName(op_type_str); + std::string name = graph.GenerateNodeName(op_type_str); std::vector input_args; std::vector output_args; @@ -574,48 +587,107 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, for (const auto& input : inputs) { NodeArg* arg; if (input == "") { - arg = &graph_.GetOrCreateNodeArg("", nullptr); + arg = &graph.GetOrCreateNodeArg("", nullptr); } else { - arg = graph_.GetNodeArg(std::string(input)); + arg = graph.GetNodeArg(std::string(input)); } input_args.push_back(arg); } output_args.reserve(num_outputs); for (size_t i = 0; i < num_outputs; ++i) { - std::string output = graph_.GenerateNodeArgName(name + "_out" + std::to_string(i)); - NodeArg* arg = &graph_.GetOrCreateNodeArg(output, nullptr); + std::string output = graph.GenerateNodeArgName(name + "_out" + std::to_string(i)); + NodeArg* arg = &graph.GetOrCreateNodeArg(output, nullptr); output_args.push_back(arg); } std::vector outputs; - Node& node = graph_.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr, - std::string(domain)); + Node& node = graph.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr, + std::string(domain)); - if (new_node_ep_ != nullptr) { - node.SetExecutionProviderType(new_node_ep_); + if (node.SinceVersion() == -1) { + node.SetSinceVersion(since_version); } + node.SetExecutionProviderType(std::string(node_ep)); + for (size_t i = 0; i < input_args.size(); ++i) { NodeArg* arg = input_args[i]; if (arg->Exists()) { const std::string& name_str = arg->Name(); - graph_.AddConsumerNode(name_str, &node); - const auto* inp_node = graph_.GetProducerNode(name_str); + graph.AddConsumerNode(name_str, &node); + const auto* inp_node = graph.GetProducerNode(name_str); if (inp_node != nullptr) { int inp_node_out_index = graph_utils::GetNodeOutputIndexFromOutputName(*inp_node, name_str); - graph_.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast(i)); + graph.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast(i)); } } } for (NodeArg* arg : output_args) { - graph_.UpdateProducerNode(arg->Name(), node.Index()); + graph.UpdateProducerNode(arg->Name(), node.Index()); } + return node; +} + +// This is a list of onnx ops and their versions which transpose_optimizer can potentially add to the graph. +// This is needed in minimal build since opschema is not available. +// The versions MUST be sorted due to how the model opset is matched with the most recent operator version. +static const std::unordered_map> onnx_ops_available_versions = { + {"Squeeze", {1, 11, 13}}, + {"Unsqueeze", {1, 11, 13}}, + {"Gather", {1, 11, 13}}, + {"Transpose", {1, 13}}, + {"Identity", {1, 13, 14, 16}}, +}; + +// Based on the opset version imported for this model, returns the since version for the node. +static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain, + const std::unordered_map& domain_to_version_map) { + int since_version = -1; + ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ", + domain, " provided for op: ", op_type); + + auto opset_import_iter = domain_to_version_map.find(std::string(domain)); + ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), "Onnx domain not found in opset imports."); + + int opset_version = opset_import_iter->second; + auto iter = onnx_ops_available_versions.find(std::string(op_type)); + ORT_ENFORCE(iter != onnx_ops_available_versions.end(), + "Transpose Optimizer is adding an unexpected node: ", op_type, + "An entry for this node should be added in onnx_ops_available_versions and static_kernel_hashes map."); + + for (auto version : iter->second) { + if (version <= opset_version) { + since_version = version; + } + } + + return since_version; +} + +std::unique_ptr ApiGraph::AddNode(std::string_view op_type, + const std::vector& inputs, size_t num_outputs, + std::string_view domain) { + int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap()); + Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs, + domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : ""); + return std::make_unique(node, graph_); } +std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, + std::string_view domain) { + Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), + source_node.Outputs().size(), domain, source_node.SinceVersion(), source_node.GetExecutionProviderType()); + + std::unique_ptr new_node = std::make_unique(node, graph_); + new_node->CopyAttributes(source_node); + + return new_node; +} + void ApiGraph::RemoveNode(api::NodeRef& node) { Node& ort_node = static_cast(node).Node(); for (const auto* node_arg : ort_node.InputDefs()) { @@ -705,6 +777,10 @@ std::unique_ptr MakeApiGraph(onnxruntime::Graph& graph, Allocator return std::make_unique(graph, std::move(cpu_allocator), new_node_ep); } +std::unique_ptr MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node) { + return std::make_unique(node, graph); +} + onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef& graph) { return static_cast(graph).Graph(); } @@ -713,4 +789,99 @@ onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& nod return static_cast(node).Node(); } +namespace layout_transformer { + +const std::unordered_set& GetORTLayoutSensitiveOps() { + static std::unordered_set ort_layout_senstive_ops = []() { + const auto& layout_sensitive_ops = onnx_layout_transformation::GetLayoutSensitiveOps(); + std::unordered_set ort_specific_ops = {"Resize", "FusedConv", "QLinearAveragePool", "QLinearGlobalAveragePool"}; + ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); + return ort_specific_ops; + }(); + + return ort_layout_senstive_ops; +} + +Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider) { + // sub graph recurse will be added later + auto api_graph = MakeApiGraph(graph, execution_provider.GetAllocator(0, OrtMemTypeDefault), nullptr); + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + for (auto& node : api_graph->Nodes()) { + if (layout_sensitive_ops.count(node->OpType())) { + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + + auto domain = node->Domain(); + // Skip if domain is incorrect + if (domain != kOnnxDomain && domain != kOnnxDomainAlias && domain != kMSDomain) { + continue; + } + + // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP + // knows this op is in the expected format. + if (node->GetAttributeIntDefault("channels_last", 0) == 1) { + onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + // Changing the domain for the node requires creating a new node and replacing the old one + // therefore set the modified flag. + modified = true; + continue; + } + + // Skip if unknown rank + auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape(); + if (!shape.has_value()) { + continue; + } + + // Convert to channels last + size_t rank = shape->size(); + + bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false; + if (has_channel_last_attr) { + node->SetAttributeInt("channels_last", 1); + } + + auto input_perm = onnx_layout_transformation::ChannelFirstToLastPerm(rank); + auto output_perm = onnx_layout_transformation::ChannelLastToFirstPerm(rank); + + // Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation + // for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout + // transformer only converts layout for 0th input, weights should be handled by every EP. + if (node->OpType() == "Resize") { + // Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case, + // we need to jump a few extra hoops to make sure its inputs are correctly handled. Current code skips + // layout conversion for ROI because it needs special handling as ROI size is 2*rank. + // Enable passing in ROI for layout conversion when an EP which supports ROI starts using layout transformer. + // NNAPI which currently uses layout transformer does not support it. + std::vector*> input_perms{&input_perm, nullptr}; + for (size_t i = 2; i < node->Inputs().size(); i++) { + auto constant = api_graph->GetConstant(node->Inputs()[i]); + if (constant != nullptr && constant->Data().size() > 0) { + input_perms.push_back(&input_perm); + } else { + input_perms.push_back(nullptr); + } + } + onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm}); + } else { + onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); + } + + onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + modified = true; + } + } + + if (modified) { + onnx_layout_transformation::Optimize(*api_graph, /*allow_extended_ops*/ true, execution_provider.Type(), + onnx_layout_transformation::OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM, + layout_sensitive_ops); + } + + return Status::OK(); +} + +} // namespace layout_transformer } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimizer/api_impl.h b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_utils.h similarity index 52% rename from onnxruntime/core/optimizer/transpose_optimizer/api_impl.h rename to onnxruntime/core/optimizer/transpose_optimizer/optimizer_utils.h index e84d1a7d2a..872f68483e 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/api_impl.h +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_utils.h @@ -3,15 +3,11 @@ #pragma once -#include "core/framework/execution_provider.h" +#include "optimizer_api.h" #include "core/graph/graph.h" -#include "core/optimizer/transpose_optimizer/api.h" - -using namespace ONNX_NAMESPACE; -using namespace ::onnxruntime::common; +#include "core/framework/execution_provider.h" namespace onnxruntime { - /// /// Creates concrete implementation of api for transpose optimizer. IMPORTANT: graph must have up-to-date edges, /// node_arg-to-producer, and node_arg-to-consumer relationships. Otherwise call Resolve() before this. @@ -24,6 +20,14 @@ std::unique_ptr MakeApiGraph(onnxrunt AllocatorPtr cpu_allocator, const char* new_node_ep); +/// +/// Creates NodeRef. +/// +/// ORT Graph which owns the node +/// ORT Node to wrap with API. +/// api::NodeRef for use with transpose optimizer +std::unique_ptr MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node); + /// /// Reveals underlying ORT graph from an api::GraphRef /// @@ -38,4 +42,25 @@ onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef& /// ORT node onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& node); +namespace layout_transformer { +/// +/// Gets a list of layout sensitive ops for ORT. This list contains onnx standard defined +/// layout senstive ops + contrib ops + ops which are not layout sensitive but are treated as +/// layout sensitive by ORT EPs (exmaple Resize). +/// +/// unordered set of op_types which are layout sensitive +const std::unordered_set& GetORTLayoutSensitiveOps(); + +/// +/// Transforms data layout from NCHW to NHWC. Applies transforms to layout sensitive nodes +/// assigned to execution_provider provided by the caller and any other non-layout sensitive +/// nodes in order to optimize the transposes as much as possible. +/// +/// graph to transform +/// indicates whether the graph is modified during transformation +/// execution provider for which the transformation needs to be performed +/// +Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider); + +} // namespace layout_transformer } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.cc index 80a1472cde..582842b3ea 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.cc @@ -1,14 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h" - +#include "ort_transpose_optimizer.h" #include #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" -#include "core/optimizer/transpose_optimizer/api_impl.h" #include "core/providers/cpu/tensor/transpose.h" +#include "optimizer_utils.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; diff --git a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc index eacc0e3fd3..6acec0ef17 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "api.h" +#include "optimizer_api.h" #include #include #include #include #include +#include namespace onnx_layout_transformation { @@ -16,6 +17,9 @@ struct OptimizerCtx { api::GraphRef& graph; bool allow_extended_ops; bool skip_cost_check; + const std::string provider_type; + OptimizerMode mode; + std::unordered_set layout_sensitive_ops; }; // Each op handler points to a (potentially shared) function for determining which input indices are eligible for @@ -44,7 +48,6 @@ struct HandlerInfo { bool transposes_outputs = true; }; - /////// /////// /* Small utilities for editing nodes and manipulating axes/permutations */ @@ -63,14 +66,14 @@ static std::vector DataInt32(api::TensorRef& tensor) { } static std::string_view AddInitializerInt64(api::GraphRef& graph, const std::vector& shape, - const std::vector& values) { + const std::vector& values) { const uint8_t* raw_data = reinterpret_cast(values.data()); std::vector data(raw_data, raw_data + values.size() * sizeof(int64_t)); return graph.AddInitializer(api::DataType::INT64, shape, data); } static std::string_view AddInitializerInt32(api::GraphRef& graph, const std::vector& shape, - const std::vector& values) { + const std::vector& values) { const uint8_t* raw_data = reinterpret_cast(values.data()); std::vector data(raw_data, raw_data + values.size() * sizeof(int32_t)); return graph.AddInitializer(api::DataType::INT32, shape, data); @@ -108,7 +111,7 @@ static std::unique_ptr MakeTranspose(api::GraphRef& graph, std::st } // Creates a Squeeze/Unsqueeze node. Does not update output ValueInfo. -static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph, +static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph, std::string_view op_type, std::string_view input, const std::vector& axes) { if (opset < 13) { @@ -141,7 +144,7 @@ static bool IsValidPerm(const std::vector& perm) { static std::optional> GetPermAttrIfValid(const api::NodeRef& node) { std::optional> perm = node.GetAttributeInts("perm"); - if (perm != std::nullopt && !IsValidPerm(*perm)) { + if (perm.has_value() && !IsValidPerm(*perm)) { return std::nullopt; } return perm; @@ -224,8 +227,12 @@ static bool IsIdentityPerm(const std::vector& perm) { } // Computes permutation from channel last to channel first ordering of given rank. Nearly all handlers work for any -// permutation, but some are restricted. Also used for layout transformation. Rank must be >= 1. +// permutation, but some are restricted. Also used for layout transformation. std::vector ChannelLastToFirstPerm(size_t rank) { + if (rank < 2) { + return {}; + } + std::vector p(rank); p[0] = 0; p[1] = rank - 1; @@ -330,9 +337,9 @@ static std::vector SqueezePerm(const std::vector& axes, const return new_perm; } -// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or +// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or // have negative values. -// +// // Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [2, 0] static std::vector AxesForTransposedInput(const std::vector& axes, const std::vector& perm) { @@ -346,7 +353,7 @@ static std::vector AxesForTransposedInput(const std::vector& a // Computes a new axes attribute for an input that has been permuted using perm and sorts the result. Axes attributes // are commonly sorted (unless order matters like in Slice). Unsafe if axes/perm are invalid or have negative values. -// +// // Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [0, 2] static std::vector SortedAxesForTransposedInput(const std::vector& axes, const std::vector& perm) { @@ -376,10 +383,8 @@ static std::vector SortedAxesForTransposedInput(const std::vector /////// /* These helpers hide the most gnarly parts of the transpose optimizer. */ - static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector& axes); - // Replaces ith input to node with unsqueezed value. Might create a new Unsqueeze node, find an existing one, // or reshape an initializer. Unsqueezing can be necessary before transposing inputs of a node that supports // broadcasting. @@ -442,7 +447,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // a Transpose, optimize it here. if (inp_node != nullptr && inp_node->IsOp("Transpose")) { auto perm = GetPermAttrIfValid(*inp_node); - if (perm != std::nullopt) { + if (perm.has_value()) { auto perm_inv = InvertPerm(*perm); std::vector indices = {0}; HandlerArgs args{ctx, *inp_node, unsqueeze, *perm, perm_inv, indices}; @@ -456,6 +461,29 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons node.SetInput(i, unsq_out); } +static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::TensorRef& constant, + size_t i, std::string_view input_name, const std::vector& perm) { + // Create new transposed initializer + auto rank = perm.size(); + auto shape = constant.Shape(); + std::vector data = constant.Data(); + std::vector new_data(data.size()); + size_t bytes_per_val = data.size() / rank; + + uint8_t* dst = new_data.data(); + for (size_t j = 0; j < rank; ++j) { + uint8_t* src = data.data() + perm[j] * bytes_per_val; + std::memcpy(dst, src, bytes_per_val); + dst += bytes_per_val; + } + + std::string_view new_initializer = graph.AddInitializer(constant.DType(), shape, new_data); + node.SetInput(i, new_initializer); + if (!graph.HasValueConsumers(input_name)) { + graph.RemoveInitializer(input_name); + } +} + // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, @@ -469,6 +497,18 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { + // Input is scalar, return early. + if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) { + return; + } + // This is a special case where the constant is 1D with length == perm. + // TODO: TransposeInitializer should be updated to handle this case. + // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if + // there are no other consumers. + if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { + Permute1DConstant(graph, node, *constant, i, input, perm); + return; + } if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. @@ -496,6 +536,10 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, } node.SetInput(i, pre_transpose_value); return; + } else if (*perm2 == perm) { + // we are trying to add a duplicate transpose. + // do nothing and return + return; } // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove @@ -513,7 +557,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, return; } } - + // Case 3: A Transpose op might already exist for (size_t j = 0; j < consumers->nodes.size(); ++j) { api::NodeRef& consumer = *consumers->nodes[j]; @@ -533,7 +577,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, } // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. -static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, +static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, const std::vector& input_indices) { auto inputs = node.Inputs(); @@ -563,7 +607,7 @@ static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t tar } // Transposes specified inputs according to perm. -// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it. +// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it. static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& perm, const std::vector& input_indices) { auto perm_inv = InvertPerm(perm); @@ -573,7 +617,7 @@ static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::ve } inline static void TransposeFirstInput(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& perm) { - std::vector indices {0}; + std::vector indices{0}; TransposeInputs(ctx, node, perm, indices); } @@ -630,7 +674,7 @@ static void TransposeOutputs(OptimizerCtx& ctx, api::NodeRef& node, const std::v // A rank of 5 is used if rank cannot be determined since 5 is the largest rank we expect from something like a Conv // and an unknown rank likely corresponds to a data-carrying (non-weight) tensor, which will be large. -// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown. +// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown. static int EstimateValueRank(api::GraphRef& graph, std::string_view input) { auto value_info = graph.GetValueInfo(input); std::optional> shape = value_info->Shape(); @@ -774,7 +818,7 @@ std::vector NonScalarInputs(OptimizerCtx& ctx, api::NodeRef& node) { constexpr HandlerInfo broadcast_node_handler = {&NonScalarInputs, &HandleSimpleNodeBroadcast}; // Transposes all inputs and all outputs. Updates axis attribute. -static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_axis=std::nullopt) { +static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_axis = std::nullopt) { size_t rank = args.perm.size(); std::optional axis = args.node.GetAttributeInt("axis"); if (axis == std::nullopt) { @@ -848,7 +892,6 @@ static bool HandleShape(HandlerArgs& args) { std::vector new_perm; // For opset 15, Shape(Transpose(x, perm))[starts:stops] = Gather(Shape(x), perm[starts:stops]) if (args.ctx.opset >= 15) { - // Assign new_perm = perm[starts:stops] int64_t start = args.node.GetAttributeIntDefault("start", 0); int64_t end = args.node.GetAttributeIntDefault("end", rank_int); @@ -870,7 +913,7 @@ static bool HandleShape(HandlerArgs& args) { } // Make new_perm initializer - std::vector perm_shape {gsl::narrow_cast(new_perm.size())}; + std::vector perm_shape{gsl::narrow_cast(new_perm.size())}; std::string_view perm_const = AddInitializerInt64(args.ctx.graph, perm_shape, new_perm); // X -> Shape -> Y, Gather @@ -900,7 +943,7 @@ static bool HandleShape(HandlerArgs& args) { constexpr HandlerInfo shape_handler = {&FirstInput, &HandleShape, /*transposes_outputs*/ false}; // Permutes a 1D node input by creating a new initializer or inserting a Gather op -void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm) { +static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm) { size_t rank = perm.size(); int64_t rank_int = gsl::narrow_cast(rank); @@ -909,25 +952,7 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std: if (constant != nullptr) { auto shape = constant->Shape(); if (shape.size() == 1 && (shape[0] == rank_int || shape[0] == 0)) { - // Create new transposed initializer - std::vector data = constant->Data(); - std::vector new_data(data.size()); - size_t bytes_per_val = data.size() / rank; - - uint8_t* dst = new_data.data(); - for (size_t j = 0; j < rank; ++j) { - uint8_t* src = data.data() + perm[j] * bytes_per_val; - for (size_t k = 0; k < bytes_per_val; ++k) { - *dst++ = *src++; - } - } - - std::string_view new_initializer = graph.AddInitializer(constant->DType(), shape, new_data); - node.SetInput(i, new_initializer); - if (!graph.HasValueConsumers(input)) { - graph.RemoveInitializer(input); - } - + Permute1DConstant(graph, node, *constant, i, input, perm); return; } } @@ -942,33 +967,39 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std: node.SetInput(i, gather_output); } -//static bool HandleResize(HandlerArgs& args) { +// static bool HandleResize(HandlerArgs& args) { // auto inputs = args.node.Inputs(); // int64_t rank_int = gsl::narrow_cast(args.perm.size()); // +// auto p = ChannelFirstToLastPerm(rank_int); +// auto& perm = p == args.perm ? args.perm : args.perm_inv; +// auto& perm_inv = p == args.perm ? args.perm_inv : args.perm; +// // if (args.ctx.opset < 11) { -// PermuteInput(args.ctx.graph, args.node, 1, args.perm_inv); -// } else { -// if (inputs[1] != "") { -// std::vector double_perm_inv = args.perm_inv; -// double_perm_inv.reserve(2 * args.perm_inv.size()); -// for (int64_t p : args.perm_inv) { -// double_perm_inv.push_back(p + rank_int); -// } -// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv); -// } -// for (size_t i = 2; i < inputs.size(); ++i) { -// if (inputs[i] != "") { -// PermuteInput(args.ctx.graph, args.node, i, args.perm_inv); -// } -// } -// } +// PermuteInput(args.ctx.graph, args.node, 1, perm); +// } else { +// if (inputs[1] != "") { +// std::vector double_perm_inv = perm; +// double_perm_inv.reserve(2 * args.perm.size()); +// for (int64_t p1 : perm) { +// double_perm_inv.push_back(p1 + rank_int); +// } +// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv); +// } +// for (size_t i = 2; i < inputs.size(); ++i) { +// if (inputs[i] != "") { +// PermuteInput(args.ctx.graph, args.node, i, perm); +// } +// } +// } // -// TransposeFirstInput(args.ctx, args.node, args.perm_inv); -// TransposeOutputs(args.ctx, args.node, args.perm); +// TransposeFirstInput(args.ctx, args.node, perm); +// TransposeOutputs(args.ctx, args.node, perm_inv); // -// return true; -//} +// SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, args.node.OpType(), "com.microsoft.nhwc"); +// +// return true; +// } // constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; @@ -1028,7 +1059,6 @@ static bool HandleReduceOp(HandlerArgs& args) { } } else { - if (!NormalizeAndValidateAxes(*axes, args.perm.size())) { return false; } @@ -1143,7 +1173,6 @@ static bool HandleSqueeze(HandlerArgs& args) { if (!args.ctx.graph.HasValueConsumers(axes_inp)) { args.ctx.graph.RemoveInitializer(axes_inp); } - } // Transpose inputs/outputs @@ -1178,25 +1207,31 @@ static bool HandleUnsqueeze(HandlerArgs& args) { constexpr HandlerInfo unsqueeze_handler = {&FirstInput, &HandleUnsqueeze}; -static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) { - size_t rank = args.perm.size(); - - if (args.ctx.opset >= 13) { +static bool HandleQuantizeDequantizeScale(const api::GraphRef& graph, const std::vector& perm, + api::NodeRef& node, int64_t opset) { + if (opset >= 13) { + size_t rank = perm.size(); // Update axis in Opset >= 13 if scale/zero_point are non-scalar - auto inputs = args.node.Inputs(); + auto inputs = node.Inputs(); - std::optional> inp_shape = args.ctx.graph.GetValueInfo(inputs[1])->Shape(); - bool scalar_params = inp_shape != std::nullopt && inp_shape->size() == 0; + auto inp_shape = graph.GetValueInfo(inputs[1])->Shape(); + bool scalar_params = inp_shape.has_value() && inp_shape->size() == 0; if (!scalar_params) { - int64_t axis = args.node.GetAttributeIntDefault("axis", 1); + int64_t axis = node.GetAttributeIntDefault("axis", 1); if (!NormalizeAndValidateAxis(axis, rank)) { return false; } - - args.node.SetAttributeInt("axis", args.perm[gsl::narrow_cast(axis)]); + node.SetAttributeInt("axis", perm[gsl::narrow_cast(axis)]); } } + return true; +} + +static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) { + if (!HandleQuantizeDequantizeScale(args.ctx.graph, args.perm, args.node, args.ctx.opset)) { + return false; + } TransposeFirstInput(args.ctx, args.node, args.perm_inv); TransposeOutputs(args.ctx, args.node, args.perm); @@ -1215,7 +1250,7 @@ static bool HandleArgMinMax(HandlerArgs& args) { return false; } int64_t new_axis = args.perm[gsl::narrow_cast(axis)]; - std::vector new_axes {new_axis}; + std::vector new_axes{new_axis}; args.node.SetAttributeInt("axis", new_axis); TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs); @@ -1368,7 +1403,7 @@ static bool HandleTile(HandlerArgs& args) { } else { // Case 2: Repeats is computed. Insert Gather node. std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); - std::vector gather_inputs {repeats_inp, perm_inv_const}; + std::vector gather_inputs{repeats_inp, perm_inv_const}; auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; @@ -1398,7 +1433,7 @@ static bool HandleTranspose(HandlerArgs& args) { if (args.perm_inv == *node_perm) { // Case 1: Permutations cancel. - auto consumers = args.ctx.graph.GetValueConsumers(args.node.Outputs()[0]); + auto consumers = args.ctx.graph.GetValueConsumers(node_output); if (consumers->comprehensive) { // If possible, replace references to output of 2nd transpose with input to 1st ReplaceValueReferences(consumers->nodes, node_output, transpose_input); @@ -1425,14 +1460,13 @@ static bool HandleTranspose(HandlerArgs& args) { } else { // Worst-case scenario: Both parent output and 2nd transpose output cannot be removed (both graph outputs) // despite computing the same value. Use an Identity op instead. - std::vector single_empty_input {""}; + std::vector single_empty_input{""}; auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); } } - // In any case, the 2nd transpose can be removed. args.ctx.graph.RemoveNode(args.node); } else { @@ -1442,7 +1476,7 @@ static bool HandleTranspose(HandlerArgs& args) { args.node.SetInput(0, transpose_input); } - // 2nd transpose no longer references 1st. Remove 2nd if possible. + // 2nd transpose no longer references 1st. Remove first if possible. if (!args.ctx.graph.HasValueConsumers(args.transpose.Outputs()[0])) { args.ctx.graph.RemoveNode(args.transpose); } @@ -1484,7 +1518,7 @@ constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &Han static bool HandleQLinearPoolOp(HandlerArgs& args) { // Swap between channel first/last variants. Only works for applicable values of perm. - int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 1); + int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 0); size_t rank = args.perm.size(); if (rank < 2) return false; auto p = ChannelLastToFirstPerm(rank); @@ -1500,7 +1534,11 @@ static bool HandleQLinearPoolOp(HandlerArgs& args) { constexpr HandlerInfo q_linear_pool_op_handler = {&FirstInput, &HandleQLinearPoolOp}; static bool HandleMaxPool(HandlerArgs& args) { - // Replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool. + // For CPU EP replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool. + if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + return false; + } + auto outputs = args.node.Outputs(); if (outputs.size() == 2 && outputs[1] != "") { // Can't optimize if optional "indices" output is provided @@ -1520,7 +1558,6 @@ static bool HandleMaxPool(HandlerArgs& args) { auto new_node = SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, "NhwcMaxPool", "com.microsoft"); new_node->ClearAttribute("storage_order"); // Only relevant for indices output. Prohibited for NhwcMaxPool. - TransposeFirstInput(args.ctx, *new_node, args.perm_inv); TransposeOutputs(args.ctx, *new_node, args.perm); return true; @@ -1529,71 +1566,121 @@ static bool HandleMaxPool(HandlerArgs& args) { constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; // TODO: check binary size of this and replace it with constexpr if large -static const std::unordered_map handler_map { +static const std::unordered_map handler_map{ - {"Cast", simple_node_handler}, {"Exp", simple_node_handler}, {"Identity", simple_node_handler}, - {"LeakyRelu", simple_node_handler}, {"Log", simple_node_handler}, {"Reciprocal", simple_node_handler}, - {"Relu", simple_node_handler}, {"Sigmoid", simple_node_handler}, {"Sqrt", simple_node_handler}, - {"Tanh", simple_node_handler}, {"Abs", simple_node_handler}, {"Not", simple_node_handler}, - {"Ceil", simple_node_handler}, {"Floor", simple_node_handler}, {"Neg", simple_node_handler}, - {"Erf", simple_node_handler}, {"HardSigmoid", simple_node_handler}, {"Round", simple_node_handler}, - {"IsInf", simple_node_handler}, {"IsNaN", simple_node_handler}, - {"Selu", simple_node_handler}, {"Shrink", simple_node_handler}, {"Sign", simple_node_handler}, - {"Softplus", simple_node_handler}, {"Softsign", simple_node_handler}, {"ThresholdedRelu", simple_node_handler}, - {"Celu", simple_node_handler}, {"HardSwish", simple_node_handler}, + {"Cast", simple_node_handler}, + {"Exp", simple_node_handler}, + {"Identity", simple_node_handler}, + {"LeakyRelu", simple_node_handler}, + {"Log", simple_node_handler}, + {"Reciprocal", simple_node_handler}, + {"Relu", simple_node_handler}, + {"Sigmoid", simple_node_handler}, + {"Sqrt", simple_node_handler}, + {"Tanh", simple_node_handler}, + {"Abs", simple_node_handler}, + {"Not", simple_node_handler}, + {"Ceil", simple_node_handler}, + {"Floor", simple_node_handler}, + {"Neg", simple_node_handler}, + {"Erf", simple_node_handler}, + {"HardSigmoid", simple_node_handler}, + {"Round", simple_node_handler}, + {"IsInf", simple_node_handler}, + {"IsNaN", simple_node_handler}, + {"Selu", simple_node_handler}, + {"Shrink", simple_node_handler}, + {"Sign", simple_node_handler}, + {"Softplus", simple_node_handler}, + {"Softsign", simple_node_handler}, + {"ThresholdedRelu", simple_node_handler}, + {"Celu", simple_node_handler}, + {"HardSwish", simple_node_handler}, - {"Sin", simple_node_handler}, {"Cos", simple_node_handler}, {"Tan", simple_node_handler}, - {"Sinh", simple_node_handler}, {"Cosh", simple_node_handler}, {"Tanh", simple_node_handler}, - {"Asin", simple_node_handler}, {"Acos", simple_node_handler}, {"Atan", simple_node_handler}, - {"Asinh", simple_node_handler}, {"Acosh", simple_node_handler}, {"Atanh", simple_node_handler}, + {"Sin", simple_node_handler}, + {"Cos", simple_node_handler}, + {"Tan", simple_node_handler}, + {"Sinh", simple_node_handler}, + {"Cosh", simple_node_handler}, + {"Tanh", simple_node_handler}, + {"Asin", simple_node_handler}, + {"Acos", simple_node_handler}, + {"Atan", simple_node_handler}, + {"Asinh", simple_node_handler}, + {"Acosh", simple_node_handler}, + {"Atanh", simple_node_handler}, - {"Add", broadcast_node_handler}, {"Max", broadcast_node_handler}, {"Min", broadcast_node_handler}, - {"Mul", broadcast_node_handler}, {"Sub", broadcast_node_handler}, {"Div", broadcast_node_handler}, - {"And", broadcast_node_handler}, {"Or", broadcast_node_handler}, {"Xor", broadcast_node_handler}, - {"Mod", broadcast_node_handler}, {"PRelu", broadcast_node_handler}, {"BitShift", broadcast_node_handler}, - {"Equal", broadcast_node_handler}, {"Greater", broadcast_node_handler}, {"Less", broadcast_node_handler}, - {"GreaterOrEqual", broadcast_node_handler}, {"LessOrEqual", broadcast_node_handler}, - {"Mean", broadcast_node_handler}, {"Sum", broadcast_node_handler}, {"Pow", broadcast_node_handler}, - {"Where", broadcast_node_handler}, + {"Add", broadcast_node_handler}, + {"Max", broadcast_node_handler}, + {"Min", broadcast_node_handler}, + {"Mul", broadcast_node_handler}, + {"Sub", broadcast_node_handler}, + {"Div", broadcast_node_handler}, + {"And", broadcast_node_handler}, + {"Or", broadcast_node_handler}, + {"Xor", broadcast_node_handler}, + {"Mod", broadcast_node_handler}, + {"PRelu", broadcast_node_handler}, + {"BitShift", broadcast_node_handler}, + {"Equal", broadcast_node_handler}, + {"Greater", broadcast_node_handler}, + {"Less", broadcast_node_handler}, + {"GreaterOrEqual", broadcast_node_handler}, + {"LessOrEqual", broadcast_node_handler}, + {"Mean", broadcast_node_handler}, + {"Sum", broadcast_node_handler}, + {"Pow", broadcast_node_handler}, + {"Where", broadcast_node_handler}, - {"Clip", node_1_inp_handler}, {"CastLike", node_1_inp_handler}, + {"Clip", node_1_inp_handler}, + {"CastLike", node_1_inp_handler}, - {"Transpose", transpose_handler}, - {"Concat", concat_handler}, - {"Split", split_handler}, - {"Shape", shape_handler}, - {"Pad", pad_handler}, - // Todo: renable resize handler after adding NHWC support in upsample op on cpu - // https://github.com/microsoft/onnxruntime/issues/9857 - //{"Resize", resize_handler}, - {"ReduceSum", reduce_sum_handler}, + {"Transpose", transpose_handler}, + {"Concat", concat_handler}, + {"Split", split_handler}, + {"Shape", shape_handler}, + {"Pad", pad_handler}, + // Todo: renable resize handler after adding NHWC support in upsample op on cpu + // https://github.com/microsoft/onnxruntime/issues/9857 + // {"Resize", resize_handler}, + {"ReduceSum", reduce_sum_handler}, - {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, {"ReduceMax", reduce_op_handler}, - {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, - {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, + {"ReduceLogSum", reduce_op_handler}, + {"ReduceLogSumExp", reduce_op_handler}, + {"ReduceMax", reduce_op_handler}, + {"ReduceMean", reduce_op_handler}, + {"ReduceMin", reduce_op_handler}, + {"ReduceProd", reduce_op_handler}, + {"ReduceSumSquare", reduce_op_handler}, + {"ReduceL1", reduce_op_handler}, + {"ReduceL2", reduce_op_handler}, - {"ArgMin", arg_min_max_handler}, {"ArgMax", arg_min_max_handler}, + {"ArgMin", arg_min_max_handler}, + {"ArgMax", arg_min_max_handler}, - {"Squeeze", squeeze_handler}, - {"Unsqueeze", unsqueeze_handler}, - {"Slice", slice_handler}, - {"Tile", tile_handler}, + {"Squeeze", squeeze_handler}, + {"Unsqueeze", unsqueeze_handler}, + {"Slice", slice_handler}, + {"Tile", tile_handler}, - {"Softmax", soft_hard_max_handler}, {"Hardmax", soft_hard_max_handler}, {"LogSoftmax", soft_hard_max_handler}, + {"Softmax", soft_hard_max_handler}, + {"Hardmax", soft_hard_max_handler}, + {"LogSoftmax", soft_hard_max_handler}, - {"QuantizeLinear", quantize_dequantize_linear_handler}, {"DequantizeLinear", quantize_dequantize_linear_handler}, + {"QuantizeLinear", quantize_dequantize_linear_handler}, + {"DequantizeLinear", quantize_dequantize_linear_handler}, }; static const std::unordered_map extended_handler_map{ - {"com.microsoft.QLinearReduceMean", reduce_op_handler}, - {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, - {"com.microsoft.QLinearLeakyRelu", node_1_inp_handler}, - {"com.microsoft.QLinearConcat", q_linear_concat_handler}, - {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, - {"com.microsoft.QLinearMul", q_linear_binary_op_handler}, - {"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler}, - {"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler}, - {"MaxPool", max_pool_op_handler}, + {"com.microsoft.QLinearReduceMean", reduce_op_handler}, + {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, + {"com.microsoft.QLinearLeakyRelu", node_1_inp_handler}, + {"com.microsoft.QLinearConcat", q_linear_concat_handler}, + {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, + {"com.microsoft.QLinearMul", q_linear_binary_op_handler}, + {"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler}, + {"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler}, + {"MaxPool", max_pool_op_handler}, }; static const HandlerInfo* GetHandler(api::NodeRef& node, bool allow_extended_ops) { @@ -1674,7 +1761,9 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& } // Returns nullopt if graph opset is unsupported. -std::optional MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops) { +std::optional MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops, + const std::string& provider_type, OptimizerMode mode, + const std::unordered_set& layout_sensitive_ops) { auto opset = graph.Opset(""); if (opset == std::nullopt) { opset = graph.Opset("ai.onnx"); @@ -1688,14 +1777,17 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, bool allo allow_extended_ops = false; } } - OptimizerCtx ctx{*opset, graph, allow_extended_ops, /*skip_cost_check*/ false}; + + // during layout transformation we want to push the transposes as far out as possible. + // it is important that the EP gets the entire graph in the layout it prefers. + bool skip_cost_check = mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM; + OptimizerCtx ctx{*opset, graph, allow_extended_ops, skip_cost_check, provider_type, mode, layout_sensitive_ops}; return ctx; } // Performs optimization. General algorithm: iterate over nodes in topological order. If a node has a transpose // as input, push it through if the transpose cost does not increase and is likely to decrease. bool OptimizeImpl(OptimizerCtx& ctx) { - const std::vector> nodes = ctx.graph.Nodes(); std::unordered_set outputs_leading_to_transpose; @@ -1703,7 +1795,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) { // First iterate over sorted nodes in reverse order to find which outputs have paths through supported ops to // transpose nodes. We pull push transposes towards these outputs. for (size_t i = 0; i < nodes.size(); ++i) { - api::NodeRef& node = *nodes[nodes.size() - i - 1]; if (node.IsOp("Transpose")) { outputs_leading_to_transpose.insert(std::string(node.Inputs()[0])); @@ -1731,6 +1822,13 @@ bool OptimizeImpl(OptimizerCtx& ctx) { // New transpose nodes are inserted, but always as an input to an existing node. for (size_t i = 0; i < nodes.size(); ++i) { api::NodeRef& node = *nodes[i]; + if (ctx.mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM && + ctx.layout_sensitive_ops.count(node.OpType()) && node.GetExecutionProviderType() != ctx.provider_type) { + // If the current op is layout sensitive and it is not assigned to the given provider + // then do not process transpose. + continue; + } + std::vector inputs = node.Inputs(); for (size_t j = 0; j < inputs.size(); ++j) { std::string_view inp = inputs[j]; @@ -1741,7 +1839,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) { if (transpose != nullptr && transpose->IsOp("Transpose")) { std::optional> perm = GetPermAttrIfValid(*transpose); if (perm != std::nullopt) { - std::vector perm_inv = InvertPerm(*perm); if (ProcessTranspose(ctx, *transpose, node, *perm, j, outputs_leading_to_transpose)) { changed = true; // Subsequent inputs may have changed and node may have been removed. @@ -1751,11 +1848,69 @@ bool OptimizeImpl(OptimizerCtx& ctx) { } } } + + // Currently limiting the second optimization pass to layout transform mode + // TODO: Enable this for both the modes. + if (ctx.mode == OptimizerMode::OPTIMIZE_TRANSPOSE) { + return changed; + } + + // Run second optimization pass. + // If any transpose succeeds a DQ node, move it above the DQ node. + // In case of QDQ models this helps to preserve the QDQ node unit + // In all other scenarios this is beneficial as well because moving transpose above DQ node is more efficient as + // transpose node now handles less data. + auto graph_nodes = ctx.graph.Nodes(); + for (size_t i = 1; i < graph_nodes.size(); i++) { + if (graph_nodes[i]->OpType() == "Transpose") { + auto& transpose_node = *graph_nodes[i]; + auto dq_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]); + if (!dq_node || dq_node->OpType() != "DequantizeLinear") { + continue; + } + + auto consumers = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]); + bool is_part_of_qdq_unit = std::find_if(consumers->nodes.cbegin(), consumers->nodes.cend(), + [](const std::unique_ptr& node) { + return node->OpType() == "QuantizeLinear"; + }) != consumers->nodes.cend(); + if (is_part_of_qdq_unit) { + continue; + } + + // Update Dequantize Node and move the transpose above it + auto perm = GetPermAttrIfValid(transpose_node); + if (!perm.has_value()) { + continue; + } + if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) { + continue; + } + TransposeFirstInput(ctx, *dq_node, *perm); + + // remove existing transpose node + transpose_node.SetInput(0, ""); + ctx.graph.MoveOutput(transpose_node, 0, *dq_node, 0); + ctx.graph.RemoveNode(transpose_node); + changed = true; + } + } return changed; } -bool Optimize(api::GraphRef& graph, bool allow_extended_ops) { - auto ctx = MakeOptimizerContext(graph, allow_extended_ops); +const std::unordered_set& GetLayoutSensitiveOps() { + // List of all layout sensitive ops defined in ONNX standard. + static std::unordered_set layout_sensitive_ops = {"Conv", "QLinearConv", "BatchNormalization", + "AveragePool", "GlobalAveragePool", "MaxPool", + "GlobalMaxPool", "LRN"}; + + return layout_sensitive_ops; +} + +bool Optimize(api::GraphRef& graph, bool allow_extended_ops, + const std::string& provider_type, OptimizerMode mode, + const std::unordered_set& layout_sensitive_ops) { + auto ctx = MakeOptimizerContext(graph, allow_extended_ops, provider_type, mode, layout_sensitive_ops); if (ctx == std::nullopt) { return false; } @@ -1781,15 +1936,14 @@ void WrapTransposesAroundNode(api::GraphRef& graph, api::NodeRef& node, std::unique_ptr SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node, std::string_view op_type, std::string_view domain) { - auto inputs = node.Inputs(); auto outputs = node.Outputs(); - auto new_node = graph.AddNode(op_type, inputs, outputs.size(), domain); + auto new_node = graph.CopyNode(node, op_type, domain); + for (size_t j = 0; j < outputs.size(); ++j) { if (outputs[j] != "") { graph.MoveOutput(node, j, *new_node, j); } } - new_node->CopyAttributes(node); graph.RemoveNode(node); return new_node; } diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 68254d870d..28f2d4a23d 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -69,7 +69,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, - gen_metadef_name, COREML); + gen_metadef_name, COREML, kCoreMLExecutionProvider); const auto num_of_partitions = result.size(); const auto num_of_supported_nodes = std::transform_reduce( diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 3f95048335..5757470e7c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -262,7 +262,7 @@ Status ModelBuilder::RegisterInitializers() { shaper_.AddShape(name, operand_type.dimensions); uint32_t index = 0; - ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index)); + ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index)); const size_t size = operand_type.GetOperandBlobByteSize(); const size_t padded_size = GetPaddedByteSize(size); sizeAll += padded_size; @@ -352,7 +352,7 @@ Status ModelBuilder::RegisterModelInputs() { shaper_.AddShape(input_name, operand_type.dimensions); uint32_t index = 0; - ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, false /* is_nhwc */, index)); + ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, index)); input_index_vec_.push_back(index); nnapi_model_->AddInput(input_name, operand_type); } @@ -386,11 +386,6 @@ Status ModelBuilder::RegisterModelOutputs() { } std::string nnapi_output_name = output_name; - if (IsOperandNHWC(output_name)) { - // We need to transpose the output still in nhwc back to nchw - nnapi_output_name = GetUniqueName(output_name + "_nhwc_to_nchw"); - ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(*this, output_name, nnapi_output_name)); - } output_index_vec_.push_back(operand_indices_[nnapi_output_name]); nnapi_model_->AddOutput(output_name, nnapi_output_name, operand_types_.at(nnapi_output_name)); @@ -405,10 +400,10 @@ void ModelBuilder::RegisterModelShaper() { Status ModelBuilder::AddNewOperand(const std::string& name, const OperandType& operand_type, - bool is_nhwc, uint32_t& index) { + uint32_t& index) { LOGS_DEFAULT(VERBOSE) << "operand name: " << name; ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operand_type, index)); - RegisterOperand(name, index, operand_type, is_nhwc); + RegisterOperand(name, index, operand_type); return Status::OK(); } @@ -432,13 +427,10 @@ Status ModelBuilder::AddNewNNAPIOperand(const OperandType& operand_type, uint32_ } void ModelBuilder::RegisterOperand(const std::string& name, uint32_t index, - const OperandType& operand_type, bool is_nhwc) { + const OperandType& operand_type) { operand_indices_[name] = index; operand_types_.emplace(name, operand_type); operands_.insert(name); - - if (is_nhwc) - RegisterNHWCOperand(name); } Status ModelBuilder::SetOperandValue(uint32_t index, @@ -466,7 +458,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( const android::nn::wrapper::OperandType& operand_type) { shaper_.AddShape(name, operand_type.dimensions); uint32_t index = 0; - ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index)); + ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index)); const size_t size = operand_type.GetOperandBlobByteSize(); // for small size operand, the value will be copied @@ -528,12 +520,11 @@ Status ModelBuilder::AddOperations() { Status ModelBuilder::AddOperation(int op, const std::vector& input_indices, const std::vector& output_names, - const std::vector& types, - const std::vector& is_nhwc_vec) { + const std::vector& types) { std::vector output_indices; for (size_t i = 0; i < types.size(); i++) { uint32_t index = 0; - ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], is_nhwc_vec[i], index)); + ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], index)); output_indices.push_back(index); } @@ -693,7 +684,6 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) { const auto& op_type = node_unit.GetNode().OpType(); if (!Contains(op_builders, op_type)) return nullptr; - return op_builders.at(op_type); } @@ -708,47 +698,13 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +DataLayout ModelBuilder::GetPreferredLayout() const { + return use_nchw_ ? DataLayout::NCHW : DataLayout::NHWC; +} + const InitializedTensorSet& ModelBuilder::GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } -void ModelBuilder::RegisterNHWCOperand(const std::string& name) { - nhwc_operands_.insert(name); -} - -bool ModelBuilder::IsOperandNHWC(const std::string& name) const { - return Contains(nhwc_operands_, name); -} - -bool ModelBuilder::GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name) { - if (Contains(nhwc_to_nchw_map_, nhwc_name)) { - nchw_name = nhwc_to_nchw_map_[nhwc_name]; - return true; - } - return false; -} - -bool ModelBuilder::GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name) { - if (Contains(nchw_to_nhwc_map_, nchw_name)) { - nhwc_name = nchw_to_nhwc_map_[nchw_name]; - return true; - } - return false; -} - -Status ModelBuilder::SetNHWCToNCHWOperandMap(const std::string& nhwc_name, - const std::string& nchw_name) { - ORT_RETURN_IF_NOT(!Contains(nhwc_to_nchw_map_, nhwc_name), "A previous nchw to nhwc map exists"); - nhwc_to_nchw_map_[nhwc_name] = nchw_name; - return Status::OK(); -} - -Status ModelBuilder::SetNCHWToNHWCOperandMap(const std::string& nchw_name, - const std::string& nhwc_name) { - ORT_RETURN_IF_NOT(!Contains(nchw_to_nhwc_map_, nchw_name), "A previous nchw to nhwc map exists"); - nchw_to_nhwc_map_[nchw_name] = nhwc_name; - return Status::OK(); -} - } // namespace nnapi } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index 2269c986f6..76178f440e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -13,6 +13,7 @@ namespace onnxruntime { class GraphViewer; +enum class DataLayout; class NodeUnit; class Node; class NodeArg; @@ -47,8 +48,7 @@ class ModelBuilder { // Add an NNAPI operation (operator) common::Status AddOperation(int op, const std::vector& input_indices, const std::vector& output_names, - const std::vector& types, - const std::vector& is_nhwc_vec); + const std::vector& types); // Find if the given node_unit has a fuseable activation (Relu/Relu1/Relu6) // For now we only support node_unit with a single output @@ -69,8 +69,7 @@ class ModelBuilder { // Register informations for a particular operand void RegisterOperand(const std::string& name, uint32_t index, - const android::nn::wrapper::OperandType& operand_type, - bool is_nhwc); + const android::nn::wrapper::OperandType& operand_type); // Generate an unique name for intermediate result std::string GetUniqueName(const std::string& base_name); @@ -79,6 +78,9 @@ class ModelBuilder { void SetUseNCHW(bool use_nchw) { use_nchw_ = use_nchw; } bool UseNCHW() const { return use_nchw_; } + // Returns the preferred layout for this EP. + DataLayout GetPreferredLayout() const; + // Relax fp32 computation to fp16 // It is off by default void SetUseFp16(bool use_fp16) { use_fp16_ = use_fp16; } @@ -106,21 +108,9 @@ class ModelBuilder { const GraphViewer& GetGraphViewer() const { return graph_viewer_; } - void RegisterNHWCOperand(const std::string& name); - bool IsOperandNHWC(const std::string& name) const; - - // Get the operand transposed to nchw/nhwc from given nhwc/nchw operand, if it exists - bool GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name); - bool GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name); - // Get the NodeUnit which contains the given node const NodeUnit& GetNodeUnit(const Node* node) const; - common::Status SetNHWCToNCHWOperandMap(const std::string& nhwc_name, - const std::string& nchw_name); - common::Status SetNCHWToNHWCOperandMap(const std::string& nchw_name, - const std::string& nhwc_name); - private: const NnApi* nnapi_{nullptr}; const GraphViewer& graph_viewer_; @@ -148,12 +138,6 @@ class ModelBuilder { std::unordered_map> op_support_checkers_; - // Operands in nhwc - std::unordered_set nhwc_operands_; - - // Maps between nhwc and nchw, and vice versa - std::unordered_map nhwc_to_nchw_map_; - std::unordered_map nchw_to_nhwc_map_; std::vector input_index_vec_; std::vector output_index_vec_; @@ -206,7 +190,6 @@ class ModelBuilder { common::Status AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type, uint32_t& index); common::Status AddNewOperand(const std::string& name, const android::nn::wrapper::OperandType& operand_type, - bool is_nhwc, uint32_t& index); static const IOpBuilder* GetOpBuilder(const NodeUnit& node_unit); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 07d310ac2b..41c7fdba9f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -40,8 +40,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder, const std::string& input, const std::string& perm_name, std::vector perm, - const std::string& output, - bool output_is_nhwc) { + const std::string& output) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -59,102 +58,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder, OperandType output_operand_type = operand_types.at(input); output_operand_type.SetDimensions(shaper[output]); return model_builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, {output}, - {output_operand_type}, {output_is_nhwc}); -} - -Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, - const std::string& input, - const std::string& output, - bool nchw_to_nhwc) { - ORT_RETURN_IF_NOT(!model_builder.UseNCHW(), "model_builder.UseNCHW() is on"); - const auto& shaper(model_builder.GetShaper()); - ORT_RETURN_IF_NOT(4 == shaper[input].size(), - "TransposeBetweenNCHWAndNHWC input has to be a 4d tensor, actual dimensions: ", shaper[input].size()); - - std::string perm_name; - std::vector perm; - if (nchw_to_nhwc) { - perm_name = model_builder.GetUniqueName(input + "nchw_to_nhwc_perm"); - perm = {0, 2, 3, 1}; - } else { // nhwc_to_nchw - perm_name = model_builder.GetUniqueName(input + "nhwc_to_nchw_perm"); - perm = {0, 3, 1, 2}; - } - - ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, nchw_to_nhwc)); - - if (nchw_to_nhwc) { - ORT_RETURN_IF_ERROR(model_builder.SetNCHWToNHWCOperandMap(input, output)); - } else { // nhwc_to_nchw - ORT_RETURN_IF_ERROR(model_builder.SetNHWCToNCHWOperandMap(input, output)); - } - - LOGS_DEFAULT(VERBOSE) << "Operand [" << input << "] with shape " - << Shape2String(shaper[input]) - << " is transposed " - << (nchw_to_nhwc ? "nchw_to_nhwc" : "nhwc_to_nchw") - << " to [" << output << "] with shape " - << Shape2String(shaper[output]); - - return Status::OK(); -} - -Status TransposeNHWCToNCHW(ModelBuilder& model_builder, - const std::string& input, - const std::string& output) { - return TransposeBetweenNCHWAndNHWC(model_builder, input, output, false /* nchw_to_nhwc */); -} - -Status TransposeNCHWToNHWC(ModelBuilder& model_builder, - const std::string& input, - const std::string& output) { - return TransposeBetweenNCHWAndNHWC(model_builder, input, output, true /* nchw_to_nhwc */); -} - -// Convert the input from nchw to nhwc -// Caller should ensure input is currently in nchw format using ModelBuilder::IsOperandNHWC -Status GetNHWCInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nhwc_input) { - const auto& nchw_input = node_unit.Inputs()[input_index].node_arg.Name(); - if (!model_builder.GetNHWCOperand(nchw_input, nhwc_input)) { - nhwc_input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc"); - ORT_RETURN_IF_ERROR(TransposeNCHWToNHWC(model_builder, nchw_input, nhwc_input)); - } - return Status::OK(); -} - -// Convert the input from nhwc to nchw -// Caller should ensure input is currently in nhwc format using ModelBuilder::IsOperandNHWC -Status GetNCHWInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nchw_input) { - const auto& nhwc_input = node_unit.Inputs()[input_index].node_arg.Name(); - if (!model_builder.GetNCHWOperand(nhwc_input, nchw_input)) { - nchw_input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); - ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(model_builder, nhwc_input, nchw_input)); - } - return Status::OK(); -} - -// Transpose layouts if necessary for element wise operators with 2 inputs -// and return the layout type of output tensor -// If both inputs have same layout, the output will have the same layout -// Otherwise we will need transpose the nhwc input back to nchw, and output will be nchw -Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const NodeUnit& node_unit, - std::string& input1, std::string& input2, - bool& output_is_nhwc) { - bool input1_is_nhwc = model_builder.IsOperandNHWC(input1); - bool input2_is_nhwc = model_builder.IsOperandNHWC(input2); - output_is_nhwc = false; - - if (input1_is_nhwc == input2_is_nhwc) { - output_is_nhwc = input1_is_nhwc; - } else if (input1_is_nhwc) { - // need transpose input1 back to nchw - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input1)); - } else { // input2_is_nhwc - // need transpose input2 back to nchw - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 1, input2)); - } - - return Status::OK(); + {output_operand_type}); } static Status AddBinaryOperator(int32_t op_type, @@ -164,7 +68,6 @@ static Status AddBinaryOperator(int32_t op_type, bool add_activation, int32_t fuse_code, const std::string& output, - bool output_is_nhwc, float output_scale = 0.0f, int32_t output_zero_point = 0) { auto& shaper(model_builder.GetShaper()); @@ -183,7 +86,7 @@ static Status AddBinaryOperator(int32_t op_type, const OperandType output_operand_type(operand_types.at(input1).type, shaper[output], output_scale, output_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_type, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -231,7 +134,7 @@ static Status AddSqueezeOp(ModelBuilder& model_builder, ORT_RETURN_IF_ERROR(shaper.Squeeze(input, axes, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SQUEEZE, input_indices, - {output}, {output_operand_type}, {false})); + {output}, {output_operand_type})); return Status::OK(); } @@ -595,6 +498,16 @@ static void AddInputToSkip(ModelBuilder& model_builder, const NodeUnitIODef& io_ AddQuantizationScaleAndZeroPointToSkip(model_builder, *io_def.quant_param); } +static Status IsOpInRequiredLayout(bool use_nchw, const NodeUnit& node_unit) { + bool is_op_nhwc = node_unit.Domain() == kMSInternalNHWCDomain; + if (is_op_nhwc && use_nchw) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Expected layout and operator layout do not match. Possible bug in layout optimizer."); + } + + return Status::OK(); +} + template void CreateSharedOpBuilderImpl(const std::string& op_type, OpBuilderRegistrations& op_registrations, @@ -715,10 +628,6 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::string input2 = inputs[1].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = false; - ORT_RETURN_IF_ERROR( - TransposeBinaryOpInputLayout(model_builder, node_unit, input1, input2, output_is_nhwc)); - float a_scale = 0.0f, b_scale = 0.0f, y_scale = 0.0f; @@ -747,7 +656,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return AddBinaryOperator(op_code, model_builder, input1, input2, add_activation, fuse_code, - output, output_is_nhwc, y_scale, y_zero_point); + output, y_scale, y_zero_point); } #pragma endregion @@ -766,19 +675,18 @@ Status ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); // skip this relu if it is some op's fuse output if (Contains(model_builder.GetFusedActivations(), input)) { LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node_unit.Name() << "] fused"; - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); } else { std::vector input_indices; input_indices.push_back(operand_indices.at(input)); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RELU, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); } return Status::OK(); @@ -825,15 +733,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension"); } - if (model_builder.IsOperandNHWC(input)) { - ORT_RETURN_IF_NOT(input_dims == 4, "Only 4D shape can be nhwc"); - - // we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc - const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2}; - for (size_t i = 0; i < perm.size(); i++) - perm[i] = axis_nchw_to_nhwc[perm[i]]; - } - // Check if the quantization scale and ZP are correct if (IsQuantizedOp(node_unit)) { float x_scale = 0.0f; @@ -845,11 +744,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co std::string perm_name = model_builder.GetUniqueName(node_unit.Name() + input + "perm"); - // It is possible this onnx transpose operator can be nchw->nhwc, but so far I don't see - // any scenario will do this since onnx is nchw only, assume the output is always not nhwc - // even it is, there will be extra transpose in the onnx model to convert it back to nchw - // before conv/pool/... operators - ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, false /* is_nhwc */)); + ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output)); return Status::OK(); } @@ -884,7 +779,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { } // We can skip the Reshape if all the output edges satisfies both the following conditions -// 1. The output the reshape/flatten is not an output of the graph +// 1. The output of the reshape/flatten is not an output of the graph // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, // and not any other types of operator, // and the input rank >= 2 and output_rank == 2 @@ -975,7 +870,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { // NNAPI CPU impl and NNAPI hardware accelerator impl if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { // Since reshape can be skipped, only register the dimension and type, with same index and new name - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); } else { // We still need to perform a reshape here // Add input @@ -987,7 +882,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen); ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type)); input_indices.push_back(operand_indices.at(shape_name)); - ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false})); + ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type})); } return Status::OK(); @@ -998,10 +893,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons const auto& initializers(model_builder.GetInitializerTensors()); auto input = node_unit.Inputs()[0].node_arg.Name(); - if (model_builder.IsOperandNHWC(input)) { - // We want to transpose nhwc operand back to nchw before reshape - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); - } const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name()); std::vector unpacked_tensor; @@ -1099,10 +990,10 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu const auto tensor_imm_product_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_mul"); Shape tensor_a_dimen = {size}; - bool input_is_nhwc = model_builder.IsOperandNHWC(input); - bool output_is_nhwc = input_is_nhwc; + bool use_nchw = model_builder.UseNCHW(); + ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit)); - if (!input_is_nhwc) { + if (use_nchw) { // the batch normalization is applied on C channel, // if the input is NC[HW], will need correct shape for tensor_a/b // to make sure we are broadcasting on the correct channel, @@ -1126,8 +1017,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder, input, tensor_a_name, true /* add_activation */, ANEURALNETWORKS_FUSED_NONE, - tensor_imm_product_name, - output_is_nhwc)); + tensor_imm_product_name)); // Add int32_t fuse_code = model_builder.FindActivation(node_unit); @@ -1135,8 +1025,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder, tensor_imm_product_name, tensor_b_name, true /* add_activation */, fuse_code, - output, - output_is_nhwc)); + output)); return Status::OK(); } @@ -1190,16 +1079,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N auto input = node_unit.Inputs()[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); - bool input_is_nhwc = model_builder.IsOperandNHWC(input); - bool output_is_nhwc = false; - if (use_nchw) { - ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC"); - } else { - output_is_nhwc = true; - if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); - } - } + ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit)); const auto& output = node_unit.Outputs()[0].node_arg.Name(); const auto& op_type = node_unit.OpType(); @@ -1290,7 +1170,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -1361,16 +1241,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N auto input = inputs[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); - bool input_is_nhwc = model_builder.IsOperandNHWC(input); - bool output_is_nhwc = false; - if (use_nchw) { - ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC"); - } else { - output_is_nhwc = true; - if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); - } - } + ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit)); const auto& weight = inputs[1].node_arg.Name(); const auto& weight_tensor = *initializers.at(weight); @@ -1559,7 +1430,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -1579,7 +1450,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); auto to = helper.Get("to", 0); Type type; @@ -1599,7 +1469,7 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, {output}, - {output_operand_type}, {output_is_nhwc})); + {output_operand_type})); return Status::OK(); } @@ -1620,21 +1490,14 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons NodeAttrHelper helper(node_unit); auto input = node_unit.Inputs()[0].node_arg.Name(); - bool input_is_nhwc = model_builder.IsOperandNHWC(input); - bool output_is_nhwc = input_is_nhwc; + + // TODO: Needs fix. if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { - if (model_builder.IsOperandNHWC(input)) { - output_is_nhwc = false; - // We want to transpose nhwc operand back to nchw before softmax - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); - } + ORT_ENFORCE(model_builder.UseNCHW(), + "For Android API Level < 29 input for softmax needs to be NCHW."); } int32_t axis = helper.Get("axis", 1); - if (output_is_nhwc) { - const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2}; - axis = axis_nchw_to_nhwc[axis]; - } const auto& output = node_unit.Outputs()[0].node_arg.Name(); float beta = 1.f; @@ -1650,7 +1513,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -1672,14 +1535,13 @@ Status IdentityOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, con const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); std::vector input_indices; input_indices.push_back(operand_indices.at(input)); // input ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); return Status::OK(); } @@ -1839,7 +1701,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_ERROR(shaper.FC(input1, input2, output)); const OperandType output_operand_type(operand_types.at(input1).type, shaper[output], y_scale, y_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices, - {output}, {output_operand_type}, {false})); + {output}, {output_operand_type})); return Status::OK(); } @@ -1896,8 +1758,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); - ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); bool is_qlinear_sigmoid = op_type == "QLinearSigmoid"; @@ -1945,7 +1805,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const input_indices.push_back(operand_indices.at(input)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -1967,8 +1827,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::vector input_indices; const auto& input0 = inputs[0].node_arg.Name(); - bool all_input_have_same_layout = true; - bool output_is_nhwc = false; const auto node_input_size = inputs.size(); // First if the inputs are uint8, we need verify all the inputs have same scale and zero points @@ -1989,48 +1847,17 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } } - // First we want to see if all the input are same layout - for (size_t i = 0; i < node_input_size - 1; i++) { - all_input_have_same_layout = - all_input_have_same_layout && - model_builder.IsOperandNHWC(inputs[i].node_arg.Name()) == - model_builder.IsOperandNHWC(inputs[i + 1].node_arg.Name()); - } - std::vector input_names; input_names.reserve(node_input_size); - if (all_input_have_same_layout) { - // if all the inputs are of same layout, output will be the same layout - output_is_nhwc = model_builder.IsOperandNHWC(input0); - - for (size_t i = 0; i < node_input_size; i++) { - const auto& input = inputs[i].node_arg.Name(); - input_indices.push_back(operand_indices.at(input)); - input_names.push_back(input); - } - } else { - // if all the inputs are not same layout, - // will need transpos those nhwc tensors back to nchw - for (size_t i = 0; i < node_input_size; i++) { - auto input = inputs[i].node_arg.Name(); - if (model_builder.IsOperandNHWC(input)) { - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, i, input)); - } - input_indices.push_back(operand_indices.at(input)); - input_names.push_back(input); - } + for (size_t i = 0; i < node_input_size; i++) { + const auto& input = inputs[i].node_arg.Name(); + input_indices.push_back(operand_indices.at(input)); + input_names.push_back(input); } int rank = shaper[input0].size(); int32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); - if (output_is_nhwc) { - ORT_RETURN_IF_NOT(rank == 4, - "nhwc is only on 4d shape, input ", input0, " has rank: ", rank); - // we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc - const uint32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2}; - axis = axis_nchw_to_nhwc[axis]; - } ADD_SCALAR_OPERAND(model_builder, input_indices, axis); const auto& output = node_unit.Outputs()[0].node_arg.Name(); @@ -2038,7 +1865,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const OperandType output_operand_type = operand_types.at(input0); output_operand_type.SetDimensions(shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2089,10 +1916,6 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto input = node_unit.Inputs()[0].node_arg.Name(); - if (model_builder.IsOperandNHWC(input)) { - // We want to transpose nhwc operand back to nchw before squeeze - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); - } std::vector axes; ORT_RETURN_IF_ERROR(GetAxes(model_builder, node_unit, axes)); @@ -2121,7 +1944,6 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); float scale = 0.0f; int32_t zero_point = 0; @@ -2134,7 +1956,7 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde std::vector input_indices; input_indices.push_back(operand_indices.at(input)); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_QUANTIZE, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2161,7 +1983,6 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil const auto& input = inputs[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); float scale = 0.0; int32_t zero_point = 0; @@ -2176,7 +1997,7 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil std::vector input_indices; input_indices.push_back(operand_indices.at(input)); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_DEQUANTIZE, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2198,13 +2019,14 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No auto input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); + auto use_nchw = model_builder.UseNCHW(); + ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit)); + if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { // on android api level 28, we need to transpose the nchw input to nhwc - output_is_nhwc = true; - if (!model_builder.IsOperandNHWC(input)) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); - } + // it is very rare that users set nchw format when using nnapi. Therefore, instead of + // adding the ability to support conversion we fail and stop. + ORT_ENFORCE(!use_nchw, "NCHW format is not supported on android api level 28"); } auto alpha = helper.Get("alpha", 0.0001f); @@ -2225,16 +2047,16 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No // specify axis is only available on api level >= 29 if (android_feature_level > ANEURALNETWORKS_FEATURE_LEVEL_2) { // ONNX LRN is always performed on C dimension - int32_t axis = output_is_nhwc - ? 3 // nhwc - : 1; // nchw + int32_t axis = use_nchw + ? 1 // nchw + : 3; // nhwc ADD_SCALAR_OPERAND(model_builder, input_indices, axis); } ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2266,14 +2088,13 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); if (Contains(model_builder.GetFusedActivations(), input)) { LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node_unit.Name() << "] fused"; - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); return Status::OK(); } @@ -2293,7 +2114,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector input_indices; input_indices.push_back(operand_indices.at(input)); ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2344,16 +2165,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto input = inputs[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); - bool input_is_nhwc = model_builder.IsOperandNHWC(input); - bool output_is_nhwc = false; - if (use_nchw) { - ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC"); - } else { - output_is_nhwc = true; - if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); - } - } + ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit)); // Check if the quantization scale and ZP is correct if (IsQuantizedOp(node_unit)) { @@ -2373,16 +2185,19 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const bool using_half_pixel = coord_trans_mode == "half_pixel"; bool using_align_corners = coord_trans_mode == "align_corners"; + // if the node domain is NHWC it means all the node inputs are converted to NHWC format by the layout transformer. + // pick the index for height and width based on the format. + int h_idx = use_nchw ? 2 : 1; + int w_idx = use_nchw ? 3 : 2; + if (inputs.size() == 3) { // we are using scales const auto& scales_name = inputs[2].node_arg.Name(); const auto& scales_tensor = *initializers.at(scales_name); std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor)); const float* scales_data = reinterpret_cast(unpacked_tensor.data()); - float scale_h = scales_data[2]; - float scale_w = scales_data[3]; ORT_RETURN_IF_ERROR( - shaper.ResizeUsingScales(input, scale_h, scale_w, use_nchw, output)); + shaper.ResizeUsingScales(input, scales_data[h_idx], scales_data[w_idx], use_nchw, output)); } else { // we are using sizes const auto& sizes_name = inputs[3].node_arg.Name(); const auto& sizes_tensor = *initializers.at(sizes_name); @@ -2390,12 +2205,12 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor)); const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); ORT_RETURN_IF_ERROR( - shaper.ResizeUsingOutputSizes(input, SafeInt(sizes_data[2]), SafeInt(sizes_data[3]), use_nchw, output)); + shaper.ResizeUsingOutputSizes(input, SafeInt(sizes_data[h_idx]), SafeInt(sizes_data[w_idx]), use_nchw, output)); } const auto& output_shape = shaper[output]; - int32_t output_h = use_nchw ? output_shape[2] : output_shape[1]; - int32_t output_w = use_nchw ? output_shape[3] : output_shape[2]; + int32_t output_h = output_shape[h_idx]; + int32_t output_w = output_shape[w_idx]; std::vector input_indices; input_indices.push_back(operand_indices.at(input)); @@ -2420,7 +2235,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const OperandType output_operand_type = operand_types.at(input); output_operand_type.SetDimensions(output_shape); ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2436,10 +2251,6 @@ class FlattenOpBuilder : public BaseOpBuilder { Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto input = node_unit.Inputs()[0].node_arg.Name(); - if (model_builder.IsOperandNHWC(input)) { - // We want to transpose nhwc operand back to nchw before reshape - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); - } // Flatten is basically a reshape to 2d tensor // Get the shape for Reshape here @@ -2467,8 +2278,7 @@ class MinMaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; static Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, - const std::string& input1, const std::string& input2, - bool output_is_nhwc); + const std::string& input1, const std::string& input2); }; /* static */ void MinMaxOpBuilder::CreateSharedOpBuilder( @@ -2482,8 +2292,7 @@ class MinMaxOpBuilder : public BaseOpBuilder { } /* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, - const std::string& input1, const std::string& input2, - bool output_is_nhwc) { + const std::string& input1, const std::string& input2) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -2506,7 +2315,7 @@ class MinMaxOpBuilder : public BaseOpBuilder { ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output)); const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, - {output}, {output_operand_type}, {output_is_nhwc})); + {output}, {output_operand_type})); return Status::OK(); } @@ -2515,11 +2324,8 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& inputs = node_unit.Inputs(); std::string input1 = inputs[0].node_arg.Name(); std::string input2 = inputs[1].node_arg.Name(); - bool output_is_nhwc = false; - ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node_unit, - input1, input2, output_is_nhwc)); - return AddMinMaxOperator(model_builder, node_unit, input1, input2, output_is_nhwc); + return AddMinMaxOperator(model_builder, node_unit, input1, input2); } #pragma endregion @@ -2537,7 +2343,6 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No const auto& operand_types(model_builder.GetOperandTypes()); const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); NodeAttrHelper helper(node_unit); @@ -2546,7 +2351,7 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No input_indices.push_back(operand_indices.at(input)); ADD_SCALAR_OPERAND(model_builder, input_indices, alpha); return model_builder.AddOperation(ANEURALNETWORKS_ELU, input_indices, - {output}, {output_operand_type}, {output_is_nhwc}); + {output}, {output_operand_type}); } #pragma endregion @@ -2644,7 +2449,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& input = inputs[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - bool output_is_nhwc = model_builder.IsOperandNHWC(input); // No shape inference for Slice, everything is calculated here, we only need to add the output shape // to the shaper @@ -2718,7 +2522,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // end_mask ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // shrink_axis_mask } - return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); + return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}); } #pragma endregion diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index 5150d7ae37..e4a752550f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -72,6 +72,10 @@ static bool HasExternalInitializer(const InitializedTensorSet& initializers, con return false; } +inline bool IsNodeLayoutNHWC(const NodeUnit& node_unit) { + return node_unit.Domain() == kMSInternalNHWCDomain; +} + static bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const OpSupportCheckParams& params, @@ -1768,7 +1772,7 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi } const float* scales_data = reinterpret_cast(unpacked_tensor.data()); float scale_n = scales_data[0]; - float scale_c = scales_data[1]; + float scale_c = IsNodeLayoutNHWC(node_unit) ? scales_data[3] : scales_data[1]; if (scale_n != 1.0f || scale_c != 1.0f) { LOGS_DEFAULT(VERBOSE) << "Scales of N/C channel should be 1" << "Resize of N/C channels are not supported" @@ -1785,14 +1789,16 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi LOGS_DEFAULT(ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage(); return false; } + + int channel_idx = IsNodeLayoutNHWC(node_unit) ? 3 : 1; const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); uint32_t size_n = SafeInt(sizes_data[0]); - uint32_t size_c = SafeInt(sizes_data[1]); - if (size_n != input_shape[0] || size_c != input_shape[1]) { - LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + uint32_t size_c = SafeInt(sizes_data[channel_idx]); + if (size_n != input_shape[0] || size_c != input_shape[channel_idx]) { + LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C channel should match the input sizes, " << "Resize of N/C channels are not supported" << ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n - << ". input_size_c, " << input_shape[1] << ", output_size_c, " << size_c; + << ". input_size_c, " << input_shape[channel_idx] << ", output_size_c, " << size_c; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 4ff0b41a51..36209fd2e1 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -176,7 +176,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, - gen_metadef_name, NNAPI); + gen_metadef_name, NNAPI, kNnapiExecutionProvider); const auto num_of_partitions = result.size(); const auto num_of_supported_nodes = std::transform_reduce( @@ -203,6 +203,10 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return result; } +DataLayout NnapiExecutionProvider::GetPreferredLayout() const { + return nnapi_flags_ & NNAPI_FLAG_USE_NCHW ? DataLayout::NCHW : DataLayout::NHWC; +} + #ifdef __ANDROID__ static Status GetOutputBuffer(Ort::CustomOpApi& ort, OrtKernelContext* context, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index 528b7671ea..f42e17e713 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -33,6 +33,8 @@ class NnapiExecutionProvider : public IExecutionProvider { uint32_t GetNNAPIFlags() const { return nnapi_flags_; } + DataLayout GetPreferredLayout() const override; + private: // The bit flags which define bool options for NNAPI EP, bits are defined as // NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index 12d0f9d602..f0bdd36d92 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -87,6 +87,7 @@ std::vector> CreateSupportedPartitionNodeGroups( const GraphViewer& graph_viewer, const IsNodeSupportedFn& is_node_supported_fn, const OnGroupClosedFn& on_group_closed_fn, + const std::string& execution_provider_type, bool debug_output) { #ifdef NDEBUG ORT_UNUSED_PARAMETER(debug_output); @@ -104,7 +105,8 @@ std::vector> CreateSupportedPartitionNodeGroups( std::deque nodes_to_process_with_next_group{}; // initialize in-degrees and find root nodes - for (const auto& node : graph_viewer.Nodes()) { + for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) { + const auto& node = *graph_viewer.GetNode(node_index); const auto node_input_edge_count = node.GetInputEdgesCount(); in_degree.insert({node.Index(), node_input_edge_count}); if (node_input_edge_count == 0) { @@ -156,9 +158,9 @@ std::vector> CreateSupportedPartitionNodeGroups( const Node& node = *nodes_to_process.front(); nodes_to_process.pop_front(); + // a node that is already assigned to an EP other than current EP is unsupported const bool is_node_supported = - node.GetExecutionProviderType().empty() && // a node that is already assigned to an EP is unsupported - is_node_supported_fn(node); + (node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node); if (!is_node_supported && Contains(supported_group_border, &node)) { // an unsupported node on the border will be processed after the current partition node group @@ -203,9 +205,7 @@ std::unordered_set CreateExcludedNodeSet(const GraphViewer& graph_v const std::unordered_set& stop_ops) { std::unordered_set excluded_nodes; - for (const NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) { - const Node& node = *graph_viewer.GetNode(node_index); - + for (const auto& node : graph_viewer.Nodes()) { if (!Contains(excluded_nodes, &node) && Contains(stop_ops, node.OpType())) { excluded_nodes.insert(&node); @@ -309,10 +309,12 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const OnGroupClosedFn& on_partition_closed_fn, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, + const std::string& execution_provider_type, bool debug_output) { const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer, is_node_supported_fn, on_partition_closed_fn, + execution_provider_type, debug_output); std::vector> partitions{}; @@ -335,6 +337,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& stop_ops, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, + const std::string& execution_provider_type, bool debug_output) { const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops); const bool check_excluded_nodes = !excluded_nodes.empty(); @@ -348,6 +351,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, {}, generate_metadef_name_fn, execution_provider_name, + execution_provider_type, debug_output); } diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index 865677e848..c00044cac1 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -54,6 +54,7 @@ Create the supported partitions for the execution provider. @param on_group_closed_fn Callback to indicate a completed partition node group. @param generate_metadef_name_fn Callback to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. +@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -65,6 +66,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const OnGroupClosedFn& on_group_closed_fn, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, + const std::string& execution_provider_type, bool debug_output = false); /** @@ -75,6 +77,7 @@ Create the supported partitions for the execution provider. @param stop_ops Set of operator names at which we stop considering nodes for assignment to this execution provider. @param generate_metadef_name Functor to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. +@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -86,6 +89,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& stop_ops, const GenerateMetadefNameFn& generate_metadef_name, const std::string& execution_provider_name, + const std::string& execution_provider_type, bool debug_output = false); /** diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 78b5158944..64d423af6c 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -135,6 +135,7 @@ struct KernelRegistry; struct Function; struct Graph; struct GraphViewer; +enum class DataLayout; struct Model; struct Path; struct Node; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index daa47db1fc..27019a7f1d 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -222,6 +222,7 @@ Status Environment::Initialize(std::unique_ptr logging_ } domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1); domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1); + domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSInternalNHWCDomain, 1, 1); #ifdef USE_DML domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1); #endif diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 71aa9bc499..cc5ed47085 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -32,6 +32,7 @@ #include "core/framework/op_kernel_context_internal.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/utils.h" +#include "core/framework/static_kernel_def_hashes.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_transformer_utils.h" @@ -40,6 +41,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/optimizer/transformer_memcpy.h" +#include "core/optimizer/transpose_optimizer/optimizer_utils.h" #include "core/platform/Barrier.h" #include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" @@ -913,7 +915,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // Do partitioning based on execution providers' capability. GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(), - session_state.GetMutableFuncMgr(), mode)); + session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayout, mode)); // apply transformers except default transformers // Default transformers are required for correctness and they are owned and run by inference session @@ -1138,6 +1141,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayout, GraphPartitioner::Mode::kOrtFormatLoad, &compiled_kernel_hashes)); @@ -1211,6 +1215,17 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs graph.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) { ORT_RETURN_IF_ERROR(set_node_ep(node_index, kernel_def_hash)); } + + // layout transformer which is enabled in extended minimal build can add new nodes. + // The following loop fetches the hash values for these nodes. + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType().empty()) { + auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); + if (kernel_hash.has_value()) { + ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value())); + } + } + } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD) return Status::OK(); @@ -1236,7 +1251,7 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -//VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong. +// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong. #pragma warning(disable : 26117) #endif common::Status InferenceSession::Initialize() { diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 344f563fab..519599b411 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -20,6 +20,7 @@ #include "gtest/gtest.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" +#include "core/optimizer/transpose_optimizer/optimizer_utils.h" using namespace ONNX_NAMESPACE; using namespace std; @@ -142,7 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { DefaultLoggingManager().DefaultLogger(), profiler); GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); + status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayout); ASSERT_TRUE(status.IsOK()) << status; ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -207,7 +209,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // Partition the graph GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); + status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayout); ASSERT_TRUE(status.IsOK()) << status; // Finalize the session state @@ -256,7 +259,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // Partition the graph GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); + status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayout); ASSERT_TRUE(status.IsOK()) << status; // Finalize the session state diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 6a9b8fb06e..9bb4e5b4bc 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -689,14 +689,15 @@ class ResizeOpTester : public OpTester { OpTester::AddNodes(graph, graph_input_defs, graph_output_defs, add_attribute_funcs); // set the Graph inputs to just X and roi (exclude 'scales') so the 'scales' are a constant initializer - if (scales_in_initializer_) { - // this isn't intended to work with a scenario where the optional 'sizes' input is provided - ASSERT_TRUE(graph_input_defs.size() == 3); + if(scales_in_initializer_) { graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()), graph.GetNodeArg(graph_input_defs[1]->Name())}); - } - - if (sizes_in_initializer_) { + if(sizes_in_initializer_) { + ASSERT_TRUE(graph_input_defs.size() == 4); + } else { + ASSERT_TRUE(graph_input_defs.size() == 3); + } + } else if (sizes_in_initializer_) { ASSERT_TRUE(graph_input_defs.size() == 4); // 'sizes' is 4th input graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()), graph.GetNodeArg(graph_input_defs[1]->Name()), @@ -744,9 +745,10 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Scales) { TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) { auto run_test = [](bool sizes_in_initializer) { - ResizeOpTester test(false, sizes_in_initializer); + ResizeOpTester test(sizes_in_initializer, sizes_in_initializer); std::vector roi{}; + std::vector scales{}; std::vector sizes{1, 1, 4, 4}; test.AddAttribute("mode", "nearest"); @@ -760,7 +762,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) { test.AddInput("X", {N, C, H, W}, X); test.AddInput("roi", {0}, roi); - test.AddInput("scales", {0}, {}); + test.AddInput("scales", {0}, scales, sizes_in_initializer); test.AddInput("sizes", {4}, sizes, sizes_in_initializer); std::vector Y = {1.0f, 1.0f, 2.0f, 2.0f, diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index bbf8ef053a..4a35a6e3ff 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -13,6 +13,7 @@ #include "core/graph/model.h" #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/optimizer/transpose_optimizer/optimizer_utils.h" #include @@ -22,12 +23,14 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops, const std::unordered_set& stop_ops, - bool debug_output) + bool debug_output, + DataLayout preferred_layout) : IExecutionProvider{utils::kInternalTestingExecutionProvider, true}, ep_name_{INTERNAL_TESTING_EP}, ops_{ops}, stop_ops_{stop_ops}, - debug_output_{debug_output} { + debug_output_{debug_output}, + preferred_layout_{preferred_layout} { // // TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through // the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first @@ -44,6 +47,10 @@ InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::un InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {} +DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { + return preferred_layout_; +} + std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const std::vector& /*registries*/) const { @@ -89,7 +96,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& }; return utils::CreateSupportedPartitions(graph_viewer, supported_nodes, stop_ops_, - generate_metadef_name, ep_name_, debug_output_); + generate_metadef_name, ep_name_, onnxruntime::utils::kInternalTestingExecutionProvider, debug_output_); } common::Status InternalTestingExecutionProvider::Compile(const std::vector& fused_nodes, @@ -99,14 +106,19 @@ common::Status InternalTestingExecutionProvider::Compile(const std::vector& ops, const std::unordered_set& stop_ops = {}, - bool debug_output = false); + bool debug_output = false, + DataLayout preferred_layout = static_cast(0)); virtual ~InternalTestingExecutionProvider(); std::vector> @@ -24,6 +25,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider { return FusionStyle::FilteredGraphViewer; } + DataLayout GetPreferredLayout() const override; + private: const std::string ep_name_; @@ -38,5 +41,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { const std::unordered_set stop_ops_; const bool debug_output_; + + DataLayout preferred_layout_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc index a0acb9518e..efe4885c3c 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc @@ -31,6 +31,7 @@ namespace test { // it would be possible to use ORT format models but the same partitioning code would run either way #if !defined(ORT_MINIMAL_BUILD) +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") // model has an unsupported node between the supported nodes after the initial topo sort. // the partition aware topo sort should result in the unsupported node moving to earlier in the order, // and allow a single partition of supported nodes to be created. @@ -44,7 +45,7 @@ TEST(InternalTestingEP, TestSortResultsInSinglePartition) { ASSERT_STATUS_OK(session->RegisterExecutionProvider( std::make_unique(supported_ops))); - const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_1.onnx"); + const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_1.onnx"; ASSERT_STATUS_OK(session->Load(model_path)); const auto& graph = session->GetGraph(); GraphViewer viewer{graph}; @@ -89,7 +90,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) { ASSERT_STATUS_OK(session->RegisterExecutionProvider( std::make_unique(supported_ops))); - const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_2.onnx"); + const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_2.onnx"; ASSERT_STATUS_OK(session->Load(model_path)); const auto& graph = session->GetGraph(); GraphViewer viewer{graph}; @@ -195,7 +196,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin } ASSERT_STATUS_OK(session->RegisterExecutionProvider( - std::make_unique(ops, stop_ops, debug_output))); + std::make_unique(ops, stop_ops, debug_output, DataLayout::NHWC))); ASSERT_STATUS_OK(session->Load(model_uri)); const auto& graph = session->GetGraph(); @@ -310,6 +311,7 @@ TEST(InternalTestingEP, DISABLED_TestNnapiPartitioningMlPerfModels) { "deeplabv3_mnv2_ade20k_float.onnx", "mobilebert.onnx", "mobiledet.onnx", + }; for (const auto& model_uri : model_paths) { diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 71032d560c..94574366a7 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -24,8 +24,10 @@ namespace onnxruntime { namespace test { +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") + static void CreateSession(const SessionOptions& so, std::unique_ptr& session, - const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model + const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "mnist.onnx", // arbitrary test model bool enable_custom_ep = true, const std::unordered_set* override_supported_ops = nullptr) { session = std::make_unique(so, GetEnvironment()); @@ -86,7 +88,7 @@ static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enable #if !defined(DISABLE_SPARSE_TENSORS) #if !defined(ORT_MINIMAL_BUILD) TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.test_output.ort"); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.test_output.ort"; // // First load the onnx format model and save as an ORT model. @@ -130,7 +132,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { } TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort"); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; // make sure we can't save a model with compiled ops. input/output model format doesn't matter SessionOptions so; @@ -152,7 +154,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { // test to validate a minimal build TEST(InternalTestingEP, TestLoadOrtModel) { - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort"); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; std::unique_ptr session; bool enable_custom_ep = true; @@ -164,7 +166,7 @@ TEST(InternalTestingEP, TestLoadOrtModel) { // test that is the custom EP cannot take all nodes due to device limitations // that we fallback to the CPU implementations and can execute the model TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort"); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; const std::unordered_set supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/}; std::unique_ptr session; @@ -225,7 +227,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph, // Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs. // There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels. TEST(InternalTestingEP, TestModelWithSubgraph) { - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "ort_github_issue_4031.onnx.ort"; const std::unordered_set supported_ops{"Add"}; std::unique_ptr session; @@ -302,8 +304,11 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) { // In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected // So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider // But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node - // This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node - const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort"); + // Post layout transformations we cannot revert back if compile fails because + // the layout transformation for this EP is already done at this stage and reverting + // can result in more failures. + // This is to test the model initialization fails if compile fails. + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; const std::unordered_set& supported_ops{"Conv", "Gemm"}; const std::unordered_set& compile_failure_ops{"Gemm"}; @@ -325,33 +330,12 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) { } // Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm" - // We should have 2 partitions taken by the EP - // 2 Conv { InferenceSessionWrapper session(SessionOptions(), GetEnvironment()); ASSERT_STATUS_OK(session.RegisterExecutionProvider( std::make_unique(supported_ops, compile_failure_ops))); ASSERT_STATUS_OK(session.Load(ort_model_path)); - ASSERT_STATUS_OK(session.Initialize()); - - // 2 Conv nodes shoule be replaced with fused nodes - const auto& graph = session.GetGraph(); - int num_replaced_nodes = CountAndValidateAssignedNodes( - session.GetGraph(), {"Conv"}, const_cast(session.GetSessionState()).GetMutableFuncMgr()); - - ASSERT_EQ(num_replaced_nodes, 2); - - // The Gemm node should still not have been replaced - int count_compile_failure_nodes = 0; - for (const auto& node : graph.Nodes()) { - if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end()) - count_compile_failure_nodes++; - } - ASSERT_EQ(count_compile_failure_nodes, 1); - - // Execute the session, since the last node is Gemm, and its input 0 is all 0s - // So the result should be the bias initializer of the Gemm node - ExecuteMnist(session, true /* enable_custom_ep */); + ASSERT_STATUS_NOT_OK(session.Initialize()); } } } // namespace test diff --git a/onnxruntime/test/providers/kernel_def_hash_test.cc b/onnxruntime/test/providers/kernel_def_hash_test.cc index 18fe708827..9fc654567a 100644 --- a/onnxruntime/test/providers/kernel_def_hash_test.cc +++ b/onnxruntime/test/providers/kernel_def_hash_test.cc @@ -80,6 +80,7 @@ #include "core/mlas/inc/mlas.h" #include "core/platform/env_var_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" +#include "gtest/gtest.h" using json = nlohmann::json; @@ -189,5 +190,30 @@ TEST(KernelDefHashTest, ExpectedCpuKernelDefHashes) { CheckKernelDefHashes(cpu_kernel_def_hashes, expected_cpu_kernel_def_hashes, is_strict); } +// This test is to ensure the latest opset version for ops which can be added +// during layout transformation step are added. IF this test fails then it means +// there is a new version available for one of the ops in the map. +// Adding this test here because resolution for this test failure requires fetching the hash +// for one of the ops in the list below and this file has information around that. +// Please update the following 3 places: +// 1. api_impl.cc "onnx_ops_available_versions" map, include the latest version in the map +// 2. static_kernel_def_hashes.cc "static_kernel_hashes" include an entry for latest version and it's associated hash +// 3. This file "onnx_ops_available_versions" map, include the latest version in the map +TEST(KernelDefHashTest, TestNewOpsVersionSupportDuringLayoutTransform) { + static const std::unordered_map> onnx_ops_available_versions = { + {"Squeeze", {1, 11, 13}}, + {"Unsqueeze", {1, 11, 13}}, + {"Gather", {1, 11, 13}}, + {"Transpose", {1, 13}}, + {"Identity", {1, 13, 14, 16}}, + }; + + auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); + for (const auto& [op_type, version_list] : onnx_ops_available_versions) { + auto schema = schema_registry->GetSchema(op_type, INT_MAX, kOnnxDomain); + EXPECT_EQ(schema->SinceVersion(), version_list[version_list.size() - 1]) << "A new version for op: " << op_type + << "is available. Please update the files mentioned in the comments of this test."; + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index bdbde8c8b0..2fef4f9491 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -315,12 +315,17 @@ TEST(NnapiExecutionProviderTest, TestQDQConv) { TEST(NnapiExecutionProviderTest, TestQDQResize) { // NNAPI EP does not support the default setting of Resize Op // Use bi-linear and asymmetric for NNAPI EP only + // Setting verify_entire_graph_use_ep for this test as false. This is because layout transformation adds + // Transpose (NCHW -> NHWC) nodes. Post tranformation graph looks like this Transpose -> DQ -> Resize -> Q -> Transpose + // NNAPI does not pick the first Transpose as its input is graph/partition input + // See https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc#L305 + // onnxruntime::nnapi::IsInternalQuantizationSupported RunQDQModelTest(BuildQDQResizeTestCase({1, 3, 64, 64} /* input_shape */, {1, 3, 32, 32} /* sizes_data */, "linear" /* mode */, "asymmetric" /* coordinate_transformation_mode */), "nnapi_qdq_test_graph_resize", - {true /* verify_entire_graph_use_ep */}); + {false /* verify_entire_graph_use_ep */}); } TEST(NnapiExecutionProviderTest, TestQDQAveragePool) {