mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Add layout transformer for NNAPI (#10371)
* Add layout transformer for NNAPI * plus merge fixes * plus some more merge fixes * test fixes * comments + cleanup * plus updates * post merge changes * enable layout transformer in extended minimal build * plus more comments * more tests + fix CI * plus updates per review * more updates per review * fix file name * fix qdq tests * plus more updates * plus updates * typo fix * fix qdq selection in 2nd optimization pass * fix typo * fix a test * update dependency structure for layout transformer * plus updates * more updates * plus change * more updates to fix linker error in minimal build * remove unnecessary headers
This commit is contained in:
parent
ceb1e2b1a6
commit
f436d3437e
38 changed files with 1097 additions and 700 deletions
|
|
@ -347,10 +347,11 @@ if (onnxruntime_MINIMAL_BUILD)
|
|||
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
# enable EPs that compile kernels at runtime
|
||||
add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
endif()
|
||||
# enable runtime optimizations. These are required for NNAPI.
|
||||
# TODO remove onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD since these optimzations are now
|
||||
# enabled for all extended minimal builds.
|
||||
SET(onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD ON)
|
||||
add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ if (onnxruntime_MINIMAL_BUILD)
|
|||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc"
|
||||
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.cc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
#include "core/framework/tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
enum class DataLayout;
|
||||
class GraphViewer;
|
||||
class Node;
|
||||
struct ComputeCapability;
|
||||
|
|
@ -274,6 +274,12 @@ class IExecutionProvider {
|
|||
return {};
|
||||
}
|
||||
|
||||
virtual DataLayout GetPreferredLayout() const {
|
||||
// NCHW is the default ONNX standard data layout. So default to it.
|
||||
// EPs which prefer a different layout should override to return their preferred layout.
|
||||
return static_cast<DataLayout>(0);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_;
|
||||
AllocatorMap allocators_;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ constexpr const char* kMLDomain = "ai.onnx.ml";
|
|||
constexpr const char* kMSDomain = "com.microsoft";
|
||||
constexpr const char* kMSExperimentalDomain = "com.microsoft.experimental";
|
||||
constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc";
|
||||
constexpr const char* kMSInternalNHWCDomain = "com.ms.internal.nhwc";
|
||||
constexpr const char* kMSDmlDomain = "com.microsoft.dml";
|
||||
constexpr const char* kNGraphDomain = "com.intel.ai";
|
||||
constexpr const char* kMIGraphXDomain = "";
|
||||
|
|
|
|||
|
|
@ -151,6 +151,12 @@ class Node {
|
|||
*/
|
||||
int SinceVersion() const noexcept { return since_version_; }
|
||||
|
||||
/** Sets the since version (opset version that the Node's operator was first defined in.) for this node.
|
||||
@remarks Used during layout transformation for setting since vesion for layout transformed nodes with
|
||||
domain kMSNHWC.
|
||||
*/
|
||||
void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Gets the Node's OpSchema.
|
||||
@remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */
|
||||
|
|
|
|||
|
|
@ -19,6 +19,12 @@ struct NodeCompare {
|
|||
bool operator()(const Node* n1, const Node* n2) const;
|
||||
};
|
||||
|
||||
enum class DataLayout {
|
||||
NCHW,
|
||||
NHWC,
|
||||
NCHWC,
|
||||
};
|
||||
|
||||
/**
|
||||
@class GraphViewer
|
||||
Class that provides a read-only view of the Graph.
|
||||
|
|
|
|||
|
|
@ -52,7 +52,97 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph
|
|||
.Provider(provider_type);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validate all the layout sensitive nodes which were transformed for current EP are indeed taken by current EP.
|
||||
/// If not, then we have a bug. If a node with domain kMSNHWC is left in the graph at this point then
|
||||
/// graph.Resolve will fail.
|
||||
/// Since layout transformation is only enabled for compile based EPs, just checking that graph does not contain
|
||||
/// a node with kMSNHWC domain is enough. This is because after compile all the nodes which the EP claims are fused
|
||||
/// into 1 and removed from the graph.
|
||||
/// </summary>
|
||||
/// <param name="graph">Graph to validate</param>
|
||||
/// <returns></returns>
|
||||
static Status ValidateGraphPartitioning(const Graph& graph) {
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
if (node.Domain() == kMSInternalNHWCDomain) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Graph contains an invalid node: " + node.Name() + " Op Type: " + node.OpType() +
|
||||
" with domain: " + kMSInternalNHWCDomain + ". These are temporary nodes added during layout transformations " +
|
||||
" and are not expected to remain in the graph post partitioning. This is a bug in layout transformer.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
|
||||
/// <summary>
|
||||
/// Check if a node can be placed on a specific provider. If yes, then set the nodes execution provider.
|
||||
/// Do nothing if the node is already assigned.
|
||||
/// </summary>
|
||||
/// <param name="graph">Graph in question.</param>
|
||||
/// <param name="capability">Indexed subgraph which needs to be assigned</param>
|
||||
/// <param name="provider_type">The EP to assign the Indexed subgraph to</param>
|
||||
static void AssignNodes(Graph& graph, const IndexedSubGraph& capability,
|
||||
const std::string& provider_type) {
|
||||
// Before assigning the ep to any node, first walk through all the nodes and ensure
|
||||
// none of the nodes have already been assigned. If a node is assigned, simply return.
|
||||
for (auto node_index : capability.nodes) {
|
||||
const auto* node = graph.GetNode(node_index);
|
||||
if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto node_index : capability.nodes) {
|
||||
auto* node = graph.GetNode(node_index);
|
||||
node->SetExecutionProviderType(provider_type);
|
||||
}
|
||||
}
|
||||
|
||||
#endif //! defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
|
||||
static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep,
|
||||
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
|
||||
TransformLayoutFunction transform_layout) {
|
||||
{
|
||||
GraphViewer graph_viewer(graph);
|
||||
capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
// Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC
|
||||
// CPU EP layout transformation happens later when level 3 transformers are run.
|
||||
if (mode != GraphPartitioner::Mode::kAssignOnly &&
|
||||
current_ep.GetPreferredLayout() == DataLayout::NHWC) {
|
||||
for (auto& capability : capabilities) {
|
||||
// in theory an EP could return an empty value...
|
||||
if (!capability || !capability->sub_graph) {
|
||||
continue;
|
||||
}
|
||||
AssignNodes(graph, *capability->sub_graph, current_ep.Type());
|
||||
}
|
||||
|
||||
// Perform layout transformation on the specific EP assigned graph
|
||||
bool modified = false;
|
||||
ORT_RETURN_IF_ERROR(transform_layout(graph, modified, current_ep));
|
||||
|
||||
// It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes
|
||||
// which are reconstructed to update domain or completly new nodes which are necessary for layout transformation.
|
||||
// Therefore, we re-run GetCapability so that these new nodes can be processed by this EP.
|
||||
if (modified) {
|
||||
capabilities.clear();
|
||||
GraphViewer graph_viewer(graph);
|
||||
capabilities = current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
|
||||
}
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
|
||||
auto schema = node.Op();
|
||||
builder.SetName(schema->Name())
|
||||
|
|
@ -94,23 +184,23 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
|||
// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
|
||||
// in order of EP priority
|
||||
bool sub_graph_available_for_assignment = true;
|
||||
for (auto node_index : capability.nodes) {
|
||||
const auto* node = graph.GetNode(node_index);
|
||||
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
|
||||
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
|
||||
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
|
||||
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
|
||||
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
|
||||
// should never be on those lists.
|
||||
//
|
||||
// when the ORT format model is loaded we will process it normally with EP priority being applied for
|
||||
// whichever EPs are enabled at the time.
|
||||
//
|
||||
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
|
||||
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
|
||||
// so we want all the nodes that either may take to be preserved. If we did not do this we would
|
||||
// need to create one ORT format model for Android and one for iOS.
|
||||
if (mode != GraphPartitioner::Mode::kAssignOnly) {
|
||||
if (mode != GraphPartitioner::Mode::kAssignOnly) {
|
||||
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
|
||||
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
|
||||
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
|
||||
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
|
||||
// should never be on those lists.
|
||||
//
|
||||
// when the ORT format model is loaded we will process it normally with EP priority being applied for
|
||||
// whichever EPs are enabled at the time.
|
||||
//
|
||||
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
|
||||
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
|
||||
// so we want all the nodes that either may take to be preserved. If we did not do this we would
|
||||
// need to create one ORT format model for Android and one for iOS.
|
||||
for (auto node_index : capability.nodes) {
|
||||
const auto* node = graph.GetNode(node_index);
|
||||
if ((nullptr == node) || (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
|
||||
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
|
||||
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
|
||||
sub_graph_available_for_assignment = false;
|
||||
|
|
@ -165,7 +255,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep,
|
||||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id) {
|
||||
int& fused_node_unique_id,
|
||||
TransformLayoutFunction transform_layout_function) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
if (graph.NumberOfNodes() == 0) {
|
||||
|
|
@ -178,28 +269,32 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
Graph* subgraph = entry.second;
|
||||
// we pass through the export_dll value and FuncManager from the top level graph
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id));
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
}
|
||||
|
||||
// If an execution provider return the capability that he could run a sub-graph,
|
||||
// onnxruntime will fuse the sub-graph into a function node. if the execution provider
|
||||
// says he need to compile the graph at runtime (by need_compile flag),
|
||||
// onnxruntime will invoke the "Compile" method to get compiled binary.
|
||||
// There are two mode of compile, one is return the entry point to the compiled binary
|
||||
// directly, another is export the compiled binary to shared library for future reuse.
|
||||
// If an execution provider returns the capability that it can run a sub-graph,
|
||||
// onnxruntime will fuse the sub-graph into a function node. For compilation
|
||||
// based execution providers (one which needs to compile graph at runtime.
|
||||
// Indicated by need_compile flag), onnxruntime will invoke the "Compile" method
|
||||
// to get compiled binary. There are two mode of compile, one is return the entry
|
||||
// point to the compiled binary directly, another is export the compiled binary to
|
||||
// shared library for future reuse.
|
||||
|
||||
// TODO: when the graph contain a function node, and user pass in the dll which could
|
||||
// TODO: when the graph contains a function node, and user passes in the dll which could
|
||||
// run the function by SessionOption, we should create a function kernel for it and
|
||||
// delegate the compute to the functions inside the dlls.
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep, mode, capabilities,
|
||||
transform_layout_function));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::string& type = current_ep.Type();
|
||||
auto fusion_style = current_ep.GetFusionStyle();
|
||||
std::vector<Node*> nodes_to_compile;
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
|
||||
|
||||
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
|
||||
// nodes_to_compile.
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
|
||||
|
|
@ -211,7 +306,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
}));
|
||||
|
||||
for (auto& capability : capabilities) {
|
||||
if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value...
|
||||
// in theory an EP could return an empty value...
|
||||
if (!capability || !capability->sub_graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -305,6 +401,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
|
||||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
|
||||
}
|
||||
|
||||
// if this is the main graph call Resolve to put the Graph back into a guaranteed good state
|
||||
|
|
@ -362,14 +460,16 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
|
|||
|
||||
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode,
|
||||
int& fused_node_unique_id) const {
|
||||
int& fused_node_unique_id,
|
||||
TransformLayoutFunction transform_layout_function) const {
|
||||
bool modified_graph = false;
|
||||
|
||||
do {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id));
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
|
||||
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
|
||||
|
|
@ -392,13 +492,15 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep,
|
||||
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) {
|
||||
int& fused_node_unique_id,
|
||||
TransformLayoutFunction transform_layout_function) {
|
||||
// recurse into nested graphs first to partition bottom up.
|
||||
for (auto& node : graph.Nodes()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry,
|
||||
current_ep, compiled_kernel_hashes, fused_node_unique_id));
|
||||
current_ep, compiled_kernel_hashes, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -409,12 +511,10 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
}
|
||||
|
||||
const std::string& type = current_ep.Type();
|
||||
GraphViewer graph_viewer(graph);
|
||||
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
|
||||
ORT_RETURN_IF_ERROR(GetCapabilityForEP(graph, kernel_registry_mgr, current_ep,
|
||||
GraphPartitioner::Mode::kOrtFormatLoad, capabilities, transform_layout_function));
|
||||
if (capabilities.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -451,40 +551,37 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) {
|
||||
Node& node = nodes_and_viewers[j].fused_node;
|
||||
std::vector<NodeComputeInfo> single_node_compute_func;
|
||||
auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func);
|
||||
if (!status.IsOK()) {
|
||||
// There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph
|
||||
LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage();
|
||||
graph.CancelFuseSubGraph(node);
|
||||
} else {
|
||||
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements");
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func));
|
||||
|
||||
const auto& cur_capability = capabilities[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element.");
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));
|
||||
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
auto kernel_def = builder.Build();
|
||||
const auto& cur_capability = capabilities[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
|
||||
// save hash so SessionState can find the kernel. each kernel name should be unique
|
||||
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
|
||||
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
|
||||
". Execution Provider must generate unique names across the entire model.");
|
||||
}
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
auto kernel_def = builder.Build();
|
||||
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
|
||||
KernelCreateInfo(std::move(kernel_def),
|
||||
[](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
|
||||
return FunctionKernel::Create(func_mgr, info, out);
|
||||
})));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
|
||||
// save hash so SessionState can find the kernel. each kernel name should be unique
|
||||
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
|
||||
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
|
||||
". Execution Provider must generate unique names across the entire model.");
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
|
||||
KernelCreateInfo(std::move(kernel_def),
|
||||
[](FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
|
||||
return FunctionKernel::Create(func_mgr, info, out);
|
||||
})));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -495,7 +592,8 @@ Status GraphPartitioner::PartitionOrtFormatModel(
|
|||
Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) const {
|
||||
int& fused_node_unique_id,
|
||||
TransformLayoutFunction transform_layout_function) const {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
if (ep->Type() == kCpuExecutionProvider) {
|
||||
|
|
@ -505,13 +603,15 @@ Status GraphPartitioner::PartitionOrtFormatModel(
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry,
|
||||
*ep, compiled_kernel_hashes, fused_node_unique_id));
|
||||
*ep, compiled_kernel_hashes, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode,
|
||||
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
TransformLayoutFunction transform_layout_function, Mode mode,
|
||||
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes) const {
|
||||
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
|
||||
// 1. Execution providers' capabilities are checked one by one.
|
||||
|
|
@ -535,16 +635,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
|
|||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
|
||||
fused_node_unique_id));
|
||||
fused_node_unique_id, transform_layout_function));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(export_dll);
|
||||
ORT_THROW("Not supported in this build.");
|
||||
#endif
|
||||
#endif //! defined(ORT_MINIMAL_BUILD)
|
||||
} else {
|
||||
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
|
||||
fused_node_unique_id));
|
||||
fused_node_unique_id, transform_layout_function));
|
||||
}
|
||||
|
||||
if (!fused_kernel_registry->IsEmpty()) {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ namespace onnxruntime {
|
|||
class ExecutionProviders;
|
||||
class KernelRegistry;
|
||||
class KernelRegistryManager;
|
||||
using TransformLayoutFunction = std::function<Status(Graph& graph, bool& modified, IExecutionProvider& current_ep)>;
|
||||
|
||||
class GraphPartitioner {
|
||||
public:
|
||||
|
|
@ -31,7 +32,8 @@ class GraphPartitioner {
|
|||
}
|
||||
|
||||
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
|
||||
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
TransformLayoutFunction transform_layout_function,
|
||||
Mode mode = Mode::kNormal,
|
||||
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes = nullptr) const;
|
||||
|
||||
|
|
@ -40,12 +42,13 @@ class GraphPartitioner {
|
|||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const;
|
||||
KernelRegistry& fused_kernel_registry, Mode mode,
|
||||
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
|
||||
#endif
|
||||
|
||||
Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry,
|
||||
std::unordered_map<std::string, HashValue>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) const;
|
||||
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
|
||||
|
||||
KernelRegistryManager& kernel_registry_mgr_;
|
||||
const ExecutionProviders& providers_;
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include "core/framework/session_state_flatbuffers_utils.h"
|
||||
#include "core/framework/session_state_utils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/framework/static_kernel_def_hashes.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
|
||||
|
|
@ -601,7 +602,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span<const OrtValue>&
|
|||
auto* node = graph_viewer_->GetNode(node_plan.node_index);
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
|
||||
static_cast<int>(node->ImplicitInputDefs().size());
|
||||
//allocate output
|
||||
// allocate output
|
||||
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
|
||||
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
|
||||
if (ml_value_idx == NodeIndexInfo::kInvalidEntry ||
|
||||
|
|
@ -631,7 +632,7 @@ Status SessionState::GeneratePatternGroupCache(const gsl::span<const OrtValue>&
|
|||
}
|
||||
}
|
||||
|
||||
//release nodes
|
||||
// release nodes
|
||||
for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) {
|
||||
auto ml_value_idx = exe_plan->to_be_freed[index];
|
||||
const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type;
|
||||
|
|
@ -1015,15 +1016,22 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
|
|||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
|
||||
// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
|
||||
// lookup the hashes for any nodes we compiled or added during graph partitioning.
|
||||
// These node indexes for compiled nodes as well as newly added nodes are not in node_indices
|
||||
// as they were created at runtime.
|
||||
if (!compiled_kernel_hashes.empty()) {
|
||||
for (const auto& node : graph_.Nodes()) {
|
||||
if (kernel_create_info_map_.count(node.Index()) == 0) {
|
||||
for (const auto& node : graph_.Nodes()) {
|
||||
if (kernel_create_info_map_.count(node.Index()) == 0) {
|
||||
if (node.Domain() == kOnnxDomain || node.Domain() == kOnnxDomainAlias) {
|
||||
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
|
||||
if (kernel_hash.has_value()) {
|
||||
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash));
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to find kernel hash for node:", node.Name(), " optype:", node.OpType());
|
||||
}
|
||||
} else {
|
||||
const auto hash_info = compiled_kernel_hashes.find(node.OpType());
|
||||
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
|
||||
"Unable to find compiled kernel hash for node '", node.Name(), "'.");
|
||||
|
||||
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second));
|
||||
}
|
||||
}
|
||||
|
|
@ -1284,7 +1292,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
subgraphs_kernel_create_info_maps,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
//Record the allocation plan
|
||||
// Record the allocation plan
|
||||
|
||||
// Uncomment the below to dump the allocation plan to std::cout
|
||||
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
|
||||
|
|
@ -1327,7 +1335,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
},
|
||||
logger_, data_transfer_mgr_, *p_seq_exec_plan_.get(), session_options));
|
||||
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
|
||||
//Record Weight allocation info on device
|
||||
// Record Weight allocation info on device
|
||||
MemoryInfo::RecordInitializerAllocInfo(GetInitializedTensors());
|
||||
#endif
|
||||
|
||||
|
|
|
|||
37
onnxruntime/core/framework/static_kernel_def_hashes.cc
Normal file
37
onnxruntime/core/framework/static_kernel_def_hashes.cc
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#include "core/framework/static_kernel_def_hashes.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version) {
|
||||
// Layout tranformer can add new nodes to the graph.
|
||||
// Since layout transformation can happen in an extended build, if these nodes are not picked up and compiled by
|
||||
// NNAPI or other compiling EPs then we need a way to get the hashes for these nodes. Since the infrastructure
|
||||
// as well as op_schema required to generate these hashes is not available in an extended minimal build,
|
||||
// we maintain a static map of nodes to hash value. This hash value can then be used to retireive the
|
||||
// kernel for the given op.
|
||||
static std::unordered_map<std::string, HashValue> static_kernel_hashes{
|
||||
{"Transpose_1", 4324835766923221184ULL},
|
||||
{"Transpose_13", 17267477159887372848ULL},
|
||||
{"Squeeze_1", 12889825108950034784ULL},
|
||||
{"Squeeze_11", 14725795030460042064ULL},
|
||||
{"Squeeze_13", 16122603335179721968ULL},
|
||||
{"UnSqueeze_1", 15964030255371555232ULL},
|
||||
{"UnSqueeze_11", 16989589986691430224ULL},
|
||||
{"UnSqueeze_13", 9466011545409597224ULL},
|
||||
{"Gather_1", 625186873870077080ULL},
|
||||
{"Gather_11", 11761559382112736008ULL},
|
||||
{"Gather_13", 7462749543760614528ULL},
|
||||
{"Identity_1", 18001636502361632792ULL},
|
||||
{"Identity_13", 16879814636194901248ULL},
|
||||
{"Identity_14", 16515685968327103576ULL},
|
||||
{"Identity_16", 17661628575887109792ULL},
|
||||
};
|
||||
|
||||
auto key = op_type + "_" + std::to_string(since_version);
|
||||
auto iter = static_kernel_hashes.find(key);
|
||||
if (iter != static_kernel_hashes.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
19
onnxruntime/core/framework/static_kernel_def_hashes.h
Normal file
19
onnxruntime/core/framework/static_kernel_def_hashes.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include "core/common/basic_types.h"
|
||||
namespace onnxruntime {
|
||||
/**
|
||||
* @brief Gets the hash value for provided op type + version combination if it is available, otherwise
|
||||
* returns a nullopt. The hash value is available if this node was added by layout transformer. For all other
|
||||
* nodes, the hash values should be present either in the serialized session state obtained form ort format model
|
||||
* or from compiled kernel hash map which is generated during partitioning.
|
||||
* @return std::optional<HashValue>
|
||||
*/
|
||||
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version);
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/nhwc_transformer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/transpose_optimizer/api_impl.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -55,7 +55,7 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
|
||||
|
||||
if (domain != kMSDomain) {
|
||||
SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", "com.microsoft");
|
||||
SwapNodeOpTypeAndDomain(*api_graph, *node, "QLinearConv", kMSDomain);
|
||||
}
|
||||
|
||||
modified = true;
|
||||
|
|
|
|||
|
|
@ -111,7 +111,9 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
std::vector<NodeGroup> qdq_selections;
|
||||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
if (node->Domain() != kOnnxDomain) {
|
||||
// post layout transformation all the layout sensitive nodes are converted to domain
|
||||
// kMSInternalNHWCDomain. Therefore need to allow this domain as well.
|
||||
if (node->Domain() != kOnnxDomain && node->Domain() != kMSInternalNHWCDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace onnx_layout_transformation {
|
||||
namespace api {
|
||||
|
|
@ -204,7 +205,7 @@ class NodeRef {
|
|||
}
|
||||
std::string_view node_domain = Domain();
|
||||
return node_domain == domain ||
|
||||
((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx"));
|
||||
((domain == "" || domain == "ai.onnx") && (node_domain == "" || node_domain == "ai.onnx"));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -217,6 +218,19 @@ class NodeRef {
|
|||
return GetAttributeInt(name).value_or(default_value);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the Execution Provider assigned to this node. Any empty string means this node is
|
||||
/// not assigned to any EP.
|
||||
/// </summary>
|
||||
/// <returns>EP type or empty string</returns>
|
||||
virtual const std::string& GetExecutionProviderType() const = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Returns the schema since version for the op_type of this node. Value os -1 means it is not set.
|
||||
/// </summary>
|
||||
/// <returns>since version or default value -1</returns>
|
||||
virtual int SinceVersion() const = 0;
|
||||
|
||||
virtual ~NodeRef(){};
|
||||
};
|
||||
|
||||
|
|
@ -224,7 +238,6 @@ class NodeRef {
|
|||
/// Information regarding the consumers of a value.
|
||||
/// </summary>
|
||||
struct ValueConsumers {
|
||||
|
||||
/// <summary>
|
||||
/// List of nodes in the current graph containing value as an input
|
||||
/// </summary>
|
||||
|
|
@ -335,6 +348,14 @@ class GraphRef {
|
|||
virtual std::unique_ptr<NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
|
||||
size_t num_outputs, std::string_view domain = "") = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Creates a copy of the provided node in the graph with the specified op type and domain.
|
||||
/// </summary>
|
||||
/// <param name="op_type">The new node's op type</param>
|
||||
/// <param name="domain">The new node's domain. Empty string signifies default onnx domain.</param>
|
||||
/// <returns>The new node</returns>
|
||||
virtual std::unique_ptr<NodeRef> CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain = "") = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Deletes a node from the graph. Behavior is undefined if node has any consumers.
|
||||
/// </summary>
|
||||
|
|
@ -409,6 +430,17 @@ class GraphRef {
|
|||
constexpr int64_t kMinSupportedOpset = 7;
|
||||
constexpr int64_t kMaxSupportedOpset = 15;
|
||||
|
||||
enum class OptimizerMode {
|
||||
OPTIMIZE_TRANSPOSE, // simple transpose optimization
|
||||
OPTIMIZE_LAYOUT_TRANSFORM // transpose optimization post layout transformation
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Gets a list of layout sensitive ops defined by ONNX standard.
|
||||
/// </summary>
|
||||
/// <returns>const reference to an unordered set of op_types which are layout sensitive</returns>
|
||||
const std::unordered_set<std::string_view>& GetLayoutSensitiveOps();
|
||||
|
||||
/// <summary>
|
||||
/// Performs transpose optimization on a graph. Returns true if the graph was modified.
|
||||
///
|
||||
|
|
@ -420,8 +452,16 @@ constexpr int64_t kMaxSupportedOpset = 15;
|
|||
/// </summary>
|
||||
/// <param name="graph">The graph to optimize (or a portion of a graph, see api::GraphRef docs)</param>
|
||||
/// <param name="allow_extended_ops">Whether com.microsoft ops can be used for optimization</param>
|
||||
/// <param name="provider_type">Execution provider if applicable.</param>
|
||||
/// <param name="mode">Current mode. Optimizer can be called in the context of transpose optimizations or during layout transformations.</param>
|
||||
/// <param name="layout_sensitive_ops">List of ops which are treated as layout sensitive by the ONNX standard as well as any runtime specific ops.
|
||||
/// These ops should be provided when mode is set to OPTIMIZE_LAYOUT_TRANSFORM. If these ops are not provided, transpose optimizer may convert the
|
||||
/// layout for these ops </param>
|
||||
/// <returns>true if the graph was modified</returns>
|
||||
bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
|
||||
bool Optimize(api::GraphRef& graph, bool allow_extended_ops,
|
||||
const std::string& provider_type = "",
|
||||
OptimizerMode mode = OptimizerMode::OPTIMIZE_TRANSPOSE,
|
||||
const std::unordered_set<std::string_view>& layout_sensitive_ops = {});
|
||||
|
||||
/* Layout Transformation Tools
|
||||
* These methods help change the channel ordering of layout sensitive ops (like Conv). ONNX currently only supports
|
||||
|
|
@ -429,14 +469,14 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
|
|||
* the new ordering. The existence of a robust transpose optimizer means that we can freely add transpose ops during
|
||||
* conversion and then call Optimize to remove as many as possible. To change the channel ordering of some/all ops
|
||||
* in a model, a user of this tool should do the following:
|
||||
*
|
||||
*
|
||||
* 1. Iterate over the graph nodes and identify nodes to convert. For each one:
|
||||
* a. Change the op type and domain (and possibly attributes) to the op/contrib op with the desired ordering.
|
||||
* b. The model is now invalid since the input tensors are in the original ordering (and all consumers
|
||||
* expect the original ordering). Use WrapTransposesAroundNode helper to insert transposes around the
|
||||
* inputs/outputs of the op to correct this.
|
||||
* 2. The model is now correct but has many unnecessary Transpose ops. Call Optimize on the graph.
|
||||
*
|
||||
*
|
||||
* After step 1, the Transpose ops will wrap converted ops in a similar manner to q/dq ops in quantization.
|
||||
* The perm attributes essentially encode the information about which ops are being reordered.
|
||||
*/
|
||||
|
|
@ -444,13 +484,13 @@ bool Optimize(api::GraphRef& graph, bool allow_extended_ops);
|
|||
/// <summary>
|
||||
/// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose
|
||||
/// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change.
|
||||
///
|
||||
///
|
||||
/// Ex:
|
||||
/// * -> NhwcConv -> **
|
||||
/// becomes
|
||||
/// * -> Transpose -> NhwcConv -> Transpose -> **
|
||||
/// Conv inputs/outputs have new shape. Shapes of * and ** are unchanged (carrying NCHW data).
|
||||
///
|
||||
///
|
||||
/// input_perms/output_perms are matched with node inputs/outputs positionally. Their lengths must be at most equal to
|
||||
/// the number of inputs/outputs, respectively. nullptr entires indicate an input or output should not be transposed.
|
||||
/// </summary>
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/transpose_optimizer/api_impl.h"
|
||||
|
||||
#include "optimizer_api.h"
|
||||
#include "optimizer_utils.h"
|
||||
#include <deque>
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/cpu/tensor/transpose.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -15,7 +15,6 @@ using namespace ::onnxruntime::common;
|
|||
using namespace onnx_layout_transformation;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ApiValueInfo final : public api::ValueInfoRef {
|
||||
private:
|
||||
NodeArg& node_arg_;
|
||||
|
|
@ -84,6 +83,8 @@ class ApiNode final : public api::NodeRef {
|
|||
void CopyAttributes(const api::NodeRef& node) override;
|
||||
void ClearAttribute(std::string_view name) override;
|
||||
void SetInput(size_t i, std::string_view name) override;
|
||||
const std::string& GetExecutionProviderType() const override;
|
||||
virtual int SinceVersion() const override;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode);
|
||||
|
|
@ -114,6 +115,9 @@ class ApiGraph final : public api::GraphRef {
|
|||
void ReshapeInitializer(std::string_view name, const std::vector<int64_t>& shape) override;
|
||||
std::unique_ptr<api::NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
|
||||
size_t num_outputs = 1, std::string_view domain = "") override;
|
||||
|
||||
std::unique_ptr<api::NodeRef> CopyNode(const api::NodeRef& source_node, std::string_view op_type,
|
||||
std::string_view domain = "") override;
|
||||
void RemoveNode(api::NodeRef& node) override;
|
||||
void RemoveInitializer(std::string_view name) override;
|
||||
std::string_view AddInitializer(api::DataType dtype, const std::vector<int64_t>& shape,
|
||||
|
|
@ -377,6 +381,15 @@ void ApiNode::SetInput(size_t i, std::string_view name) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ApiNode::GetExecutionProviderType() const {
|
||||
return node_.GetExecutionProviderType();
|
||||
}
|
||||
|
||||
int ApiNode::SinceVersion() const {
|
||||
return node_.SinceVersion();
|
||||
}
|
||||
|
||||
// </ApiNode>
|
||||
|
||||
std::optional<int64_t> ApiGraph::Opset(std::string_view domain) const {
|
||||
|
|
@ -562,11 +575,11 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vector<int64
|
|||
node_arg->SetShape(new_shape);
|
||||
}
|
||||
|
||||
std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
|
||||
const std::vector<std::string_view>& inputs, size_t num_outputs,
|
||||
std::string_view domain) {
|
||||
static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type,
|
||||
const std::vector<std::string_view>& inputs, size_t num_outputs,
|
||||
std::string_view domain, int since_version, std::string_view node_ep) {
|
||||
const std::string op_type_str(op_type);
|
||||
std::string name = graph_.GenerateNodeName(op_type_str);
|
||||
std::string name = graph.GenerateNodeName(op_type_str);
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
|
||||
|
|
@ -574,48 +587,107 @@ std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
|
|||
for (const auto& input : inputs) {
|
||||
NodeArg* arg;
|
||||
if (input == "") {
|
||||
arg = &graph_.GetOrCreateNodeArg("", nullptr);
|
||||
arg = &graph.GetOrCreateNodeArg("", nullptr);
|
||||
} else {
|
||||
arg = graph_.GetNodeArg(std::string(input));
|
||||
arg = graph.GetNodeArg(std::string(input));
|
||||
}
|
||||
input_args.push_back(arg);
|
||||
}
|
||||
|
||||
output_args.reserve(num_outputs);
|
||||
for (size_t i = 0; i < num_outputs; ++i) {
|
||||
std::string output = graph_.GenerateNodeArgName(name + "_out" + std::to_string(i));
|
||||
NodeArg* arg = &graph_.GetOrCreateNodeArg(output, nullptr);
|
||||
std::string output = graph.GenerateNodeArgName(name + "_out" + std::to_string(i));
|
||||
NodeArg* arg = &graph.GetOrCreateNodeArg(output, nullptr);
|
||||
output_args.push_back(arg);
|
||||
}
|
||||
|
||||
std::vector<NodeArg*> outputs;
|
||||
Node& node = graph_.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr,
|
||||
std::string(domain));
|
||||
Node& node = graph.AddNode(name, op_type_str, "Added in transpose optimizer", input_args, output_args, nullptr,
|
||||
std::string(domain));
|
||||
|
||||
if (new_node_ep_ != nullptr) {
|
||||
node.SetExecutionProviderType(new_node_ep_);
|
||||
if (node.SinceVersion() == -1) {
|
||||
node.SetSinceVersion(since_version);
|
||||
}
|
||||
|
||||
node.SetExecutionProviderType(std::string(node_ep));
|
||||
|
||||
for (size_t i = 0; i < input_args.size(); ++i) {
|
||||
NodeArg* arg = input_args[i];
|
||||
if (arg->Exists()) {
|
||||
const std::string& name_str = arg->Name();
|
||||
graph_.AddConsumerNode(name_str, &node);
|
||||
const auto* inp_node = graph_.GetProducerNode(name_str);
|
||||
graph.AddConsumerNode(name_str, &node);
|
||||
const auto* inp_node = graph.GetProducerNode(name_str);
|
||||
if (inp_node != nullptr) {
|
||||
int inp_node_out_index = graph_utils::GetNodeOutputIndexFromOutputName(*inp_node, name_str);
|
||||
graph_.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast<int>(i));
|
||||
graph.AddEdge(inp_node->Index(), node.Index(), inp_node_out_index, gsl::narrow_cast<int>(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (NodeArg* arg : output_args) {
|
||||
graph_.UpdateProducerNode(arg->Name(), node.Index());
|
||||
graph.UpdateProducerNode(arg->Name(), node.Index());
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
// This is a list of onnx ops and their versions which transpose_optimizer can potentially add to the graph.
|
||||
// This is needed in minimal build since opschema is not available.
|
||||
// The versions MUST be sorted due to how the model opset is matched with the most recent operator version.
|
||||
static const std::unordered_map<std::string, std::vector<int>> onnx_ops_available_versions = {
|
||||
{"Squeeze", {1, 11, 13}},
|
||||
{"Unsqueeze", {1, 11, 13}},
|
||||
{"Gather", {1, 11, 13}},
|
||||
{"Transpose", {1, 13}},
|
||||
{"Identity", {1, 13, 14, 16}},
|
||||
};
|
||||
|
||||
// Based on the opset version imported for this model, returns the since version for the node.
|
||||
static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain,
|
||||
const std::unordered_map<std::string, int>& domain_to_version_map) {
|
||||
int since_version = -1;
|
||||
ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ",
|
||||
domain, " provided for op: ", op_type);
|
||||
|
||||
auto opset_import_iter = domain_to_version_map.find(std::string(domain));
|
||||
ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), "Onnx domain not found in opset imports.");
|
||||
|
||||
int opset_version = opset_import_iter->second;
|
||||
auto iter = onnx_ops_available_versions.find(std::string(op_type));
|
||||
ORT_ENFORCE(iter != onnx_ops_available_versions.end(),
|
||||
"Transpose Optimizer is adding an unexpected node: ", op_type,
|
||||
"An entry for this node should be added in onnx_ops_available_versions and static_kernel_hashes map.");
|
||||
|
||||
for (auto version : iter->second) {
|
||||
if (version <= opset_version) {
|
||||
since_version = version;
|
||||
}
|
||||
}
|
||||
|
||||
return since_version;
|
||||
}
|
||||
|
||||
std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
|
||||
const std::vector<std::string_view>& inputs, size_t num_outputs,
|
||||
std::string_view domain) {
|
||||
int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap());
|
||||
Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs,
|
||||
domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : "");
|
||||
|
||||
return std::make_unique<ApiNode>(node, graph_);
|
||||
}
|
||||
|
||||
std::unique_ptr<api::NodeRef> ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type,
|
||||
std::string_view domain) {
|
||||
Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(),
|
||||
source_node.Outputs().size(), domain, source_node.SinceVersion(), source_node.GetExecutionProviderType());
|
||||
|
||||
std::unique_ptr<api::NodeRef> new_node = std::make_unique<ApiNode>(node, graph_);
|
||||
new_node->CopyAttributes(source_node);
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void ApiGraph::RemoveNode(api::NodeRef& node) {
|
||||
Node& ort_node = static_cast<ApiNode&>(node).Node();
|
||||
for (const auto* node_arg : ort_node.InputDefs()) {
|
||||
|
|
@ -705,6 +777,10 @@ std::unique_ptr<api::GraphRef> MakeApiGraph(onnxruntime::Graph& graph, Allocator
|
|||
return std::make_unique<ApiGraph>(graph, std::move(cpu_allocator), new_node_ep);
|
||||
}
|
||||
|
||||
std::unique_ptr<api::NodeRef> MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node) {
|
||||
return std::make_unique<ApiNode>(node, graph);
|
||||
}
|
||||
|
||||
onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef& graph) {
|
||||
return static_cast<ApiGraph&>(graph).Graph();
|
||||
}
|
||||
|
|
@ -713,4 +789,99 @@ onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& nod
|
|||
return static_cast<ApiNode&>(node).Node();
|
||||
}
|
||||
|
||||
namespace layout_transformer {
|
||||
|
||||
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps() {
|
||||
static std::unordered_set<std::string_view> ort_layout_senstive_ops = []() {
|
||||
const auto& layout_sensitive_ops = onnx_layout_transformation::GetLayoutSensitiveOps();
|
||||
std::unordered_set<std::string_view> ort_specific_ops = {"Resize", "FusedConv", "QLinearAveragePool", "QLinearGlobalAveragePool"};
|
||||
ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend());
|
||||
return ort_specific_ops;
|
||||
}();
|
||||
|
||||
return ort_layout_senstive_ops;
|
||||
}
|
||||
|
||||
Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider) {
|
||||
// sub graph recurse will be added later
|
||||
auto api_graph = MakeApiGraph(graph, execution_provider.GetAllocator(0, OrtMemTypeDefault), nullptr);
|
||||
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
|
||||
|
||||
for (auto& node : api_graph->Nodes()) {
|
||||
if (layout_sensitive_ops.count(node->OpType())) {
|
||||
if (node->GetExecutionProviderType() != execution_provider.Type()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto domain = node->Domain();
|
||||
// Skip if domain is incorrect
|
||||
if (domain != kOnnxDomain && domain != kOnnxDomainAlias && domain != kMSDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if already transformed then change the domain to kMSInternalNHWCDomain this way the EP
|
||||
// knows this op is in the expected format.
|
||||
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
|
||||
onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
|
||||
// Changing the domain for the node requires creating a new node and replacing the old one
|
||||
// therefore set the modified flag.
|
||||
modified = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if unknown rank
|
||||
auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape();
|
||||
if (!shape.has_value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert to channels last
|
||||
size_t rank = shape->size();
|
||||
|
||||
bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false;
|
||||
if (has_channel_last_attr) {
|
||||
node->SetAttributeInt("channels_last", 1);
|
||||
}
|
||||
|
||||
auto input_perm = onnx_layout_transformation::ChannelFirstToLastPerm(rank);
|
||||
auto output_perm = onnx_layout_transformation::ChannelLastToFirstPerm(rank);
|
||||
|
||||
// Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation
|
||||
// for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout
|
||||
// transformer only converts layout for 0th input, weights should be handled by every EP.
|
||||
if (node->OpType() == "Resize") {
|
||||
// Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case,
|
||||
// we need to jump a few extra hoops to make sure its inputs are correctly handled. Current code skips
|
||||
// layout conversion for ROI because it needs special handling as ROI size is 2*rank.
|
||||
// Enable passing in ROI for layout conversion when an EP which supports ROI starts using layout transformer.
|
||||
// NNAPI which currently uses layout transformer does not support it.
|
||||
std::vector<const std::vector<int64_t>*> input_perms{&input_perm, nullptr};
|
||||
for (size_t i = 2; i < node->Inputs().size(); i++) {
|
||||
auto constant = api_graph->GetConstant(node->Inputs()[i]);
|
||||
if (constant != nullptr && constant->Data().size() > 0) {
|
||||
input_perms.push_back(&input_perm);
|
||||
} else {
|
||||
input_perms.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm});
|
||||
} else {
|
||||
onnx_layout_transformation::WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
|
||||
}
|
||||
|
||||
onnx_layout_transformation::SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
onnx_layout_transformation::Optimize(*api_graph, /*allow_extended_ops*/ true, execution_provider.Type(),
|
||||
onnx_layout_transformation::OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM,
|
||||
layout_sensitive_ops);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace layout_transformer
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,15 +3,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "optimizer_api.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/optimizer/transpose_optimizer/api.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/// <summary>
|
||||
/// Creates concrete implementation of api for transpose optimizer. IMPORTANT: graph must have up-to-date edges,
|
||||
/// node_arg-to-producer, and node_arg-to-consumer relationships. Otherwise call Resolve() before this.
|
||||
|
|
@ -24,6 +20,14 @@ std::unique_ptr<onnx_layout_transformation::api::GraphRef> MakeApiGraph(onnxrunt
|
|||
AllocatorPtr cpu_allocator,
|
||||
const char* new_node_ep);
|
||||
|
||||
/// <summary>
|
||||
/// Creates NodeRef.
|
||||
/// </summary>
|
||||
/// <param name="graph">ORT Graph which owns the node</param>
|
||||
/// <param name="node">ORT Node to wrap with API.</param>
|
||||
/// <returns>api::NodeRef for use with transpose optimizer</returns>
|
||||
std::unique_ptr<onnx_layout_transformation::api::NodeRef> MakeApiNode(onnxruntime::Graph& graph, onnxruntime::Node& node);
|
||||
|
||||
/// <summary>
|
||||
/// Reveals underlying ORT graph from an api::GraphRef
|
||||
/// </summary>
|
||||
|
|
@ -38,4 +42,25 @@ onnxruntime::Graph& GraphFromApiGraph(onnx_layout_transformation::api::GraphRef&
|
|||
/// <returns>ORT node</returns>
|
||||
onnxruntime::Node& NodeFromApiNode(onnx_layout_transformation::api::NodeRef& node);
|
||||
|
||||
namespace layout_transformer {
|
||||
/// <summary>
|
||||
/// Gets a list of layout sensitive ops for ORT. This list contains onnx standard defined
|
||||
/// layout senstive ops + contrib ops + ops which are not layout sensitive but are treated as
|
||||
/// layout sensitive by ORT EPs (exmaple Resize).
|
||||
/// </summary>
|
||||
/// <returns>unordered set of op_types which are layout sensitive</returns>
|
||||
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps();
|
||||
|
||||
/// <summary>
|
||||
/// Transforms data layout from NCHW to NHWC. Applies transforms to layout sensitive nodes
|
||||
/// assigned to execution_provider provided by the caller and any other non-layout sensitive
|
||||
/// nodes in order to optimize the transposes as much as possible.
|
||||
/// </summary>
|
||||
/// <param name="graph">graph to transform</param>
|
||||
/// <param name="modified">indicates whether the graph is modified during transformation</param>
|
||||
/// <param name="execution_provider">execution provider for which the transformation needs to be performed</param>
|
||||
/// <returns></returns>
|
||||
Status TransformLayout(Graph& graph, bool& modified, IExecutionProvider& execution_provider);
|
||||
|
||||
} // namespace layout_transformer
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
|
||||
|
||||
#include "ort_transpose_optimizer.h"
|
||||
#include <deque>
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/transpose_optimizer/api_impl.h"
|
||||
#include "core/providers/cpu/tensor/transpose.h"
|
||||
#include "optimizer_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "api.h"
|
||||
#include "optimizer_api.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <gsl/gsl>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <cstring>
|
||||
|
||||
namespace onnx_layout_transformation {
|
||||
|
||||
|
|
@ -16,6 +17,9 @@ struct OptimizerCtx {
|
|||
api::GraphRef& graph;
|
||||
bool allow_extended_ops;
|
||||
bool skip_cost_check;
|
||||
const std::string provider_type;
|
||||
OptimizerMode mode;
|
||||
std::unordered_set<std::string_view> layout_sensitive_ops;
|
||||
};
|
||||
|
||||
// Each op handler points to a (potentially shared) function for determining which input indices are eligible for
|
||||
|
|
@ -44,7 +48,6 @@ struct HandlerInfo {
|
|||
bool transposes_outputs = true;
|
||||
};
|
||||
|
||||
|
||||
/////// <Helper Utils> ///////
|
||||
/* Small utilities for editing nodes and manipulating axes/permutations */
|
||||
|
||||
|
|
@ -63,14 +66,14 @@ static std::vector<int32_t> DataInt32(api::TensorRef& tensor) {
|
|||
}
|
||||
|
||||
static std::string_view AddInitializerInt64(api::GraphRef& graph, const std::vector<int64_t>& shape,
|
||||
const std::vector<int64_t>& values) {
|
||||
const std::vector<int64_t>& values) {
|
||||
const uint8_t* raw_data = reinterpret_cast<const uint8_t*>(values.data());
|
||||
std::vector<uint8_t> data(raw_data, raw_data + values.size() * sizeof(int64_t));
|
||||
return graph.AddInitializer(api::DataType::INT64, shape, data);
|
||||
}
|
||||
|
||||
static std::string_view AddInitializerInt32(api::GraphRef& graph, const std::vector<int64_t>& shape,
|
||||
const std::vector<int32_t>& values) {
|
||||
const std::vector<int32_t>& values) {
|
||||
const uint8_t* raw_data = reinterpret_cast<const uint8_t*>(values.data());
|
||||
std::vector<uint8_t> data(raw_data, raw_data + values.size() * sizeof(int32_t));
|
||||
return graph.AddInitializer(api::DataType::INT32, shape, data);
|
||||
|
|
@ -108,7 +111,7 @@ static std::unique_ptr<api::NodeRef> MakeTranspose(api::GraphRef& graph, std::st
|
|||
}
|
||||
|
||||
// Creates a Squeeze/Unsqueeze node. Does not update output ValueInfo.
|
||||
static std::unique_ptr<api::NodeRef> MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph,
|
||||
static std::unique_ptr<api::NodeRef> MakeSqueezeOrUnsqueeze(int64_t opset, api::GraphRef& graph,
|
||||
std::string_view op_type, std::string_view input,
|
||||
const std::vector<int64_t>& axes) {
|
||||
if (opset < 13) {
|
||||
|
|
@ -141,7 +144,7 @@ static bool IsValidPerm(const std::vector<int64_t>& perm) {
|
|||
|
||||
static std::optional<std::vector<int64_t>> GetPermAttrIfValid(const api::NodeRef& node) {
|
||||
std::optional<std::vector<int64_t>> perm = node.GetAttributeInts("perm");
|
||||
if (perm != std::nullopt && !IsValidPerm(*perm)) {
|
||||
if (perm.has_value() && !IsValidPerm(*perm)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return perm;
|
||||
|
|
@ -224,8 +227,12 @@ static bool IsIdentityPerm(const std::vector<int64_t>& perm) {
|
|||
}
|
||||
|
||||
// Computes permutation from channel last to channel first ordering of given rank. Nearly all handlers work for any
|
||||
// permutation, but some are restricted. Also used for layout transformation. Rank must be >= 1.
|
||||
// permutation, but some are restricted. Also used for layout transformation.
|
||||
std::vector<int64_t> ChannelLastToFirstPerm(size_t rank) {
|
||||
if (rank < 2) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<int64_t> p(rank);
|
||||
p[0] = 0;
|
||||
p[1] = rank - 1;
|
||||
|
|
@ -330,9 +337,9 @@ static std::vector<int64_t> SqueezePerm(const std::vector<int64_t>& axes, const
|
|||
return new_perm;
|
||||
}
|
||||
|
||||
// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or
|
||||
// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or
|
||||
// have negative values.
|
||||
//
|
||||
//
|
||||
// Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [2, 0]
|
||||
static std::vector<int64_t> AxesForTransposedInput(const std::vector<int64_t>& axes,
|
||||
const std::vector<int64_t>& perm) {
|
||||
|
|
@ -346,7 +353,7 @@ static std::vector<int64_t> AxesForTransposedInput(const std::vector<int64_t>& a
|
|||
|
||||
// Computes a new axes attribute for an input that has been permuted using perm and sorts the result. Axes attributes
|
||||
// are commonly sorted (unless order matters like in Slice). Unsafe if axes/perm are invalid or have negative values.
|
||||
//
|
||||
//
|
||||
// Ex: perm = [2, 0, 1], axes = [0, 1], new_axes = [0, 2]
|
||||
static std::vector<int64_t> SortedAxesForTransposedInput(const std::vector<int64_t>& axes,
|
||||
const std::vector<int64_t>& perm) {
|
||||
|
|
@ -376,10 +383,8 @@ static std::vector<int64_t> SortedAxesForTransposedInput(const std::vector<int64
|
|||
/////// <Core Helpers> ///////
|
||||
/* These helpers hide the most gnarly parts of the transpose optimizer. */
|
||||
|
||||
|
||||
static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector<int64_t>& axes);
|
||||
|
||||
|
||||
// Replaces ith input to node with unsqueezed value. Might create a new Unsqueeze node, find an existing one,
|
||||
// or reshape an initializer. Unsqueezing can be necessary before transposing inputs of a node that supports
|
||||
// broadcasting.
|
||||
|
|
@ -442,7 +447,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
|
|||
// a Transpose, optimize it here.
|
||||
if (inp_node != nullptr && inp_node->IsOp("Transpose")) {
|
||||
auto perm = GetPermAttrIfValid(*inp_node);
|
||||
if (perm != std::nullopt) {
|
||||
if (perm.has_value()) {
|
||||
auto perm_inv = InvertPerm(*perm);
|
||||
std::vector<size_t> indices = {0};
|
||||
HandlerArgs args{ctx, *inp_node, unsqueeze, *perm, perm_inv, indices};
|
||||
|
|
@ -456,6 +461,29 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
|
|||
node.SetInput(i, unsq_out);
|
||||
}
|
||||
|
||||
static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::TensorRef& constant,
|
||||
size_t i, std::string_view input_name, const std::vector<int64_t>& perm) {
|
||||
// Create new transposed initializer
|
||||
auto rank = perm.size();
|
||||
auto shape = constant.Shape();
|
||||
std::vector<uint8_t> data = constant.Data();
|
||||
std::vector<uint8_t> new_data(data.size());
|
||||
size_t bytes_per_val = data.size() / rank;
|
||||
|
||||
uint8_t* dst = new_data.data();
|
||||
for (size_t j = 0; j < rank; ++j) {
|
||||
uint8_t* src = data.data() + perm[j] * bytes_per_val;
|
||||
std::memcpy(dst, src, bytes_per_val);
|
||||
dst += bytes_per_val;
|
||||
}
|
||||
|
||||
std::string_view new_initializer = graph.AddInitializer(constant.DType(), shape, new_data);
|
||||
node.SetInput(i, new_initializer);
|
||||
if (!graph.HasValueConsumers(input_name)) {
|
||||
graph.RemoveInitializer(input_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one,
|
||||
// or transpose an initializer.
|
||||
static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
|
||||
|
|
@ -469,6 +497,18 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
|
|||
|
||||
// Case 1: input is a constant with a known list of consumer nodes
|
||||
if (constant != nullptr && consumers->comprehensive) {
|
||||
// Input is scalar, return early.
|
||||
if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) {
|
||||
return;
|
||||
}
|
||||
// This is a special case where the constant is 1D with length == perm.
|
||||
// TODO: TransposeInitializer should be updated to handle this case.
|
||||
// Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if
|
||||
// there are no other consumers.
|
||||
if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast<int64_t>(perm.size())) {
|
||||
Permute1DConstant(graph, node, *constant, i, input, perm);
|
||||
return;
|
||||
}
|
||||
if (consumers->nodes.size() > 0) {
|
||||
// Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv
|
||||
// to counteract the effect. These Transposes will hopefully be optimized out later.
|
||||
|
|
@ -496,6 +536,10 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
|
|||
}
|
||||
node.SetInput(i, pre_transpose_value);
|
||||
return;
|
||||
} else if (*perm2 == perm) {
|
||||
// we are trying to add a duplicate transpose.
|
||||
// do nothing and return
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove
|
||||
|
|
@ -513,7 +557,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
|
|||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Case 3: A Transpose op might already exist
|
||||
for (size_t j = 0; j < consumers->nodes.size(); ++j) {
|
||||
api::NodeRef& consumer = *consumers->nodes[j];
|
||||
|
|
@ -533,7 +577,7 @@ static void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i,
|
|||
}
|
||||
|
||||
// Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank.
|
||||
static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank,
|
||||
static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank,
|
||||
const std::vector<size_t>& input_indices) {
|
||||
auto inputs = node.Inputs();
|
||||
|
||||
|
|
@ -563,7 +607,7 @@ static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t tar
|
|||
}
|
||||
|
||||
// Transposes specified inputs according to perm.
|
||||
// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it.
|
||||
// NOTE: if a Transpose is expected to be above an input to this node, use the inverse of its permutation to cancel it.
|
||||
static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector<int64_t>& perm,
|
||||
const std::vector<size_t>& input_indices) {
|
||||
auto perm_inv = InvertPerm(perm);
|
||||
|
|
@ -573,7 +617,7 @@ static void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::ve
|
|||
}
|
||||
|
||||
inline static void TransposeFirstInput(OptimizerCtx& ctx, api::NodeRef& node, const std::vector<int64_t>& perm) {
|
||||
std::vector<size_t> indices {0};
|
||||
std::vector<size_t> indices{0};
|
||||
TransposeInputs(ctx, node, perm, indices);
|
||||
}
|
||||
|
||||
|
|
@ -630,7 +674,7 @@ static void TransposeOutputs(OptimizerCtx& ctx, api::NodeRef& node, const std::v
|
|||
// A rank of 5 is used if rank cannot be determined since 5 is the largest rank we expect from something like a Conv
|
||||
// and an unknown rank likely corresponds to a data-carrying (non-weight) tensor, which will be large.
|
||||
|
||||
// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown.
|
||||
// Given a value, returns the rank of the value excluding dimensions of value 1. Returns 5 if the rank is unknown.
|
||||
static int EstimateValueRank(api::GraphRef& graph, std::string_view input) {
|
||||
auto value_info = graph.GetValueInfo(input);
|
||||
std::optional<std::vector<int64_t>> shape = value_info->Shape();
|
||||
|
|
@ -774,7 +818,7 @@ std::vector<size_t> NonScalarInputs(OptimizerCtx& ctx, api::NodeRef& node) {
|
|||
constexpr HandlerInfo broadcast_node_handler = {&NonScalarInputs, &HandleSimpleNodeBroadcast};
|
||||
|
||||
// Transposes all inputs and all outputs. Updates axis attribute.
|
||||
static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis=std::nullopt) {
|
||||
static bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis = std::nullopt) {
|
||||
size_t rank = args.perm.size();
|
||||
std::optional<int64_t> axis = args.node.GetAttributeInt("axis");
|
||||
if (axis == std::nullopt) {
|
||||
|
|
@ -848,7 +892,6 @@ static bool HandleShape(HandlerArgs& args) {
|
|||
std::vector<int64_t> new_perm;
|
||||
// For opset 15, Shape(Transpose(x, perm))[starts:stops] = Gather(Shape(x), perm[starts:stops])
|
||||
if (args.ctx.opset >= 15) {
|
||||
|
||||
// Assign new_perm = perm[starts:stops]
|
||||
int64_t start = args.node.GetAttributeIntDefault("start", 0);
|
||||
int64_t end = args.node.GetAttributeIntDefault("end", rank_int);
|
||||
|
|
@ -870,7 +913,7 @@ static bool HandleShape(HandlerArgs& args) {
|
|||
}
|
||||
|
||||
// Make new_perm initializer
|
||||
std::vector<int64_t> perm_shape {gsl::narrow_cast<int64_t>(new_perm.size())};
|
||||
std::vector<int64_t> perm_shape{gsl::narrow_cast<int64_t>(new_perm.size())};
|
||||
std::string_view perm_const = AddInitializerInt64(args.ctx.graph, perm_shape, new_perm);
|
||||
|
||||
// X -> Shape -> Y, Gather
|
||||
|
|
@ -900,7 +943,7 @@ static bool HandleShape(HandlerArgs& args) {
|
|||
constexpr HandlerInfo shape_handler = {&FirstInput, &HandleShape, /*transposes_outputs*/ false};
|
||||
|
||||
// Permutes a 1D node input by creating a new initializer or inserting a Gather op
|
||||
void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector<int64_t>& perm) {
|
||||
static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector<int64_t>& perm) {
|
||||
size_t rank = perm.size();
|
||||
int64_t rank_int = gsl::narrow_cast<int64_t>(rank);
|
||||
|
||||
|
|
@ -909,25 +952,7 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std:
|
|||
if (constant != nullptr) {
|
||||
auto shape = constant->Shape();
|
||||
if (shape.size() == 1 && (shape[0] == rank_int || shape[0] == 0)) {
|
||||
// Create new transposed initializer
|
||||
std::vector<uint8_t> data = constant->Data();
|
||||
std::vector<uint8_t> new_data(data.size());
|
||||
size_t bytes_per_val = data.size() / rank;
|
||||
|
||||
uint8_t* dst = new_data.data();
|
||||
for (size_t j = 0; j < rank; ++j) {
|
||||
uint8_t* src = data.data() + perm[j] * bytes_per_val;
|
||||
for (size_t k = 0; k < bytes_per_val; ++k) {
|
||||
*dst++ = *src++;
|
||||
}
|
||||
}
|
||||
|
||||
std::string_view new_initializer = graph.AddInitializer(constant->DType(), shape, new_data);
|
||||
node.SetInput(i, new_initializer);
|
||||
if (!graph.HasValueConsumers(input)) {
|
||||
graph.RemoveInitializer(input);
|
||||
}
|
||||
|
||||
Permute1DConstant(graph, node, *constant, i, input, perm);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -942,33 +967,39 @@ void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std:
|
|||
node.SetInput(i, gather_output);
|
||||
}
|
||||
|
||||
//static bool HandleResize(HandlerArgs& args) {
|
||||
// static bool HandleResize(HandlerArgs& args) {
|
||||
// auto inputs = args.node.Inputs();
|
||||
// int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
|
||||
//
|
||||
// auto p = ChannelFirstToLastPerm(rank_int);
|
||||
// auto& perm = p == args.perm ? args.perm : args.perm_inv;
|
||||
// auto& perm_inv = p == args.perm ? args.perm_inv : args.perm;
|
||||
//
|
||||
// if (args.ctx.opset < 11) {
|
||||
// PermuteInput(args.ctx.graph, args.node, 1, args.perm_inv);
|
||||
// } else {
|
||||
// if (inputs[1] != "") {
|
||||
// std::vector<int64_t> double_perm_inv = args.perm_inv;
|
||||
// double_perm_inv.reserve(2 * args.perm_inv.size());
|
||||
// for (int64_t p : args.perm_inv) {
|
||||
// double_perm_inv.push_back(p + rank_int);
|
||||
// }
|
||||
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
|
||||
// }
|
||||
// for (size_t i = 2; i < inputs.size(); ++i) {
|
||||
// if (inputs[i] != "") {
|
||||
// PermuteInput(args.ctx.graph, args.node, i, args.perm_inv);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// PermuteInput(args.ctx.graph, args.node, 1, perm);
|
||||
// } else {
|
||||
// if (inputs[1] != "") {
|
||||
// std::vector<int64_t> double_perm_inv = perm;
|
||||
// double_perm_inv.reserve(2 * args.perm.size());
|
||||
// for (int64_t p1 : perm) {
|
||||
// double_perm_inv.push_back(p1 + rank_int);
|
||||
// }
|
||||
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
|
||||
// }
|
||||
// for (size_t i = 2; i < inputs.size(); ++i) {
|
||||
// if (inputs[i] != "") {
|
||||
// PermuteInput(args.ctx.graph, args.node, i, perm);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// TransposeFirstInput(args.ctx, args.node, args.perm_inv);
|
||||
// TransposeOutputs(args.ctx, args.node, args.perm);
|
||||
// TransposeFirstInput(args.ctx, args.node, perm);
|
||||
// TransposeOutputs(args.ctx, args.node, perm_inv);
|
||||
//
|
||||
// return true;
|
||||
//}
|
||||
// SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, args.node.OpType(), "com.microsoft.nhwc");
|
||||
//
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};
|
||||
|
||||
|
|
@ -1028,7 +1059,6 @@ static bool HandleReduceOp(HandlerArgs& args) {
|
|||
}
|
||||
|
||||
} else {
|
||||
|
||||
if (!NormalizeAndValidateAxes(*axes, args.perm.size())) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1143,7 +1173,6 @@ static bool HandleSqueeze(HandlerArgs& args) {
|
|||
if (!args.ctx.graph.HasValueConsumers(axes_inp)) {
|
||||
args.ctx.graph.RemoveInitializer(axes_inp);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Transpose inputs/outputs
|
||||
|
|
@ -1178,25 +1207,31 @@ static bool HandleUnsqueeze(HandlerArgs& args) {
|
|||
|
||||
constexpr HandlerInfo unsqueeze_handler = {&FirstInput, &HandleUnsqueeze};
|
||||
|
||||
static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) {
|
||||
size_t rank = args.perm.size();
|
||||
|
||||
if (args.ctx.opset >= 13) {
|
||||
static bool HandleQuantizeDequantizeScale(const api::GraphRef& graph, const std::vector<int64_t>& perm,
|
||||
api::NodeRef& node, int64_t opset) {
|
||||
if (opset >= 13) {
|
||||
size_t rank = perm.size();
|
||||
// Update axis in Opset >= 13 if scale/zero_point are non-scalar
|
||||
auto inputs = args.node.Inputs();
|
||||
auto inputs = node.Inputs();
|
||||
|
||||
std::optional<std::vector<int64_t>> inp_shape = args.ctx.graph.GetValueInfo(inputs[1])->Shape();
|
||||
bool scalar_params = inp_shape != std::nullopt && inp_shape->size() == 0;
|
||||
auto inp_shape = graph.GetValueInfo(inputs[1])->Shape();
|
||||
bool scalar_params = inp_shape.has_value() && inp_shape->size() == 0;
|
||||
|
||||
if (!scalar_params) {
|
||||
int64_t axis = args.node.GetAttributeIntDefault("axis", 1);
|
||||
int64_t axis = node.GetAttributeIntDefault("axis", 1);
|
||||
if (!NormalizeAndValidateAxis(axis, rank)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
args.node.SetAttributeInt("axis", args.perm[gsl::narrow_cast<size_t>(axis)]);
|
||||
node.SetAttributeInt("axis", perm[gsl::narrow_cast<size_t>(axis)]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool HandleQuantizeDequantizeLinear(HandlerArgs& args) {
|
||||
if (!HandleQuantizeDequantizeScale(args.ctx.graph, args.perm, args.node, args.ctx.opset)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TransposeFirstInput(args.ctx, args.node, args.perm_inv);
|
||||
TransposeOutputs(args.ctx, args.node, args.perm);
|
||||
|
|
@ -1215,7 +1250,7 @@ static bool HandleArgMinMax(HandlerArgs& args) {
|
|||
return false;
|
||||
}
|
||||
int64_t new_axis = args.perm[gsl::narrow_cast<size_t>(axis)];
|
||||
std::vector<int64_t> new_axes {new_axis};
|
||||
std::vector<int64_t> new_axes{new_axis};
|
||||
args.node.SetAttributeInt("axis", new_axis);
|
||||
|
||||
TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs);
|
||||
|
|
@ -1368,7 +1403,7 @@ static bool HandleTile(HandlerArgs& args) {
|
|||
} else {
|
||||
// Case 2: Repeats is computed. Insert Gather node.
|
||||
std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv);
|
||||
std::vector<std::string_view> gather_inputs {repeats_inp, perm_inv_const};
|
||||
std::vector<std::string_view> gather_inputs{repeats_inp, perm_inv_const};
|
||||
auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1);
|
||||
api::NodeRef& gather_node = *gather_node_ptr;
|
||||
std::string_view gather_output = gather_node.Outputs()[0];
|
||||
|
|
@ -1398,7 +1433,7 @@ static bool HandleTranspose(HandlerArgs& args) {
|
|||
|
||||
if (args.perm_inv == *node_perm) {
|
||||
// Case 1: Permutations cancel.
|
||||
auto consumers = args.ctx.graph.GetValueConsumers(args.node.Outputs()[0]);
|
||||
auto consumers = args.ctx.graph.GetValueConsumers(node_output);
|
||||
if (consumers->comprehensive) {
|
||||
// If possible, replace references to output of 2nd transpose with input to 1st
|
||||
ReplaceValueReferences(consumers->nodes, node_output, transpose_input);
|
||||
|
|
@ -1425,14 +1460,13 @@ static bool HandleTranspose(HandlerArgs& args) {
|
|||
} else {
|
||||
// Worst-case scenario: Both parent output and 2nd transpose output cannot be removed (both graph outputs)
|
||||
// despite computing the same value. Use an Identity op instead.
|
||||
std::vector<std::string_view> single_empty_input {""};
|
||||
std::vector<std::string_view> single_empty_input{""};
|
||||
auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1);
|
||||
api::NodeRef& identity = *identity_ptr;
|
||||
args.ctx.graph.MoveOutput(args.node, 0, identity, 0);
|
||||
identity.SetInput(0, transpose_input);
|
||||
}
|
||||
}
|
||||
|
||||
// In any case, the 2nd transpose can be removed.
|
||||
args.ctx.graph.RemoveNode(args.node);
|
||||
} else {
|
||||
|
|
@ -1442,7 +1476,7 @@ static bool HandleTranspose(HandlerArgs& args) {
|
|||
args.node.SetInput(0, transpose_input);
|
||||
}
|
||||
|
||||
// 2nd transpose no longer references 1st. Remove 2nd if possible.
|
||||
// 2nd transpose no longer references 1st. Remove first if possible.
|
||||
if (!args.ctx.graph.HasValueConsumers(args.transpose.Outputs()[0])) {
|
||||
args.ctx.graph.RemoveNode(args.transpose);
|
||||
}
|
||||
|
|
@ -1484,7 +1518,7 @@ constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &Han
|
|||
|
||||
static bool HandleQLinearPoolOp(HandlerArgs& args) {
|
||||
// Swap between channel first/last variants. Only works for applicable values of perm.
|
||||
int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 1);
|
||||
int64_t channels_last = args.node.GetAttributeIntDefault("channels_last", 0);
|
||||
size_t rank = args.perm.size();
|
||||
if (rank < 2) return false;
|
||||
auto p = ChannelLastToFirstPerm(rank);
|
||||
|
|
@ -1500,7 +1534,11 @@ static bool HandleQLinearPoolOp(HandlerArgs& args) {
|
|||
constexpr HandlerInfo q_linear_pool_op_handler = {&FirstInput, &HandleQLinearPoolOp};
|
||||
|
||||
static bool HandleMaxPool(HandlerArgs& args) {
|
||||
// Replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool.
|
||||
// For CPU EP replace with NhwcMaxPool if possible. Only int8 and uint8 dtypes are supported by NhwcMaxPool.
|
||||
if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto outputs = args.node.Outputs();
|
||||
if (outputs.size() == 2 && outputs[1] != "") {
|
||||
// Can't optimize if optional "indices" output is provided
|
||||
|
|
@ -1520,7 +1558,6 @@ static bool HandleMaxPool(HandlerArgs& args) {
|
|||
|
||||
auto new_node = SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, "NhwcMaxPool", "com.microsoft");
|
||||
new_node->ClearAttribute("storage_order"); // Only relevant for indices output. Prohibited for NhwcMaxPool.
|
||||
|
||||
TransposeFirstInput(args.ctx, *new_node, args.perm_inv);
|
||||
TransposeOutputs(args.ctx, *new_node, args.perm);
|
||||
return true;
|
||||
|
|
@ -1529,71 +1566,121 @@ static bool HandleMaxPool(HandlerArgs& args) {
|
|||
constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool};
|
||||
|
||||
// TODO: check binary size of this and replace it with constexpr if large
|
||||
static const std::unordered_map<std::string_view, const HandlerInfo&> handler_map {
|
||||
static const std::unordered_map<std::string_view, const HandlerInfo&> handler_map{
|
||||
|
||||
{"Cast", simple_node_handler}, {"Exp", simple_node_handler}, {"Identity", simple_node_handler},
|
||||
{"LeakyRelu", simple_node_handler}, {"Log", simple_node_handler}, {"Reciprocal", simple_node_handler},
|
||||
{"Relu", simple_node_handler}, {"Sigmoid", simple_node_handler}, {"Sqrt", simple_node_handler},
|
||||
{"Tanh", simple_node_handler}, {"Abs", simple_node_handler}, {"Not", simple_node_handler},
|
||||
{"Ceil", simple_node_handler}, {"Floor", simple_node_handler}, {"Neg", simple_node_handler},
|
||||
{"Erf", simple_node_handler}, {"HardSigmoid", simple_node_handler}, {"Round", simple_node_handler},
|
||||
{"IsInf", simple_node_handler}, {"IsNaN", simple_node_handler},
|
||||
{"Selu", simple_node_handler}, {"Shrink", simple_node_handler}, {"Sign", simple_node_handler},
|
||||
{"Softplus", simple_node_handler}, {"Softsign", simple_node_handler}, {"ThresholdedRelu", simple_node_handler},
|
||||
{"Celu", simple_node_handler}, {"HardSwish", simple_node_handler},
|
||||
{"Cast", simple_node_handler},
|
||||
{"Exp", simple_node_handler},
|
||||
{"Identity", simple_node_handler},
|
||||
{"LeakyRelu", simple_node_handler},
|
||||
{"Log", simple_node_handler},
|
||||
{"Reciprocal", simple_node_handler},
|
||||
{"Relu", simple_node_handler},
|
||||
{"Sigmoid", simple_node_handler},
|
||||
{"Sqrt", simple_node_handler},
|
||||
{"Tanh", simple_node_handler},
|
||||
{"Abs", simple_node_handler},
|
||||
{"Not", simple_node_handler},
|
||||
{"Ceil", simple_node_handler},
|
||||
{"Floor", simple_node_handler},
|
||||
{"Neg", simple_node_handler},
|
||||
{"Erf", simple_node_handler},
|
||||
{"HardSigmoid", simple_node_handler},
|
||||
{"Round", simple_node_handler},
|
||||
{"IsInf", simple_node_handler},
|
||||
{"IsNaN", simple_node_handler},
|
||||
{"Selu", simple_node_handler},
|
||||
{"Shrink", simple_node_handler},
|
||||
{"Sign", simple_node_handler},
|
||||
{"Softplus", simple_node_handler},
|
||||
{"Softsign", simple_node_handler},
|
||||
{"ThresholdedRelu", simple_node_handler},
|
||||
{"Celu", simple_node_handler},
|
||||
{"HardSwish", simple_node_handler},
|
||||
|
||||
{"Sin", simple_node_handler}, {"Cos", simple_node_handler}, {"Tan", simple_node_handler},
|
||||
{"Sinh", simple_node_handler}, {"Cosh", simple_node_handler}, {"Tanh", simple_node_handler},
|
||||
{"Asin", simple_node_handler}, {"Acos", simple_node_handler}, {"Atan", simple_node_handler},
|
||||
{"Asinh", simple_node_handler}, {"Acosh", simple_node_handler}, {"Atanh", simple_node_handler},
|
||||
{"Sin", simple_node_handler},
|
||||
{"Cos", simple_node_handler},
|
||||
{"Tan", simple_node_handler},
|
||||
{"Sinh", simple_node_handler},
|
||||
{"Cosh", simple_node_handler},
|
||||
{"Tanh", simple_node_handler},
|
||||
{"Asin", simple_node_handler},
|
||||
{"Acos", simple_node_handler},
|
||||
{"Atan", simple_node_handler},
|
||||
{"Asinh", simple_node_handler},
|
||||
{"Acosh", simple_node_handler},
|
||||
{"Atanh", simple_node_handler},
|
||||
|
||||
{"Add", broadcast_node_handler}, {"Max", broadcast_node_handler}, {"Min", broadcast_node_handler},
|
||||
{"Mul", broadcast_node_handler}, {"Sub", broadcast_node_handler}, {"Div", broadcast_node_handler},
|
||||
{"And", broadcast_node_handler}, {"Or", broadcast_node_handler}, {"Xor", broadcast_node_handler},
|
||||
{"Mod", broadcast_node_handler}, {"PRelu", broadcast_node_handler}, {"BitShift", broadcast_node_handler},
|
||||
{"Equal", broadcast_node_handler}, {"Greater", broadcast_node_handler}, {"Less", broadcast_node_handler},
|
||||
{"GreaterOrEqual", broadcast_node_handler}, {"LessOrEqual", broadcast_node_handler},
|
||||
{"Mean", broadcast_node_handler}, {"Sum", broadcast_node_handler}, {"Pow", broadcast_node_handler},
|
||||
{"Where", broadcast_node_handler},
|
||||
{"Add", broadcast_node_handler},
|
||||
{"Max", broadcast_node_handler},
|
||||
{"Min", broadcast_node_handler},
|
||||
{"Mul", broadcast_node_handler},
|
||||
{"Sub", broadcast_node_handler},
|
||||
{"Div", broadcast_node_handler},
|
||||
{"And", broadcast_node_handler},
|
||||
{"Or", broadcast_node_handler},
|
||||
{"Xor", broadcast_node_handler},
|
||||
{"Mod", broadcast_node_handler},
|
||||
{"PRelu", broadcast_node_handler},
|
||||
{"BitShift", broadcast_node_handler},
|
||||
{"Equal", broadcast_node_handler},
|
||||
{"Greater", broadcast_node_handler},
|
||||
{"Less", broadcast_node_handler},
|
||||
{"GreaterOrEqual", broadcast_node_handler},
|
||||
{"LessOrEqual", broadcast_node_handler},
|
||||
{"Mean", broadcast_node_handler},
|
||||
{"Sum", broadcast_node_handler},
|
||||
{"Pow", broadcast_node_handler},
|
||||
{"Where", broadcast_node_handler},
|
||||
|
||||
{"Clip", node_1_inp_handler}, {"CastLike", node_1_inp_handler},
|
||||
{"Clip", node_1_inp_handler},
|
||||
{"CastLike", node_1_inp_handler},
|
||||
|
||||
{"Transpose", transpose_handler},
|
||||
{"Concat", concat_handler},
|
||||
{"Split", split_handler},
|
||||
{"Shape", shape_handler},
|
||||
{"Pad", pad_handler},
|
||||
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
|
||||
// https://github.com/microsoft/onnxruntime/issues/9857
|
||||
//{"Resize", resize_handler},
|
||||
{"ReduceSum", reduce_sum_handler},
|
||||
{"Transpose", transpose_handler},
|
||||
{"Concat", concat_handler},
|
||||
{"Split", split_handler},
|
||||
{"Shape", shape_handler},
|
||||
{"Pad", pad_handler},
|
||||
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
|
||||
// https://github.com/microsoft/onnxruntime/issues/9857
|
||||
// {"Resize", resize_handler},
|
||||
{"ReduceSum", reduce_sum_handler},
|
||||
|
||||
{"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, {"ReduceMax", reduce_op_handler},
|
||||
{"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler},
|
||||
{"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler},
|
||||
{"ReduceLogSum", reduce_op_handler},
|
||||
{"ReduceLogSumExp", reduce_op_handler},
|
||||
{"ReduceMax", reduce_op_handler},
|
||||
{"ReduceMean", reduce_op_handler},
|
||||
{"ReduceMin", reduce_op_handler},
|
||||
{"ReduceProd", reduce_op_handler},
|
||||
{"ReduceSumSquare", reduce_op_handler},
|
||||
{"ReduceL1", reduce_op_handler},
|
||||
{"ReduceL2", reduce_op_handler},
|
||||
|
||||
{"ArgMin", arg_min_max_handler}, {"ArgMax", arg_min_max_handler},
|
||||
{"ArgMin", arg_min_max_handler},
|
||||
{"ArgMax", arg_min_max_handler},
|
||||
|
||||
{"Squeeze", squeeze_handler},
|
||||
{"Unsqueeze", unsqueeze_handler},
|
||||
{"Slice", slice_handler},
|
||||
{"Tile", tile_handler},
|
||||
{"Squeeze", squeeze_handler},
|
||||
{"Unsqueeze", unsqueeze_handler},
|
||||
{"Slice", slice_handler},
|
||||
{"Tile", tile_handler},
|
||||
|
||||
{"Softmax", soft_hard_max_handler}, {"Hardmax", soft_hard_max_handler}, {"LogSoftmax", soft_hard_max_handler},
|
||||
{"Softmax", soft_hard_max_handler},
|
||||
{"Hardmax", soft_hard_max_handler},
|
||||
{"LogSoftmax", soft_hard_max_handler},
|
||||
|
||||
{"QuantizeLinear", quantize_dequantize_linear_handler}, {"DequantizeLinear", quantize_dequantize_linear_handler},
|
||||
{"QuantizeLinear", quantize_dequantize_linear_handler},
|
||||
{"DequantizeLinear", quantize_dequantize_linear_handler},
|
||||
};
|
||||
|
||||
static const std::unordered_map<std::string_view, const HandlerInfo&> extended_handler_map{
|
||||
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
|
||||
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
|
||||
{"com.microsoft.QLinearLeakyRelu", node_1_inp_handler},
|
||||
{"com.microsoft.QLinearConcat", q_linear_concat_handler},
|
||||
{"com.microsoft.QLinearAdd", q_linear_binary_op_handler},
|
||||
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
|
||||
{"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler},
|
||||
{"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler},
|
||||
{"MaxPool", max_pool_op_handler},
|
||||
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
|
||||
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
|
||||
{"com.microsoft.QLinearLeakyRelu", node_1_inp_handler},
|
||||
{"com.microsoft.QLinearConcat", q_linear_concat_handler},
|
||||
{"com.microsoft.QLinearAdd", q_linear_binary_op_handler},
|
||||
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
|
||||
{"com.microsoft.QLinearAveragePool", q_linear_pool_op_handler},
|
||||
{"com.microsoft.QLinearGlobalAveragePool", q_linear_pool_op_handler},
|
||||
{"MaxPool", max_pool_op_handler},
|
||||
};
|
||||
|
||||
static const HandlerInfo* GetHandler(api::NodeRef& node, bool allow_extended_ops) {
|
||||
|
|
@ -1674,7 +1761,9 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef&
|
|||
}
|
||||
|
||||
// Returns nullopt if graph opset is unsupported.
|
||||
std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops) {
|
||||
std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allow_extended_ops,
|
||||
const std::string& provider_type, OptimizerMode mode,
|
||||
const std::unordered_set<std::string_view>& layout_sensitive_ops) {
|
||||
auto opset = graph.Opset("");
|
||||
if (opset == std::nullopt) {
|
||||
opset = graph.Opset("ai.onnx");
|
||||
|
|
@ -1688,14 +1777,17 @@ std::optional<OptimizerCtx> MakeOptimizerContext(api::GraphRef& graph, bool allo
|
|||
allow_extended_ops = false;
|
||||
}
|
||||
}
|
||||
OptimizerCtx ctx{*opset, graph, allow_extended_ops, /*skip_cost_check*/ false};
|
||||
|
||||
// during layout transformation we want to push the transposes as far out as possible.
|
||||
// it is important that the EP gets the entire graph in the layout it prefers.
|
||||
bool skip_cost_check = mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM;
|
||||
OptimizerCtx ctx{*opset, graph, allow_extended_ops, skip_cost_check, provider_type, mode, layout_sensitive_ops};
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Performs optimization. General algorithm: iterate over nodes in topological order. If a node has a transpose
|
||||
// as input, push it through if the transpose cost does not increase and is likely to decrease.
|
||||
bool OptimizeImpl(OptimizerCtx& ctx) {
|
||||
|
||||
const std::vector<std::unique_ptr<api::NodeRef>> nodes = ctx.graph.Nodes();
|
||||
|
||||
std::unordered_set<std::string> outputs_leading_to_transpose;
|
||||
|
|
@ -1703,7 +1795,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
|
|||
// First iterate over sorted nodes in reverse order to find which outputs have paths through supported ops to
|
||||
// transpose nodes. We pull push transposes towards these outputs.
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
|
||||
api::NodeRef& node = *nodes[nodes.size() - i - 1];
|
||||
if (node.IsOp("Transpose")) {
|
||||
outputs_leading_to_transpose.insert(std::string(node.Inputs()[0]));
|
||||
|
|
@ -1731,6 +1822,13 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
|
|||
// New transpose nodes are inserted, but always as an input to an existing node.
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
api::NodeRef& node = *nodes[i];
|
||||
if (ctx.mode == OptimizerMode::OPTIMIZE_LAYOUT_TRANSFORM &&
|
||||
ctx.layout_sensitive_ops.count(node.OpType()) && node.GetExecutionProviderType() != ctx.provider_type) {
|
||||
// If the current op is layout sensitive and it is not assigned to the given provider
|
||||
// then do not process transpose.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::string_view> inputs = node.Inputs();
|
||||
for (size_t j = 0; j < inputs.size(); ++j) {
|
||||
std::string_view inp = inputs[j];
|
||||
|
|
@ -1741,7 +1839,6 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
|
|||
if (transpose != nullptr && transpose->IsOp("Transpose")) {
|
||||
std::optional<std::vector<int64_t>> perm = GetPermAttrIfValid(*transpose);
|
||||
if (perm != std::nullopt) {
|
||||
std::vector<int64_t> perm_inv = InvertPerm(*perm);
|
||||
if (ProcessTranspose(ctx, *transpose, node, *perm, j, outputs_leading_to_transpose)) {
|
||||
changed = true;
|
||||
// Subsequent inputs may have changed and node may have been removed.
|
||||
|
|
@ -1751,11 +1848,69 @@ bool OptimizeImpl(OptimizerCtx& ctx) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Currently limiting the second optimization pass to layout transform mode
|
||||
// TODO: Enable this for both the modes.
|
||||
if (ctx.mode == OptimizerMode::OPTIMIZE_TRANSPOSE) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Run second optimization pass.
|
||||
// If any transpose succeeds a DQ node, move it above the DQ node.
|
||||
// In case of QDQ models this helps to preserve the QDQ node unit
|
||||
// In all other scenarios this is beneficial as well because moving transpose above DQ node is more efficient as
|
||||
// transpose node now handles less data.
|
||||
auto graph_nodes = ctx.graph.Nodes();
|
||||
for (size_t i = 1; i < graph_nodes.size(); i++) {
|
||||
if (graph_nodes[i]->OpType() == "Transpose") {
|
||||
auto& transpose_node = *graph_nodes[i];
|
||||
auto dq_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]);
|
||||
if (!dq_node || dq_node->OpType() != "DequantizeLinear") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto consumers = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]);
|
||||
bool is_part_of_qdq_unit = std::find_if(consumers->nodes.cbegin(), consumers->nodes.cend(),
|
||||
[](const std::unique_ptr<api::NodeRef>& node) {
|
||||
return node->OpType() == "QuantizeLinear";
|
||||
}) != consumers->nodes.cend();
|
||||
if (is_part_of_qdq_unit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update Dequantize Node and move the transpose above it
|
||||
auto perm = GetPermAttrIfValid(transpose_node);
|
||||
if (!perm.has_value()) {
|
||||
continue;
|
||||
}
|
||||
if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) {
|
||||
continue;
|
||||
}
|
||||
TransposeFirstInput(ctx, *dq_node, *perm);
|
||||
|
||||
// remove existing transpose node
|
||||
transpose_node.SetInput(0, "");
|
||||
ctx.graph.MoveOutput(transpose_node, 0, *dq_node, 0);
|
||||
ctx.graph.RemoveNode(transpose_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool Optimize(api::GraphRef& graph, bool allow_extended_ops) {
|
||||
auto ctx = MakeOptimizerContext(graph, allow_extended_ops);
|
||||
const std::unordered_set<std::string_view>& GetLayoutSensitiveOps() {
|
||||
// List of all layout sensitive ops defined in ONNX standard.
|
||||
static std::unordered_set<std::string_view> layout_sensitive_ops = {"Conv", "QLinearConv", "BatchNormalization",
|
||||
"AveragePool", "GlobalAveragePool", "MaxPool",
|
||||
"GlobalMaxPool", "LRN"};
|
||||
|
||||
return layout_sensitive_ops;
|
||||
}
|
||||
|
||||
bool Optimize(api::GraphRef& graph, bool allow_extended_ops,
|
||||
const std::string& provider_type, OptimizerMode mode,
|
||||
const std::unordered_set<std::string_view>& layout_sensitive_ops) {
|
||||
auto ctx = MakeOptimizerContext(graph, allow_extended_ops, provider_type, mode, layout_sensitive_ops);
|
||||
if (ctx == std::nullopt) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1781,15 +1936,14 @@ void WrapTransposesAroundNode(api::GraphRef& graph, api::NodeRef& node,
|
|||
|
||||
std::unique_ptr<api::NodeRef> SwapNodeOpTypeAndDomain(api::GraphRef& graph, api::NodeRef& node,
|
||||
std::string_view op_type, std::string_view domain) {
|
||||
auto inputs = node.Inputs();
|
||||
auto outputs = node.Outputs();
|
||||
auto new_node = graph.AddNode(op_type, inputs, outputs.size(), domain);
|
||||
auto new_node = graph.CopyNode(node, op_type, domain);
|
||||
|
||||
for (size_t j = 0; j < outputs.size(); ++j) {
|
||||
if (outputs[j] != "") {
|
||||
graph.MoveOutput(node, j, *new_node, j);
|
||||
}
|
||||
}
|
||||
new_node->CopyAttributes(node);
|
||||
graph.RemoveNode(node);
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
|
|||
};
|
||||
|
||||
result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
|
||||
gen_metadef_name, COREML);
|
||||
gen_metadef_name, COREML, kCoreMLExecutionProvider);
|
||||
|
||||
const auto num_of_partitions = result.size();
|
||||
const auto num_of_supported_nodes = std::transform_reduce(
|
||||
|
|
|
|||
|
|
@ -262,7 +262,7 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
shaper_.AddShape(name, operand_type.dimensions);
|
||||
|
||||
uint32_t index = 0;
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index));
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index));
|
||||
const size_t size = operand_type.GetOperandBlobByteSize();
|
||||
const size_t padded_size = GetPaddedByteSize(size);
|
||||
sizeAll += padded_size;
|
||||
|
|
@ -352,7 +352,7 @@ Status ModelBuilder::RegisterModelInputs() {
|
|||
shaper_.AddShape(input_name, operand_type.dimensions);
|
||||
|
||||
uint32_t index = 0;
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, false /* is_nhwc */, index));
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(input_name, operand_type, index));
|
||||
input_index_vec_.push_back(index);
|
||||
nnapi_model_->AddInput(input_name, operand_type);
|
||||
}
|
||||
|
|
@ -386,11 +386,6 @@ Status ModelBuilder::RegisterModelOutputs() {
|
|||
}
|
||||
|
||||
std::string nnapi_output_name = output_name;
|
||||
if (IsOperandNHWC(output_name)) {
|
||||
// We need to transpose the output still in nhwc back to nchw
|
||||
nnapi_output_name = GetUniqueName(output_name + "_nhwc_to_nchw");
|
||||
ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(*this, output_name, nnapi_output_name));
|
||||
}
|
||||
|
||||
output_index_vec_.push_back(operand_indices_[nnapi_output_name]);
|
||||
nnapi_model_->AddOutput(output_name, nnapi_output_name, operand_types_.at(nnapi_output_name));
|
||||
|
|
@ -405,10 +400,10 @@ void ModelBuilder::RegisterModelShaper() {
|
|||
|
||||
Status ModelBuilder::AddNewOperand(const std::string& name,
|
||||
const OperandType& operand_type,
|
||||
bool is_nhwc, uint32_t& index) {
|
||||
uint32_t& index) {
|
||||
LOGS_DEFAULT(VERBOSE) << "operand name: " << name;
|
||||
ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operand_type, index));
|
||||
RegisterOperand(name, index, operand_type, is_nhwc);
|
||||
RegisterOperand(name, index, operand_type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -432,13 +427,10 @@ Status ModelBuilder::AddNewNNAPIOperand(const OperandType& operand_type, uint32_
|
|||
}
|
||||
|
||||
void ModelBuilder::RegisterOperand(const std::string& name, uint32_t index,
|
||||
const OperandType& operand_type, bool is_nhwc) {
|
||||
const OperandType& operand_type) {
|
||||
operand_indices_[name] = index;
|
||||
operand_types_.emplace(name, operand_type);
|
||||
operands_.insert(name);
|
||||
|
||||
if (is_nhwc)
|
||||
RegisterNHWCOperand(name);
|
||||
}
|
||||
|
||||
Status ModelBuilder::SetOperandValue(uint32_t index,
|
||||
|
|
@ -466,7 +458,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
|
|||
const android::nn::wrapper::OperandType& operand_type) {
|
||||
shaper_.AddShape(name, operand_type.dimensions);
|
||||
uint32_t index = 0;
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, false /* is_nhwc */, index));
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(name, operand_type, index));
|
||||
const size_t size = operand_type.GetOperandBlobByteSize();
|
||||
|
||||
// for small size operand, the value will be copied
|
||||
|
|
@ -528,12 +520,11 @@ Status ModelBuilder::AddOperations() {
|
|||
|
||||
Status ModelBuilder::AddOperation(int op, const std::vector<uint32_t>& input_indices,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<OperandType>& types,
|
||||
const std::vector<bool>& is_nhwc_vec) {
|
||||
const std::vector<OperandType>& types) {
|
||||
std::vector<uint32_t> output_indices;
|
||||
for (size_t i = 0; i < types.size(); i++) {
|
||||
uint32_t index = 0;
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], is_nhwc_vec[i], index));
|
||||
ORT_RETURN_IF_ERROR(AddNewOperand(output_names[i], types[i], index));
|
||||
output_indices.push_back(index);
|
||||
}
|
||||
|
||||
|
|
@ -693,7 +684,6 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
|
|||
const auto& op_type = node_unit.GetNode().OpType();
|
||||
if (!Contains(op_builders, op_type))
|
||||
return nullptr;
|
||||
|
||||
return op_builders.at(op_type);
|
||||
}
|
||||
|
||||
|
|
@ -708,47 +698,13 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) {
|
|||
return unique_name;
|
||||
}
|
||||
|
||||
DataLayout ModelBuilder::GetPreferredLayout() const {
|
||||
return use_nchw_ ? DataLayout::NCHW : DataLayout::NHWC;
|
||||
}
|
||||
|
||||
const InitializedTensorSet& ModelBuilder::GetInitializerTensors() const {
|
||||
return graph_viewer_.GetAllInitializedTensors();
|
||||
}
|
||||
|
||||
void ModelBuilder::RegisterNHWCOperand(const std::string& name) {
|
||||
nhwc_operands_.insert(name);
|
||||
}
|
||||
|
||||
bool ModelBuilder::IsOperandNHWC(const std::string& name) const {
|
||||
return Contains(nhwc_operands_, name);
|
||||
}
|
||||
|
||||
bool ModelBuilder::GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name) {
|
||||
if (Contains(nhwc_to_nchw_map_, nhwc_name)) {
|
||||
nchw_name = nhwc_to_nchw_map_[nhwc_name];
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ModelBuilder::GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name) {
|
||||
if (Contains(nchw_to_nhwc_map_, nchw_name)) {
|
||||
nhwc_name = nchw_to_nhwc_map_[nchw_name];
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ModelBuilder::SetNHWCToNCHWOperandMap(const std::string& nhwc_name,
|
||||
const std::string& nchw_name) {
|
||||
ORT_RETURN_IF_NOT(!Contains(nhwc_to_nchw_map_, nhwc_name), "A previous nchw to nhwc map exists");
|
||||
nhwc_to_nchw_map_[nhwc_name] = nchw_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ModelBuilder::SetNCHWToNHWCOperandMap(const std::string& nchw_name,
|
||||
const std::string& nhwc_name) {
|
||||
ORT_RETURN_IF_NOT(!Contains(nchw_to_nhwc_map_, nchw_name), "A previous nchw to nhwc map exists");
|
||||
nchw_to_nhwc_map_[nchw_name] = nhwc_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace nnapi
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -13,6 +13,7 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
class GraphViewer;
|
||||
enum class DataLayout;
|
||||
class NodeUnit;
|
||||
class Node;
|
||||
class NodeArg;
|
||||
|
|
@ -47,8 +48,7 @@ class ModelBuilder {
|
|||
// Add an NNAPI operation (operator)
|
||||
common::Status AddOperation(int op, const std::vector<uint32_t>& input_indices,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<android::nn::wrapper::OperandType>& types,
|
||||
const std::vector<bool>& is_nhwc_vec);
|
||||
const std::vector<android::nn::wrapper::OperandType>& types);
|
||||
|
||||
// Find if the given node_unit has a fuseable activation (Relu/Relu1/Relu6)
|
||||
// For now we only support node_unit with a single output
|
||||
|
|
@ -69,8 +69,7 @@ class ModelBuilder {
|
|||
|
||||
// Register informations for a particular operand
|
||||
void RegisterOperand(const std::string& name, uint32_t index,
|
||||
const android::nn::wrapper::OperandType& operand_type,
|
||||
bool is_nhwc);
|
||||
const android::nn::wrapper::OperandType& operand_type);
|
||||
|
||||
// Generate an unique name for intermediate result
|
||||
std::string GetUniqueName(const std::string& base_name);
|
||||
|
|
@ -79,6 +78,9 @@ class ModelBuilder {
|
|||
void SetUseNCHW(bool use_nchw) { use_nchw_ = use_nchw; }
|
||||
bool UseNCHW() const { return use_nchw_; }
|
||||
|
||||
// Returns the preferred layout for this EP.
|
||||
DataLayout GetPreferredLayout() const;
|
||||
|
||||
// Relax fp32 computation to fp16
|
||||
// It is off by default
|
||||
void SetUseFp16(bool use_fp16) { use_fp16_ = use_fp16; }
|
||||
|
|
@ -106,21 +108,9 @@ class ModelBuilder {
|
|||
|
||||
const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
|
||||
|
||||
void RegisterNHWCOperand(const std::string& name);
|
||||
bool IsOperandNHWC(const std::string& name) const;
|
||||
|
||||
// Get the operand transposed to nchw/nhwc from given nhwc/nchw operand, if it exists
|
||||
bool GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name);
|
||||
bool GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name);
|
||||
|
||||
// Get the NodeUnit which contains the given node
|
||||
const NodeUnit& GetNodeUnit(const Node* node) const;
|
||||
|
||||
common::Status SetNHWCToNCHWOperandMap(const std::string& nhwc_name,
|
||||
const std::string& nchw_name);
|
||||
common::Status SetNCHWToNHWCOperandMap(const std::string& nchw_name,
|
||||
const std::string& nhwc_name);
|
||||
|
||||
private:
|
||||
const NnApi* nnapi_{nullptr};
|
||||
const GraphViewer& graph_viewer_;
|
||||
|
|
@ -148,12 +138,6 @@ class ModelBuilder {
|
|||
|
||||
std::unordered_map<std::string, std::shared_ptr<IOpSupportChecker>> op_support_checkers_;
|
||||
|
||||
// Operands in nhwc
|
||||
std::unordered_set<std::string> nhwc_operands_;
|
||||
|
||||
// Maps between nhwc and nchw, and vice versa
|
||||
std::unordered_map<std::string, std::string> nhwc_to_nchw_map_;
|
||||
std::unordered_map<std::string, std::string> nchw_to_nhwc_map_;
|
||||
|
||||
std::vector<uint32_t> input_index_vec_;
|
||||
std::vector<uint32_t> output_index_vec_;
|
||||
|
|
@ -206,7 +190,6 @@ class ModelBuilder {
|
|||
common::Status AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type, uint32_t& index);
|
||||
common::Status AddNewOperand(const std::string& name,
|
||||
const android::nn::wrapper::OperandType& operand_type,
|
||||
bool is_nhwc,
|
||||
uint32_t& index);
|
||||
|
||||
static const IOpBuilder* GetOpBuilder(const NodeUnit& node_unit);
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder,
|
|||
const std::string& input,
|
||||
const std::string& perm_name,
|
||||
std::vector<int32_t> perm,
|
||||
const std::string& output,
|
||||
bool output_is_nhwc) {
|
||||
const std::string& output) {
|
||||
auto& shaper(model_builder.GetShaper());
|
||||
const auto& operand_indices(model_builder.GetOperandIndices());
|
||||
const auto& operand_types(model_builder.GetOperandTypes());
|
||||
|
|
@ -59,102 +58,7 @@ Status AddTransposeOperator(ModelBuilder& model_builder,
|
|||
OperandType output_operand_type = operand_types.at(input);
|
||||
output_operand_type.SetDimensions(shaper[output]);
|
||||
return model_builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, {output},
|
||||
{output_operand_type}, {output_is_nhwc});
|
||||
}
|
||||
|
||||
Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder,
|
||||
const std::string& input,
|
||||
const std::string& output,
|
||||
bool nchw_to_nhwc) {
|
||||
ORT_RETURN_IF_NOT(!model_builder.UseNCHW(), "model_builder.UseNCHW() is on");
|
||||
const auto& shaper(model_builder.GetShaper());
|
||||
ORT_RETURN_IF_NOT(4 == shaper[input].size(),
|
||||
"TransposeBetweenNCHWAndNHWC input has to be a 4d tensor, actual dimensions: ", shaper[input].size());
|
||||
|
||||
std::string perm_name;
|
||||
std::vector<int32_t> perm;
|
||||
if (nchw_to_nhwc) {
|
||||
perm_name = model_builder.GetUniqueName(input + "nchw_to_nhwc_perm");
|
||||
perm = {0, 2, 3, 1};
|
||||
} else { // nhwc_to_nchw
|
||||
perm_name = model_builder.GetUniqueName(input + "nhwc_to_nchw_perm");
|
||||
perm = {0, 3, 1, 2};
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, nchw_to_nhwc));
|
||||
|
||||
if (nchw_to_nhwc) {
|
||||
ORT_RETURN_IF_ERROR(model_builder.SetNCHWToNHWCOperandMap(input, output));
|
||||
} else { // nhwc_to_nchw
|
||||
ORT_RETURN_IF_ERROR(model_builder.SetNHWCToNCHWOperandMap(input, output));
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Operand [" << input << "] with shape "
|
||||
<< Shape2String(shaper[input])
|
||||
<< " is transposed "
|
||||
<< (nchw_to_nhwc ? "nchw_to_nhwc" : "nhwc_to_nchw")
|
||||
<< " to [" << output << "] with shape "
|
||||
<< Shape2String(shaper[output]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransposeNHWCToNCHW(ModelBuilder& model_builder,
|
||||
const std::string& input,
|
||||
const std::string& output) {
|
||||
return TransposeBetweenNCHWAndNHWC(model_builder, input, output, false /* nchw_to_nhwc */);
|
||||
}
|
||||
|
||||
Status TransposeNCHWToNHWC(ModelBuilder& model_builder,
|
||||
const std::string& input,
|
||||
const std::string& output) {
|
||||
return TransposeBetweenNCHWAndNHWC(model_builder, input, output, true /* nchw_to_nhwc */);
|
||||
}
|
||||
|
||||
// Convert the input from nchw to nhwc
|
||||
// Caller should ensure input is currently in nchw format using ModelBuilder::IsOperandNHWC
|
||||
Status GetNHWCInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nhwc_input) {
|
||||
const auto& nchw_input = node_unit.Inputs()[input_index].node_arg.Name();
|
||||
if (!model_builder.GetNHWCOperand(nchw_input, nhwc_input)) {
|
||||
nhwc_input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc");
|
||||
ORT_RETURN_IF_ERROR(TransposeNCHWToNHWC(model_builder, nchw_input, nhwc_input));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convert the input from nhwc to nchw
|
||||
// Caller should ensure input is currently in nhwc format using ModelBuilder::IsOperandNHWC
|
||||
Status GetNCHWInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nchw_input) {
|
||||
const auto& nhwc_input = node_unit.Inputs()[input_index].node_arg.Name();
|
||||
if (!model_builder.GetNCHWOperand(nhwc_input, nchw_input)) {
|
||||
nchw_input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw");
|
||||
ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(model_builder, nhwc_input, nchw_input));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Transpose layouts if necessary for element wise operators with 2 inputs
|
||||
// and return the layout type of output tensor
|
||||
// If both inputs have same layout, the output will have the same layout
|
||||
// Otherwise we will need transpose the nhwc input back to nchw, and output will be nchw
|
||||
Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const NodeUnit& node_unit,
|
||||
std::string& input1, std::string& input2,
|
||||
bool& output_is_nhwc) {
|
||||
bool input1_is_nhwc = model_builder.IsOperandNHWC(input1);
|
||||
bool input2_is_nhwc = model_builder.IsOperandNHWC(input2);
|
||||
output_is_nhwc = false;
|
||||
|
||||
if (input1_is_nhwc == input2_is_nhwc) {
|
||||
output_is_nhwc = input1_is_nhwc;
|
||||
} else if (input1_is_nhwc) {
|
||||
// need transpose input1 back to nchw
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input1));
|
||||
} else { // input2_is_nhwc
|
||||
// need transpose input2 back to nchw
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 1, input2));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
{output_operand_type});
|
||||
}
|
||||
|
||||
static Status AddBinaryOperator(int32_t op_type,
|
||||
|
|
@ -164,7 +68,6 @@ static Status AddBinaryOperator(int32_t op_type,
|
|||
bool add_activation,
|
||||
int32_t fuse_code,
|
||||
const std::string& output,
|
||||
bool output_is_nhwc,
|
||||
float output_scale = 0.0f,
|
||||
int32_t output_zero_point = 0) {
|
||||
auto& shaper(model_builder.GetShaper());
|
||||
|
|
@ -183,7 +86,7 @@ static Status AddBinaryOperator(int32_t op_type,
|
|||
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output],
|
||||
output_scale, output_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_type, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +134,7 @@ static Status AddSqueezeOp(ModelBuilder& model_builder,
|
|||
ORT_RETURN_IF_ERROR(shaper.Squeeze(input, axes, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SQUEEZE, input_indices,
|
||||
{output}, {output_operand_type}, {false}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -595,6 +498,16 @@ static void AddInputToSkip(ModelBuilder& model_builder, const NodeUnitIODef& io_
|
|||
AddQuantizationScaleAndZeroPointToSkip(model_builder, *io_def.quant_param);
|
||||
}
|
||||
|
||||
static Status IsOpInRequiredLayout(bool use_nchw, const NodeUnit& node_unit) {
|
||||
bool is_op_nhwc = node_unit.Domain() == kMSInternalNHWCDomain;
|
||||
if (is_op_nhwc && use_nchw) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Expected layout and operator layout do not match. Possible bug in layout optimizer.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void CreateSharedOpBuilderImpl(const std::string& op_type,
|
||||
OpBuilderRegistrations& op_registrations,
|
||||
|
|
@ -715,10 +628,6 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
std::string input2 = inputs[1].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
|
||||
bool output_is_nhwc = false;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
TransposeBinaryOpInputLayout(model_builder, node_unit, input1, input2, output_is_nhwc));
|
||||
|
||||
float a_scale = 0.0f,
|
||||
b_scale = 0.0f,
|
||||
y_scale = 0.0f;
|
||||
|
|
@ -747,7 +656,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
return AddBinaryOperator(op_code, model_builder,
|
||||
input1, input2,
|
||||
add_activation, fuse_code,
|
||||
output, output_is_nhwc, y_scale, y_zero_point);
|
||||
output, y_scale, y_zero_point);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
|
@ -766,19 +675,18 @@ Status ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
|
||||
// skip this relu if it is some op's fuse output
|
||||
if (Contains(model_builder.GetFusedActivations(), input)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node_unit.Name() << "] fused";
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
|
||||
} else {
|
||||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RELU, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -825,15 +733,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
|
|||
ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension");
|
||||
}
|
||||
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
ORT_RETURN_IF_NOT(input_dims == 4, "Only 4D shape can be nhwc");
|
||||
|
||||
// we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc
|
||||
const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
|
||||
for (size_t i = 0; i < perm.size(); i++)
|
||||
perm[i] = axis_nchw_to_nhwc[perm[i]];
|
||||
}
|
||||
|
||||
// Check if the quantization scale and ZP are correct
|
||||
if (IsQuantizedOp(node_unit)) {
|
||||
float x_scale = 0.0f;
|
||||
|
|
@ -845,11 +744,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
|
|||
|
||||
std::string perm_name = model_builder.GetUniqueName(node_unit.Name() + input + "perm");
|
||||
|
||||
// It is possible this onnx transpose operator can be nchw->nhwc, but so far I don't see
|
||||
// any scenario will do this since onnx is nchw only, assume the output is always not nhwc
|
||||
// even it is, there will be extra transpose in the onnx model to convert it back to nchw
|
||||
// before conv/pool/... operators
|
||||
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output, false /* is_nhwc */));
|
||||
ORT_RETURN_IF_ERROR(AddTransposeOperator(model_builder, input, perm_name, perm, output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -884,7 +779,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
|
|||
}
|
||||
|
||||
// We can skip the Reshape if all the output edges satisfies both the following conditions
|
||||
// 1. The output the reshape/flatten is not an output of the graph
|
||||
// 1. The output of the reshape/flatten is not an output of the graph
|
||||
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
|
||||
// and not any other types of operator,
|
||||
// and the input rank >= 2 and output_rank == 2
|
||||
|
|
@ -975,7 +870,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
|
|||
// NNAPI CPU impl and NNAPI hardware accelerator impl
|
||||
if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
|
||||
// Since reshape can be skipped, only register the dimension and type, with same index and new name
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
|
||||
} else {
|
||||
// We still need to perform a reshape here
|
||||
// Add input
|
||||
|
|
@ -987,7 +882,7 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
|
|||
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type));
|
||||
input_indices.push_back(operand_indices.at(shape_name));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false}));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -998,10 +893,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
const auto& initializers(model_builder.GetInitializerTensors());
|
||||
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
// We want to transpose nhwc operand back to nchw before reshape
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
|
||||
const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name());
|
||||
std::vector<uint8_t> unpacked_tensor;
|
||||
|
|
@ -1099,10 +990,10 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
const auto tensor_imm_product_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_mul");
|
||||
Shape tensor_a_dimen = {size};
|
||||
|
||||
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
bool output_is_nhwc = input_is_nhwc;
|
||||
bool use_nchw = model_builder.UseNCHW();
|
||||
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
|
||||
|
||||
if (!input_is_nhwc) {
|
||||
if (use_nchw) {
|
||||
// the batch normalization is applied on C channel,
|
||||
// if the input is NC[HW], will need correct shape for tensor_a/b
|
||||
// to make sure we are broadcasting on the correct channel,
|
||||
|
|
@ -1126,8 +1017,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
model_builder,
|
||||
input, tensor_a_name,
|
||||
true /* add_activation */, ANEURALNETWORKS_FUSED_NONE,
|
||||
tensor_imm_product_name,
|
||||
output_is_nhwc));
|
||||
tensor_imm_product_name));
|
||||
|
||||
// Add
|
||||
int32_t fuse_code = model_builder.FindActivation(node_unit);
|
||||
|
|
@ -1135,8 +1025,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
model_builder,
|
||||
tensor_imm_product_name, tensor_b_name,
|
||||
true /* add_activation */, fuse_code,
|
||||
output,
|
||||
output_is_nhwc));
|
||||
output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -1190,16 +1079,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
bool use_nchw = model_builder.UseNCHW();
|
||||
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
bool output_is_nhwc = false;
|
||||
if (use_nchw) {
|
||||
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
|
||||
} else {
|
||||
output_is_nhwc = true;
|
||||
if (!input_is_nhwc) {
|
||||
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
|
||||
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
const auto& op_type = node_unit.OpType();
|
||||
|
|
@ -1290,7 +1170,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1361,16 +1241,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
auto input = inputs[0].node_arg.Name();
|
||||
bool use_nchw = model_builder.UseNCHW();
|
||||
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
bool output_is_nhwc = false;
|
||||
if (use_nchw) {
|
||||
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
|
||||
} else {
|
||||
output_is_nhwc = true;
|
||||
if (!input_is_nhwc) {
|
||||
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
|
||||
|
||||
const auto& weight = inputs[1].node_arg.Name();
|
||||
const auto& weight_tensor = *initializers.at(weight);
|
||||
|
|
@ -1559,7 +1430,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1579,7 +1450,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
auto to = helper.Get("to", 0);
|
||||
Type type;
|
||||
|
|
@ -1599,7 +1469,7 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(type, shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, {output},
|
||||
{output_operand_type}, {output_is_nhwc}));
|
||||
{output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1620,21 +1490,14 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
NodeAttrHelper helper(node_unit);
|
||||
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
bool output_is_nhwc = input_is_nhwc;
|
||||
|
||||
// TODO: Needs fix.
|
||||
if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) {
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
output_is_nhwc = false;
|
||||
// We want to transpose nhwc operand back to nchw before softmax
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
ORT_ENFORCE(model_builder.UseNCHW(),
|
||||
"For Android API Level < 29 input for softmax needs to be NCHW.");
|
||||
}
|
||||
|
||||
int32_t axis = helper.Get("axis", 1);
|
||||
if (output_is_nhwc) {
|
||||
const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
|
||||
axis = axis_nchw_to_nhwc[axis];
|
||||
}
|
||||
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
float beta = 1.f;
|
||||
|
|
@ -1650,7 +1513,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1672,14 +1535,13 @@ Status IdentityOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, con
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input)); // input
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1839,7 +1701,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
ORT_RETURN_IF_ERROR(shaper.FC(input1, input2, output));
|
||||
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output], y_scale, y_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices,
|
||||
{output}, {output_operand_type}, {false}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1896,8 +1758,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
bool is_qlinear_sigmoid = op_type == "QLinearSigmoid";
|
||||
|
||||
|
|
@ -1945,7 +1805,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
input_indices.push_back(operand_indices.at(input));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -1967,8 +1827,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
|
||||
std::vector<uint32_t> input_indices;
|
||||
const auto& input0 = inputs[0].node_arg.Name();
|
||||
bool all_input_have_same_layout = true;
|
||||
bool output_is_nhwc = false;
|
||||
const auto node_input_size = inputs.size();
|
||||
|
||||
// First if the inputs are uint8, we need verify all the inputs have same scale and zero points
|
||||
|
|
@ -1989,48 +1847,17 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
}
|
||||
}
|
||||
|
||||
// First we want to see if all the input are same layout
|
||||
for (size_t i = 0; i < node_input_size - 1; i++) {
|
||||
all_input_have_same_layout =
|
||||
all_input_have_same_layout &&
|
||||
model_builder.IsOperandNHWC(inputs[i].node_arg.Name()) ==
|
||||
model_builder.IsOperandNHWC(inputs[i + 1].node_arg.Name());
|
||||
}
|
||||
|
||||
std::vector<std::string> input_names;
|
||||
input_names.reserve(node_input_size);
|
||||
if (all_input_have_same_layout) {
|
||||
// if all the inputs are of same layout, output will be the same layout
|
||||
output_is_nhwc = model_builder.IsOperandNHWC(input0);
|
||||
|
||||
for (size_t i = 0; i < node_input_size; i++) {
|
||||
const auto& input = inputs[i].node_arg.Name();
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
input_names.push_back(input);
|
||||
}
|
||||
} else {
|
||||
// if all the inputs are not same layout,
|
||||
// will need transpos those nhwc tensors back to nchw
|
||||
for (size_t i = 0; i < node_input_size; i++) {
|
||||
auto input = inputs[i].node_arg.Name();
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, i, input));
|
||||
}
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
input_names.push_back(input);
|
||||
}
|
||||
for (size_t i = 0; i < node_input_size; i++) {
|
||||
const auto& input = inputs[i].node_arg.Name();
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
input_names.push_back(input);
|
||||
}
|
||||
|
||||
int rank = shaper[input0].size();
|
||||
int32_t axis = static_cast<int32_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
|
||||
|
||||
if (output_is_nhwc) {
|
||||
ORT_RETURN_IF_NOT(rank == 4,
|
||||
"nhwc is only on 4d shape, input ", input0, " has rank: ", rank);
|
||||
// we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc
|
||||
const uint32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2};
|
||||
axis = axis_nchw_to_nhwc[axis];
|
||||
}
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, axis);
|
||||
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
|
|
@ -2038,7 +1865,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
OperandType output_operand_type = operand_types.at(input0);
|
||||
output_operand_type.SetDimensions(shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2089,10 +1916,6 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
|
|||
|
||||
Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
// We want to transpose nhwc operand back to nchw before squeeze
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
|
||||
std::vector<int32_t> axes;
|
||||
ORT_RETURN_IF_ERROR(GetAxes(model_builder, node_unit, axes));
|
||||
|
|
@ -2121,7 +1944,6 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
float scale = 0.0f;
|
||||
int32_t zero_point = 0;
|
||||
|
|
@ -2134,7 +1956,7 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
|
|||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_QUANTIZE, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2161,7 +1983,6 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
|
|||
|
||||
const auto& input = inputs[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
float scale = 0.0;
|
||||
int32_t zero_point = 0;
|
||||
|
|
@ -2176,7 +1997,7 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
|
|||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_DEQUANTIZE, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2198,13 +2019,14 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
|
|||
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
auto use_nchw = model_builder.UseNCHW();
|
||||
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
|
||||
|
||||
if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) {
|
||||
// on android api level 28, we need to transpose the nchw input to nhwc
|
||||
output_is_nhwc = true;
|
||||
if (!model_builder.IsOperandNHWC(input)) {
|
||||
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
// it is very rare that users set nchw format when using nnapi. Therefore, instead of
|
||||
// adding the ability to support conversion we fail and stop.
|
||||
ORT_ENFORCE(!use_nchw, "NCHW format is not supported on android api level 28");
|
||||
}
|
||||
|
||||
auto alpha = helper.Get("alpha", 0.0001f);
|
||||
|
|
@ -2225,16 +2047,16 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
|
|||
// specify axis is only available on api level >= 29
|
||||
if (android_feature_level > ANEURALNETWORKS_FEATURE_LEVEL_2) {
|
||||
// ONNX LRN is always performed on C dimension
|
||||
int32_t axis = output_is_nhwc
|
||||
? 3 // nhwc
|
||||
: 1; // nchw
|
||||
int32_t axis = use_nchw
|
||||
? 1 // nchw
|
||||
: 3; // nhwc
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, axis);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2266,14 +2088,13 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
|
||||
if (Contains(model_builder.GetFusedActivations(), input)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node_unit.Name() << "] fused";
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc);
|
||||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2293,7 +2114,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2344,16 +2165,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
|
||||
auto input = inputs[0].node_arg.Name();
|
||||
bool use_nchw = model_builder.UseNCHW();
|
||||
bool input_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
bool output_is_nhwc = false;
|
||||
if (use_nchw) {
|
||||
ORT_RETURN_IF_NOT(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC");
|
||||
} else {
|
||||
output_is_nhwc = true;
|
||||
if (!input_is_nhwc) {
|
||||
ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(IsOpInRequiredLayout(use_nchw, node_unit));
|
||||
|
||||
// Check if the quantization scale and ZP is correct
|
||||
if (IsQuantizedOp(node_unit)) {
|
||||
|
|
@ -2373,16 +2185,19 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
bool using_half_pixel = coord_trans_mode == "half_pixel";
|
||||
bool using_align_corners = coord_trans_mode == "align_corners";
|
||||
|
||||
// if the node domain is NHWC it means all the node inputs are converted to NHWC format by the layout transformer.
|
||||
// pick the index for height and width based on the format.
|
||||
int h_idx = use_nchw ? 2 : 1;
|
||||
int w_idx = use_nchw ? 3 : 2;
|
||||
|
||||
if (inputs.size() == 3) { // we are using scales
|
||||
const auto& scales_name = inputs[2].node_arg.Name();
|
||||
const auto& scales_tensor = *initializers.at(scales_name);
|
||||
std::vector<uint8_t> unpacked_tensor;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor));
|
||||
const float* scales_data = reinterpret_cast<const float*>(unpacked_tensor.data());
|
||||
float scale_h = scales_data[2];
|
||||
float scale_w = scales_data[3];
|
||||
ORT_RETURN_IF_ERROR(
|
||||
shaper.ResizeUsingScales(input, scale_h, scale_w, use_nchw, output));
|
||||
shaper.ResizeUsingScales(input, scales_data[h_idx], scales_data[w_idx], use_nchw, output));
|
||||
} else { // we are using sizes
|
||||
const auto& sizes_name = inputs[3].node_arg.Name();
|
||||
const auto& sizes_tensor = *initializers.at(sizes_name);
|
||||
|
|
@ -2390,12 +2205,12 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor));
|
||||
const int64_t* sizes_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
|
||||
ORT_RETURN_IF_ERROR(
|
||||
shaper.ResizeUsingOutputSizes(input, SafeInt<uint32_t>(sizes_data[2]), SafeInt<uint32_t>(sizes_data[3]), use_nchw, output));
|
||||
shaper.ResizeUsingOutputSizes(input, SafeInt<uint32_t>(sizes_data[h_idx]), SafeInt<uint32_t>(sizes_data[w_idx]), use_nchw, output));
|
||||
}
|
||||
|
||||
const auto& output_shape = shaper[output];
|
||||
int32_t output_h = use_nchw ? output_shape[2] : output_shape[1];
|
||||
int32_t output_w = use_nchw ? output_shape[3] : output_shape[2];
|
||||
int32_t output_h = output_shape[h_idx];
|
||||
int32_t output_w = output_shape[w_idx];
|
||||
|
||||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
|
|
@ -2420,7 +2235,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
OperandType output_operand_type = operand_types.at(input);
|
||||
output_operand_type.SetDimensions(output_shape);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -2436,10 +2251,6 @@ class FlattenOpBuilder : public BaseOpBuilder {
|
|||
|
||||
Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
|
||||
auto input = node_unit.Inputs()[0].node_arg.Name();
|
||||
if (model_builder.IsOperandNHWC(input)) {
|
||||
// We want to transpose nhwc operand back to nchw before reshape
|
||||
ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input));
|
||||
}
|
||||
|
||||
// Flatten is basically a reshape to 2d tensor
|
||||
// Get the shape for Reshape here
|
||||
|
|
@ -2467,8 +2278,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
|
||||
static Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
|
||||
const std::string& input1, const std::string& input2,
|
||||
bool output_is_nhwc);
|
||||
const std::string& input1, const std::string& input2);
|
||||
};
|
||||
|
||||
/* static */ void MinMaxOpBuilder::CreateSharedOpBuilder(
|
||||
|
|
@ -2482,8 +2292,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
|
|||
}
|
||||
|
||||
/* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
|
||||
const std::string& input1, const std::string& input2,
|
||||
bool output_is_nhwc) {
|
||||
const std::string& input1, const std::string& input2) {
|
||||
auto& shaper(model_builder.GetShaper());
|
||||
const auto& operand_indices(model_builder.GetOperandIndices());
|
||||
const auto& operand_types(model_builder.GetOperandTypes());
|
||||
|
|
@ -2506,7 +2315,7 @@ class MinMaxOpBuilder : public BaseOpBuilder {
|
|||
ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output));
|
||||
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
{output}, {output_operand_type}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -2515,11 +2324,8 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
const auto& inputs = node_unit.Inputs();
|
||||
std::string input1 = inputs[0].node_arg.Name();
|
||||
std::string input2 = inputs[1].node_arg.Name();
|
||||
bool output_is_nhwc = false;
|
||||
ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node_unit,
|
||||
input1, input2, output_is_nhwc));
|
||||
|
||||
return AddMinMaxOperator(model_builder, node_unit, input1, input2, output_is_nhwc);
|
||||
return AddMinMaxOperator(model_builder, node_unit, input1, input2);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
|
@ -2537,7 +2343,6 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
|
|||
const auto& operand_types(model_builder.GetOperandTypes());
|
||||
const auto& input = node_unit.Inputs()[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
NodeAttrHelper helper(node_unit);
|
||||
|
|
@ -2546,7 +2351,7 @@ Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No
|
|||
input_indices.push_back(operand_indices.at(input));
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, alpha);
|
||||
return model_builder.AddOperation(ANEURALNETWORKS_ELU, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc});
|
||||
{output}, {output_operand_type});
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
|
@ -2644,7 +2449,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
|
||||
const auto& input = inputs[0].node_arg.Name();
|
||||
const auto& output = node_unit.Outputs()[0].node_arg.Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
// No shape inference for Slice, everything is calculated here, we only need to add the output shape
|
||||
// to the shaper
|
||||
|
|
@ -2718,7 +2522,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // end_mask
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, 0); // shrink_axis_mask
|
||||
}
|
||||
return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}, {output_is_nhwc});
|
||||
return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type});
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
|
|
|||
|
|
@ -72,6 +72,10 @@ static bool HasExternalInitializer(const InitializedTensorSet& initializers, con
|
|||
return false;
|
||||
}
|
||||
|
||||
inline bool IsNodeLayoutNHWC(const NodeUnit& node_unit) {
|
||||
return node_unit.Domain() == kMSInternalNHWCDomain;
|
||||
}
|
||||
|
||||
static bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers,
|
||||
const NodeUnitIODef& io_def,
|
||||
const OpSupportCheckParams& params,
|
||||
|
|
@ -1768,7 +1772,7 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
}
|
||||
const float* scales_data = reinterpret_cast<const float*>(unpacked_tensor.data());
|
||||
float scale_n = scales_data[0];
|
||||
float scale_c = scales_data[1];
|
||||
float scale_c = IsNodeLayoutNHWC(node_unit) ? scales_data[3] : scales_data[1];
|
||||
if (scale_n != 1.0f || scale_c != 1.0f) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Scales of N/C channel should be 1"
|
||||
<< "Resize of N/C channels are not supported"
|
||||
|
|
@ -1785,14 +1789,16 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
LOGS_DEFAULT(ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage();
|
||||
return false;
|
||||
}
|
||||
|
||||
int channel_idx = IsNodeLayoutNHWC(node_unit) ? 3 : 1;
|
||||
const int64_t* sizes_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
|
||||
uint32_t size_n = SafeInt<uint32_t>(sizes_data[0]);
|
||||
uint32_t size_c = SafeInt<uint32_t>(sizes_data[1]);
|
||||
if (size_n != input_shape[0] || size_c != input_shape[1]) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C chanel should match the input sizes, "
|
||||
uint32_t size_c = SafeInt<uint32_t>(sizes_data[channel_idx]);
|
||||
if (size_n != input_shape[0] || size_c != input_shape[channel_idx]) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Output sizes of N/C channel should match the input sizes, "
|
||||
<< "Resize of N/C channels are not supported"
|
||||
<< ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n
|
||||
<< ". input_size_c, " << input_shape[1] << ", output_size_c, " << size_c;
|
||||
<< ". input_size_c, " << input_shape[channel_idx] << ", output_size_c, " << size_c;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
};
|
||||
|
||||
result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
|
||||
gen_metadef_name, NNAPI);
|
||||
gen_metadef_name, NNAPI, kNnapiExecutionProvider);
|
||||
|
||||
const auto num_of_partitions = result.size();
|
||||
const auto num_of_supported_nodes = std::transform_reduce(
|
||||
|
|
@ -203,6 +203,10 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
return result;
|
||||
}
|
||||
|
||||
DataLayout NnapiExecutionProvider::GetPreferredLayout() const {
|
||||
return nnapi_flags_ & NNAPI_FLAG_USE_NCHW ? DataLayout::NCHW : DataLayout::NHWC;
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
static Status GetOutputBuffer(Ort::CustomOpApi& ort,
|
||||
OrtKernelContext* context,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ class NnapiExecutionProvider : public IExecutionProvider {
|
|||
|
||||
uint32_t GetNNAPIFlags() const { return nnapi_flags_; }
|
||||
|
||||
DataLayout GetPreferredLayout() const override;
|
||||
|
||||
private:
|
||||
// The bit flags which define bool options for NNAPI EP, bits are defined as
|
||||
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
const GraphViewer& graph_viewer,
|
||||
const IsNodeSupportedFn& is_node_supported_fn,
|
||||
const OnGroupClosedFn& on_group_closed_fn,
|
||||
const std::string& execution_provider_type,
|
||||
bool debug_output) {
|
||||
#ifdef NDEBUG
|
||||
ORT_UNUSED_PARAMETER(debug_output);
|
||||
|
|
@ -104,7 +105,8 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
std::deque<const Node*> nodes_to_process_with_next_group{};
|
||||
|
||||
// initialize in-degrees and find root nodes
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto& node = *graph_viewer.GetNode(node_index);
|
||||
const auto node_input_edge_count = node.GetInputEdgesCount();
|
||||
in_degree.insert({node.Index(), node_input_edge_count});
|
||||
if (node_input_edge_count == 0) {
|
||||
|
|
@ -156,9 +158,9 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
const Node& node = *nodes_to_process.front();
|
||||
nodes_to_process.pop_front();
|
||||
|
||||
// a node that is already assigned to an EP other than current EP is unsupported
|
||||
const bool is_node_supported =
|
||||
node.GetExecutionProviderType().empty() && // a node that is already assigned to an EP is unsupported
|
||||
is_node_supported_fn(node);
|
||||
(node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node);
|
||||
|
||||
if (!is_node_supported && Contains(supported_group_border, &node)) {
|
||||
// an unsupported node on the border will be processed after the current partition node group
|
||||
|
|
@ -203,9 +205,7 @@ std::unordered_set<const Node*> CreateExcludedNodeSet(const GraphViewer& graph_v
|
|||
const std::unordered_set<std::string>& stop_ops) {
|
||||
std::unordered_set<const Node*> excluded_nodes;
|
||||
|
||||
for (const NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const Node& node = *graph_viewer.GetNode(node_index);
|
||||
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
if (!Contains(excluded_nodes, &node) && Contains(stop_ops, node.OpType())) {
|
||||
excluded_nodes.insert(&node);
|
||||
|
||||
|
|
@ -309,10 +309,12 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const OnGroupClosedFn& on_partition_closed_fn,
|
||||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
bool debug_output) {
|
||||
const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer,
|
||||
is_node_supported_fn,
|
||||
on_partition_closed_fn,
|
||||
execution_provider_type,
|
||||
debug_output);
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> partitions{};
|
||||
|
|
@ -335,6 +337,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const std::unordered_set<std::string>& stop_ops,
|
||||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
bool debug_output) {
|
||||
const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops);
|
||||
const bool check_excluded_nodes = !excluded_nodes.empty();
|
||||
|
|
@ -348,6 +351,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
{},
|
||||
generate_metadef_name_fn,
|
||||
execution_provider_name,
|
||||
execution_provider_type,
|
||||
debug_output);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ Create the supported partitions for the execution provider.
|
|||
@param on_group_closed_fn Callback to indicate a completed partition node group.
|
||||
@param generate_metadef_name_fn Callback to create the name for the MetaDef.
|
||||
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
|
||||
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
|
||||
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
|
||||
No-op in a release build.
|
||||
|
||||
|
|
@ -65,6 +66,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const OnGroupClosedFn& on_group_closed_fn,
|
||||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
bool debug_output = false);
|
||||
|
||||
/**
|
||||
|
|
@ -75,6 +77,7 @@ Create the supported partitions for the execution provider.
|
|||
@param stop_ops Set of operator names at which we stop considering nodes for assignment to this execution provider.
|
||||
@param generate_metadef_name Functor to create the name for the MetaDef.
|
||||
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
|
||||
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
|
||||
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
|
||||
No-op in a release build.
|
||||
|
||||
|
|
@ -86,6 +89,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const std::unordered_set<std::string>& stop_ops,
|
||||
const GenerateMetadefNameFn& generate_metadef_name,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
bool debug_output = false);
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ struct KernelRegistry;
|
|||
struct Function;
|
||||
struct Graph;
|
||||
struct GraphViewer;
|
||||
enum class DataLayout;
|
||||
struct Model;
|
||||
struct Path;
|
||||
struct Node;
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
|
|||
}
|
||||
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSExperimentalDomain, 1, 1);
|
||||
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1);
|
||||
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSInternalNHWCDomain, 1, 1);
|
||||
#ifdef USE_DML
|
||||
domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/ort_value_pattern_planner.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/framework/static_kernel_def_hashes.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/optimizer/graph_transformer_utils.h"
|
||||
|
|
@ -40,6 +41,7 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
|
||||
#include "core/optimizer/transformer_memcpy.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
#include "core/platform/Barrier.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
|
@ -913,7 +915,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
// Do partitioning based on execution providers' capability.
|
||||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
session_state.GetMutableFuncMgr(), mode));
|
||||
session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayout, mode));
|
||||
|
||||
// apply transformers except default transformers
|
||||
// Default transformers are required for correctness and they are owned and run by inference session
|
||||
|
|
@ -1138,6 +1141,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
|||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayout,
|
||||
GraphPartitioner::Mode::kOrtFormatLoad,
|
||||
&compiled_kernel_hashes));
|
||||
|
||||
|
|
@ -1211,6 +1215,17 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs
|
|||
graph.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
|
||||
ORT_RETURN_IF_ERROR(set_node_ep(node_index, kernel_def_hash));
|
||||
}
|
||||
|
||||
// layout transformer which is enabled in extended minimal build can add new nodes.
|
||||
// The following loop fetches the hash values for these nodes.
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
if (node.GetExecutionProviderType().empty()) {
|
||||
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
|
||||
if (kernel_hash.has_value()) {
|
||||
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_IN_MINIMAL_BUILD)
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -1236,7 +1251,7 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
|
|||
}
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
//VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
|
||||
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
|
||||
#pragma warning(disable : 26117)
|
||||
#endif
|
||||
common::Status InferenceSession::Initialize() {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace std;
|
||||
|
|
@ -142,7 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayout);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
|
||||
|
|
@ -207,7 +209,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
|
|||
|
||||
// Partition the graph
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayout);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
// Finalize the session state
|
||||
|
|
@ -256,7 +259,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
|
|||
|
||||
// Partition the graph
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayout);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
// Finalize the session state
|
||||
|
|
|
|||
|
|
@ -689,14 +689,15 @@ class ResizeOpTester : public OpTester {
|
|||
OpTester::AddNodes(graph, graph_input_defs, graph_output_defs, add_attribute_funcs);
|
||||
|
||||
// set the Graph inputs to just X and roi (exclude 'scales') so the 'scales' are a constant initializer
|
||||
if (scales_in_initializer_) {
|
||||
// this isn't intended to work with a scenario where the optional 'sizes' input is provided
|
||||
ASSERT_TRUE(graph_input_defs.size() == 3);
|
||||
if(scales_in_initializer_) {
|
||||
graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()),
|
||||
graph.GetNodeArg(graph_input_defs[1]->Name())});
|
||||
}
|
||||
|
||||
if (sizes_in_initializer_) {
|
||||
if(sizes_in_initializer_) {
|
||||
ASSERT_TRUE(graph_input_defs.size() == 4);
|
||||
} else {
|
||||
ASSERT_TRUE(graph_input_defs.size() == 3);
|
||||
}
|
||||
} else if (sizes_in_initializer_) {
|
||||
ASSERT_TRUE(graph_input_defs.size() == 4); // 'sizes' is 4th input
|
||||
graph.SetInputs({graph.GetNodeArg(graph_input_defs[0]->Name()),
|
||||
graph.GetNodeArg(graph_input_defs[1]->Name()),
|
||||
|
|
@ -744,9 +745,10 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Scales) {
|
|||
|
||||
TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) {
|
||||
auto run_test = [](bool sizes_in_initializer) {
|
||||
ResizeOpTester test(false, sizes_in_initializer);
|
||||
ResizeOpTester test(sizes_in_initializer, sizes_in_initializer);
|
||||
|
||||
std::vector<float> roi{};
|
||||
std::vector<float> scales{};
|
||||
std::vector<int64_t> sizes{1, 1, 4, 4};
|
||||
|
||||
test.AddAttribute("mode", "nearest");
|
||||
|
|
@ -760,7 +762,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Sizes) {
|
|||
|
||||
test.AddInput<float>("X", {N, C, H, W}, X);
|
||||
test.AddInput<float>("roi", {0}, roi);
|
||||
test.AddInput<float>("scales", {0}, {});
|
||||
test.AddInput<float>("scales", {0}, scales, sizes_in_initializer);
|
||||
test.AddInput<int64_t>("sizes", {4}, sizes, sizes_in_initializer);
|
||||
|
||||
std::vector<float> Y = {1.0f, 1.0f, 2.0f, 2.0f,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/graph/model.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
|
||||
#include <queue>
|
||||
|
||||
|
|
@ -22,12 +23,14 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP";
|
|||
|
||||
InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops,
|
||||
const std::unordered_set<std::string>& stop_ops,
|
||||
bool debug_output)
|
||||
bool debug_output,
|
||||
DataLayout preferred_layout)
|
||||
: IExecutionProvider{utils::kInternalTestingExecutionProvider, true},
|
||||
ep_name_{INTERNAL_TESTING_EP},
|
||||
ops_{ops},
|
||||
stop_ops_{stop_ops},
|
||||
debug_output_{debug_output} {
|
||||
debug_output_{debug_output},
|
||||
preferred_layout_{preferred_layout} {
|
||||
//
|
||||
// TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through
|
||||
// the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first
|
||||
|
|
@ -44,6 +47,10 @@ InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::un
|
|||
|
||||
InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {}
|
||||
|
||||
DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const {
|
||||
return preferred_layout_;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& /*registries*/) const {
|
||||
|
|
@ -89,7 +96,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
|
|||
};
|
||||
|
||||
return utils::CreateSupportedPartitions(graph_viewer, supported_nodes, stop_ops_,
|
||||
generate_metadef_name, ep_name_, debug_output_);
|
||||
generate_metadef_name, ep_name_, onnxruntime::utils::kInternalTestingExecutionProvider, debug_output_);
|
||||
}
|
||||
|
||||
common::Status InternalTestingExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
|
|
@ -99,14 +106,19 @@ common::Status InternalTestingExecutionProvider::Compile(const std::vector<Fused
|
|||
NodeComputeInfo compute_info;
|
||||
const Node& node = node_and_viewer.fused_node;
|
||||
|
||||
//{
|
||||
// const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
|
||||
// std::cout << "Fusing nodes: ";
|
||||
// for (const auto& unfused_node : graph_viewer.Nodes()) {
|
||||
// std::cout << " '" << unfused_node.Name() << "':" << unfused_node.Index();
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//}
|
||||
if (preferred_layout_ == DataLayout::NHWC) {
|
||||
const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
|
||||
auto layout_sensitive_ops = layout_transformer::GetORTLayoutSensitiveOps();
|
||||
for (const auto& unfused_node : graph_viewer.Nodes()) {
|
||||
std::cout << unfused_node.OpType() << std::endl;
|
||||
if (layout_sensitive_ops.count(unfused_node.OpType()) && unfused_node.Domain() != kMSInternalNHWCDomain) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Found a layout sensitive op which is still in NCHW format. Node: ",
|
||||
unfused_node.OpType(), " ", unfused_node.Name(),
|
||||
" The preferrd layout for this EP is NHWC. This is a possible bug in layout transformer.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) {
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
|
|||
public:
|
||||
InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops,
|
||||
const std::unordered_set<std::string>& stop_ops = {},
|
||||
bool debug_output = false);
|
||||
bool debug_output = false,
|
||||
DataLayout preferred_layout = static_cast<DataLayout>(0));
|
||||
virtual ~InternalTestingExecutionProvider();
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
|
|
@ -24,6 +25,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
|
|||
return FusionStyle::FilteredGraphViewer;
|
||||
}
|
||||
|
||||
DataLayout GetPreferredLayout() const override;
|
||||
|
||||
private:
|
||||
const std::string ep_name_;
|
||||
|
||||
|
|
@ -38,5 +41,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
|
|||
const std::unordered_set<std::string> stop_ops_;
|
||||
|
||||
const bool debug_output_;
|
||||
|
||||
DataLayout preferred_layout_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ namespace test {
|
|||
// it would be possible to use ORT format models but the same partitioning code would run either way
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#define ORT_MODEL_FOLDER ORT_TSTR("testdata/")
|
||||
// model has an unsupported node between the supported nodes after the initial topo sort.
|
||||
// the partition aware topo sort should result in the unsupported node moving to earlier in the order,
|
||||
// and allow a single partition of supported nodes to be created.
|
||||
|
|
@ -44,7 +45,7 @@ TEST(InternalTestingEP, TestSortResultsInSinglePartition) {
|
|||
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
|
||||
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
|
||||
|
||||
const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_1.onnx");
|
||||
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_1.onnx";
|
||||
ASSERT_STATUS_OK(session->Load(model_path));
|
||||
const auto& graph = session->GetGraph();
|
||||
GraphViewer viewer{graph};
|
||||
|
|
@ -89,7 +90,7 @@ TEST(InternalTestingEP, TestDependenciesCorrectlyHandled) {
|
|||
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
|
||||
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
|
||||
|
||||
const ORTCHAR_T* model_path = ORT_TSTR("testdata/ep_partitioning_test_2.onnx");
|
||||
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "ep_partitioning_test_2.onnx";
|
||||
ASSERT_STATUS_OK(session->Load(model_path));
|
||||
const auto& graph = session->GetGraph();
|
||||
GraphViewer viewer{graph};
|
||||
|
|
@ -195,7 +196,7 @@ static void TestNnapiPartitioning(const std::string& test_name, const std::strin
|
|||
}
|
||||
|
||||
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
|
||||
std::make_unique<InternalTestingExecutionProvider>(ops, stop_ops, debug_output)));
|
||||
std::make_unique<InternalTestingExecutionProvider>(ops, stop_ops, debug_output, DataLayout::NHWC)));
|
||||
|
||||
ASSERT_STATUS_OK(session->Load(model_uri));
|
||||
const auto& graph = session->GetGraph();
|
||||
|
|
@ -310,6 +311,7 @@ TEST(InternalTestingEP, DISABLED_TestNnapiPartitioningMlPerfModels) {
|
|||
"deeplabv3_mnv2_ade20k_float.onnx",
|
||||
"mobilebert.onnx",
|
||||
"mobiledet.onnx",
|
||||
|
||||
};
|
||||
|
||||
for (const auto& model_uri : model_paths) {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ namespace onnxruntime {
|
|||
|
||||
namespace test {
|
||||
|
||||
#define ORT_MODEL_FOLDER ORT_TSTR("testdata/")
|
||||
|
||||
static void CreateSession(const SessionOptions& so, std::unique_ptr<InferenceSessionWrapper>& session,
|
||||
const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model
|
||||
const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "mnist.onnx", // arbitrary test model
|
||||
bool enable_custom_ep = true,
|
||||
const std::unordered_set<std::string>* override_supported_ops = nullptr) {
|
||||
session = std::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
|
||||
|
|
@ -86,7 +88,7 @@ static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enable
|
|||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.test_output.ort");
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.test_output.ort";
|
||||
|
||||
//
|
||||
// First load the onnx format model and save as an ORT model.
|
||||
|
|
@ -130,7 +132,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
|
|||
}
|
||||
|
||||
TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
|
||||
|
||||
// make sure we can't save a model with compiled ops. input/output model format doesn't matter
|
||||
SessionOptions so;
|
||||
|
|
@ -152,7 +154,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
|
|||
|
||||
// test to validate a minimal build
|
||||
TEST(InternalTestingEP, TestLoadOrtModel) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
bool enable_custom_ep = true;
|
||||
|
|
@ -164,7 +166,7 @@ TEST(InternalTestingEP, TestLoadOrtModel) {
|
|||
// test that is the custom EP cannot take all nodes due to device limitations
|
||||
// that we fallback to the CPU implementations and can execute the model
|
||||
TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
|
||||
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/};
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
|
|
@ -225,7 +227,7 @@ static int CountAndValidateAssignedNodes(const Graph& current_graph,
|
|||
// Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs.
|
||||
// There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels.
|
||||
TEST(InternalTestingEP, TestModelWithSubgraph) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "ort_github_issue_4031.onnx.ort";
|
||||
const std::unordered_set<std::string> supported_ops{"Add"};
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
|
|
@ -302,8 +304,11 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
|
|||
// In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected
|
||||
// So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider
|
||||
// But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node
|
||||
// This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.internal_testing_ep.ort");
|
||||
// Post layout transformations we cannot revert back if compile fails because
|
||||
// the layout transformation for this EP is already done at this stage and reverting
|
||||
// can result in more failures.
|
||||
// This is to test the model initialization fails if compile fails.
|
||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort";
|
||||
|
||||
const std::unordered_set<std::string>& supported_ops{"Conv", "Gemm"};
|
||||
const std::unordered_set<std::string>& compile_failure_ops{"Gemm"};
|
||||
|
|
@ -325,33 +330,12 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
|
|||
}
|
||||
|
||||
// Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm"
|
||||
// We should have 2 partitions taken by the EP
|
||||
// 2 Conv
|
||||
{
|
||||
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
|
||||
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
|
||||
std::make_unique<CompileFailureTestExecutionProvider>(supported_ops, compile_failure_ops)));
|
||||
ASSERT_STATUS_OK(session.Load(ort_model_path));
|
||||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
// 2 Conv nodes shoule be replaced with fused nodes
|
||||
const auto& graph = session.GetGraph();
|
||||
int num_replaced_nodes = CountAndValidateAssignedNodes(
|
||||
session.GetGraph(), {"Conv"}, const_cast<SessionState&>(session.GetSessionState()).GetMutableFuncMgr());
|
||||
|
||||
ASSERT_EQ(num_replaced_nodes, 2);
|
||||
|
||||
// The Gemm node should still not have been replaced
|
||||
int count_compile_failure_nodes = 0;
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end())
|
||||
count_compile_failure_nodes++;
|
||||
}
|
||||
ASSERT_EQ(count_compile_failure_nodes, 1);
|
||||
|
||||
// Execute the session, since the last node is Gemm, and its input 0 is all 0s
|
||||
// So the result should be the bias initializer of the Gemm node
|
||||
ExecuteMnist(session, true /* enable_custom_ep */);
|
||||
ASSERT_STATUS_NOT_OK(session.Initialize());
|
||||
}
|
||||
}
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@
|
|||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -189,5 +190,30 @@ TEST(KernelDefHashTest, ExpectedCpuKernelDefHashes) {
|
|||
CheckKernelDefHashes(cpu_kernel_def_hashes, expected_cpu_kernel_def_hashes, is_strict);
|
||||
}
|
||||
|
||||
// This test is to ensure the latest opset version for ops which can be added
|
||||
// during layout transformation step are added. IF this test fails then it means
|
||||
// there is a new version available for one of the ops in the map.
|
||||
// Adding this test here because resolution for this test failure requires fetching the hash
|
||||
// for one of the ops in the list below and this file has information around that.
|
||||
// Please update the following 3 places:
|
||||
// 1. api_impl.cc "onnx_ops_available_versions" map, include the latest version in the map
|
||||
// 2. static_kernel_def_hashes.cc "static_kernel_hashes" include an entry for latest version and it's associated hash
|
||||
// 3. This file "onnx_ops_available_versions" map, include the latest version in the map
|
||||
TEST(KernelDefHashTest, TestNewOpsVersionSupportDuringLayoutTransform) {
|
||||
static const std::unordered_map<std::string, std::vector<int>> onnx_ops_available_versions = {
|
||||
{"Squeeze", {1, 11, 13}},
|
||||
{"Unsqueeze", {1, 11, 13}},
|
||||
{"Gather", {1, 11, 13}},
|
||||
{"Transpose", {1, 13}},
|
||||
{"Identity", {1, 13, 14, 16}},
|
||||
};
|
||||
|
||||
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
|
||||
for (const auto& [op_type, version_list] : onnx_ops_available_versions) {
|
||||
auto schema = schema_registry->GetSchema(op_type, INT_MAX, kOnnxDomain);
|
||||
EXPECT_EQ(schema->SinceVersion(), version_list[version_list.size() - 1]) << "A new version for op: " << op_type
|
||||
<< "is available. Please update the files mentioned in the comments of this test.";
|
||||
}
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -315,12 +315,17 @@ TEST(NnapiExecutionProviderTest, TestQDQConv) {
|
|||
TEST(NnapiExecutionProviderTest, TestQDQResize) {
|
||||
// NNAPI EP does not support the default setting of Resize Op
|
||||
// Use bi-linear and asymmetric for NNAPI EP only
|
||||
// Setting verify_entire_graph_use_ep for this test as false. This is because layout transformation adds
|
||||
// Transpose (NCHW -> NHWC) nodes. Post tranformation graph looks like this Transpose -> DQ -> Resize -> Q -> Transpose
|
||||
// NNAPI does not pick the first Transpose as its input is graph/partition input
|
||||
// See https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc#L305
|
||||
// onnxruntime::nnapi::IsInternalQuantizationSupported
|
||||
RunQDQModelTest(BuildQDQResizeTestCase({1, 3, 64, 64} /* input_shape */,
|
||||
{1, 3, 32, 32} /* sizes_data */,
|
||||
"linear" /* mode */,
|
||||
"asymmetric" /* coordinate_transformation_mode */),
|
||||
"nnapi_qdq_test_graph_resize",
|
||||
{true /* verify_entire_graph_use_ep */});
|
||||
{false /* verify_entire_graph_use_ep */});
|
||||
}
|
||||
|
||||
TEST(NnapiExecutionProviderTest, TestQDQAveragePool) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue