Support EPs that compile nodes in a minimal build. (#5776)

* Support EPs that compile nodes in a minimal build. This enables NNAPI being used.
This commit is contained in:
Scott McKay 2020-11-17 13:52:22 +10:00 committed by GitHub
parent 794e8479eb
commit 7b76b57fc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 1866 additions and 505 deletions

View file

@ -120,6 +120,7 @@ option(onnxruntime_DISABLE_RTTI "Disable RTTI" OFF)
# For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone
option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF)
option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF)
option(onnxruntime_DISABLE_ORT_FORMAT_LOAD "Disable loading an ORT format model when onnxruntime_MINIMAL_BUILD=OFF (i.e. in a full build)." OFF)
@ -208,12 +209,21 @@ if(onnxruntime_USE_OPENMP)
endif()
endif()
# 'extended' implies minimal.
if (onnxruntime_EXTENDED_MINIMAL_BUILD AND NOT onnxruntime_MINIMAL_BUILD)
set(onnxruntime_MINIMAL_BUILD ON)
endif()
# ORT build with as much excluded as possible. Supports ORT flatbuffers models only.
# Will expose option in build.py when all pieces are available
if(onnxruntime_MINIMAL_BUILD)
if (onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_MINIMAL_BUILD)
add_compile_definitions(ENABLE_ORT_FORMAT_LOAD)
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
# enable EPs that compile kernels at runtime
add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD)
endif()
set(onnxruntime_REDUCED_OPS_BUILD ON)
if (NOT onnxruntime_ENABLE_PYTHON)

View file

@ -110,6 +110,7 @@ target_link_libraries(onnxruntime PRIVATE
${PROVIDERS_DML}
${PROVIDERS_ACL}
${PROVIDERS_ARMNN}
${PROVIDERS_INTERNAL_TESTING}
${onnxruntime_winml}
${PROVIDERS_ROCM}
onnxruntime_optimizer

View file

@ -10,7 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
if (onnxruntime_MINIMAL_BUILD)
file(GLOB onnxruntime_framework_src_exclude
"${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc"
"${ONNXRUNTIME_ROOT}/core/framework/graph_partitioner.*"
"${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h"
"${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc"
)

View file

@ -657,6 +657,10 @@ if (onnxruntime_USE_OPENVINO)
endif()
if (onnxruntime_USE_NNAPI_BUILTIN)
if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD)
message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'")
endif()
add_compile_definitions(USE_NNAPI=1)
# This is the minimum Android API Level required by ORT NNAPI EP to run
@ -782,7 +786,7 @@ if (onnxruntime_USE_DML)
else()
add_dependencies(${target} RESTORE_PACKAGES)
target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/DirectML.lib")
target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST)
target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST)
endif()
endfunction()

View file

@ -312,6 +312,13 @@ if (onnxruntime_USE_RKNPU)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src})
endif()
if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD)
file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/internal_testing/*"
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_internal_testing_src})
endif()
set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")
set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools")
set (ONNXRUNTIME_API_TESTS_WITHOUT_ENV_SRC_DIR "${ONNXRUNTIME_ROOT}/test/api_tests_without_env")
@ -525,7 +532,7 @@ else()
target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}
"${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public")
endif()
onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest onnx onnx_proto flatbuffers)
onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest GTest::gmock onnx onnx_proto flatbuffers)
if (onnxruntime_USE_DNNL)
target_compile_definitions(onnxruntime_test_utils PUBLIC USE_DNNL=1)

View file

@ -168,10 +168,13 @@ class IExecutionProvider {
void InsertAllocator(AllocatorPtr allocator);
void ReplaceAllocator(AllocatorPtr allocator);
// creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support
// compilation via GraphViewer instances.
#if !defined(ORT_MINIMAL_BUILD)
/**
Given a list of fused_node, return create_state/compute/release_state func for each node.
*/
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs);
/**
@ -181,9 +184,53 @@ class IExecutionProvider {
Compute_${node_name}
Release_State_${node_name}
*/
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::string& dll_path);
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
struct FusedNodeAndGraph {
const std::reference_wrapper<onnxruntime::Node> fused_node;
// GraphViewer that filters the full graph to the nodes that are covered by 'node'
const std::reference_wrapper<GraphViewer> filtered_graph;
};
/**
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
return create_state/compute/release_state func for each node.
@remarks This is an optional interface that is only needed if the execution provider compiles nodes
in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model.
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
as it is only valid for the duration of the call to Compile.
*/
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs);
#endif
// Fusion approach that is suppported
enum class FusionStyle {
// The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
// in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
// A GraphProto can be produced from the Node body.
Function,
// The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
// that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
// Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
// This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
// and can be supported in a minimal build.
FilteredGraphViewer
};
virtual FusionStyle GetFusionStyle() const {
// existing EPs use this mode so default to it.
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
// FilteredGraphViewer
return FusionStyle::Function;
}
void SetLogger(const logging::Logger* logger) {
logger_ = logger;
}

View file

@ -45,9 +45,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
common::Status GetFusedFuncs(ComputeFunc* compute,
CreateFunctionStateFunc* create,
DestroyFunctionStateFunc* release) const;
common::Status GetFusedFuncs(NodeComputeInfo*& compute_info) const;
private:
ORT_DISALLOW_MOVE(OpKernelInfo);

View file

@ -158,7 +158,7 @@ class Tensor final {
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ",
"T ", "!=", dtype_);
const T* data = reinterpret_cast<const T*>(static_cast<char*>(p_data_) + byte_offset_);
return gsl::make_span(data, shape_.Size());
return gsl::make_span(data, static_cast<typename gsl::span<T>::index_type>(shape_.Size()));
}
void* MutableDataRaw(MLDataType type) {

View file

@ -37,49 +37,50 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider";
constexpr const char *providers_available[] = {
kCpuExecutionProvider,
constexpr const char* providers_available[] = {
kCpuExecutionProvider,
#ifdef USE_CUDA
kCudaExecutionProvider,
kCudaExecutionProvider,
#endif
#ifdef USE_DNNL
kDnnlExecutionProvider,
kDnnlExecutionProvider,
#endif
#ifdef USE_NGRAPH
kNGraphExecutionProvider,
kNGraphExecutionProvider,
#endif
#ifdef USE_OPENVINO
kOpenVINOExecutionProvider,
kOpenVINOExecutionProvider,
#endif
#ifdef USE_NUPHAR
kNupharExecutionProvider,
kNupharExecutionProvider,
#endif
#ifdef USE_VITISAI
kVitisAIExecutionProvider,
kVitisAIExecutionProvider,
#endif
#ifdef USE_TENSORRT
kTensorrtExecutionProvider,
kTensorrtExecutionProvider,
#endif
#ifdef USE_NNAPI
kNnapiExecutionProvider,
kNnapiExecutionProvider,
#endif
#ifdef USE_RKNPU
kRknpuExecutionProvider,
kRknpuExecutionProvider,
#endif
#ifdef USE_DML
kDmlExecutionProvider,
kDmlExecutionProvider,
#endif
#ifdef USE_MIGRAPHX
kMIGraphXExecutionProvider,
kMIGraphXExecutionProvider,
#endif
#ifdef USE_ACL
kAclExecutionProvider,
kAclExecutionProvider,
#endif
#ifdef USE_ARMNN
kArmNNExecutionProvider,
kArmNNExecutionProvider,
#endif
#ifdef USE_ROCM
kRocmExecutionProvider,
kRocmExecutionProvider,
#endif
};

View file

@ -31,7 +31,7 @@ class Function {
/**
Create a new Function instance.
@param graph The graph containing the Function.
@param customized_func the IndexedSubGraph to use for the Function.
@param nodes_to_fuse the IndexedSubGraph to use for the Function.
*/
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
const IndexedSubGraph& nodes_to_fuse,

View file

@ -368,7 +368,9 @@ class Node {
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
/** Sets the execution ProviderType that this Node will be executed by. */
void SetExecutionProviderType(ProviderType execution_provider_type);
void SetExecutionProviderType(ProviderType execution_provider_type) {
execution_provider_type_ = execution_provider_type;
}
/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
@ -476,7 +478,7 @@ class Node {
Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {}
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void Init(const std::string& name,
const std::string& op_type,
const std::string& description,
@ -485,25 +487,24 @@ class Node {
const NodeAttributes* attributes,
const std::string& domain);
// create a Graph instance for an attribute that contains a GraphProto
void CreateSubgraph(const std::string& attr_name);
// internal only method to allow selected classes to directly alter the input/output definitions and arg counts
Definitions& MutableDefinitions() noexcept;
// internal only method to allow selected classes to directly alter the links between nodes.
Relationships& MutableRelationships() noexcept;
void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; }
#endif
// create a Graph instance for an attribute that contains a GraphProto
void CreateSubgraph(const std::string& attr_name);
const std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
void SetNodeType(Node::Type node_type) noexcept;
void SetFunctionBody(const Function& func);
// validate and update the input arg count
common::Status UpdateInputArgCount();
#endif // !defined(ORT_MINIMAL_BUILD)
void SetFunctionBody(const Function& func);
const Definitions& GetDefinitions() const noexcept { return definitions_; }
const Relationships& GetRelationships() const noexcept { return relationships_; }
@ -578,6 +579,15 @@ class Graph {
/** Gets the path of the owning model, if any. */
const Path& ModelPath() const;
/** Returns true if this is a subgraph or false if it is a high-level graph. */
bool IsSubgraph() const { return parent_graph_ != nullptr; }
/** Returns the parent graph if this is a subgraph */
const Graph* ParentGraph() const { return parent_graph_; }
/** Returns the mutable parent graph if this is a subgraph */
Graph* MutableParentGraph() { return parent_graph_; }
#if !defined(ORT_MINIMAL_BUILD)
/** Sets the Graph name. */
void SetName(const std::string& name);
@ -756,6 +766,15 @@ class Graph {
/** Generate a unique name in this Graph for a Node */
std::string GenerateNodeName(const std::string& base_name);
/** Copy a Node and add it to this Graph.
@param other Node to copy
@returns Reference to the Node that was created and added to this Graph.
@remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
*/
Node& AddNode(const Node& other);
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Add a Node to this Graph.
@param name The Node name. Must be unique in this Graph.
@param op_type The operator type. e.g. ONNX operator name.
@ -775,13 +794,6 @@ class Graph {
const NodeAttributes* attributes = nullptr,
const std::string& domain = "");
/** Copy a Node and add it to this Graph.
@param other Node to copy
@returns Reference to the Node that was created and added to this Graph.
@remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
*/
Node& AddNode(const Node& other);
/** Remove a Node from this Graph and free it.
The output edges of this specified node MUST have been removed before removing the node.
The input edges of this specified node is removed while removing the node. The process of
@ -809,14 +821,15 @@ class Graph {
@param dst_arg_index node arg index of destination node.
*/
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
#endif
#if !defined(ORT_MINIMAL_BUILD)
/**
Add a control edge between two Nodes in this Graph.
The source Node does not produce output that is directly consumed by the destination Node, however the
destination Node must execute after the source node. The control edge allows this ordering to occur.
*/
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index);
#endif // !defined(ORT_MINIMAL_BUILD)
/** Mark the Graph as needing Resolve() to be called.
@ -880,6 +893,7 @@ class Graph {
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node*, const Node*)>& stop) const;
#if !defined(ORT_MINIMAL_BUILD)
/** Performs topological sort with Kahn's algorithm on the graph/s.
@param enter Visit function that will be invoked on a node when it is visited.
@param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
@ -887,11 +901,29 @@ class Graph {
void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const;
#endif
/** Gets the map of operator domains to their opset versions. */
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
return domain_to_version_;
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/**
Create a single Node that will be the result of the a fusion of multiple nodes in this Graph.
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
@param fused_node_name The name for the new Node.
@returns Node with fused subgraph.
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
while this is in use.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
*/
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
#endif
#if !defined(ORT_MINIMAL_BUILD)
/** Gets the GraphProto representation of this Graph. */
const ONNX_NAMESPACE::GraphProto& ToGraphProto();
@ -901,10 +933,11 @@ class Graph {
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const;
/**
Create a single Node that is the result of the a fusion of multiple nodes in this Graph.
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph.
A new Graph instance will be created for the fused nodes.
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node
@param fused_node_name The name for the new Node.
@returns Node with fused subgraph.
@returns Function based Node with fused subgraph. The Node body will contain a Function instance.
*/
Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
@ -939,18 +972,7 @@ class Graph {
@remarks Note that the output order matters for subgraphs.
*/
void SetOutputs(const std::vector<const NodeArg*>& outputs);
#endif // !defined(ORT_MINIMAL_BUILD)
/** Returns true if this is a subgraph or false if it is a high-level graph. */
bool IsSubgraph() const { return parent_graph_ != nullptr; }
/** Returns the parent graph if this is a subgraph */
const Graph* ParentGraph() const { return parent_graph_; }
/** Returns the mutable parent graph if this is a subgraph */
Graph* MutableParentGraph() { return parent_graph_; }
#if !defined(ORT_MINIMAL_BUILD)
/** Sets the type of a NodeArg, replacing existing type/shape if any */
void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto);
@ -1209,12 +1231,6 @@ class Graph {
// Clear all unused initializers
void CleanUnusedInitializers(const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr);
gsl::not_null<Node*> AllocateNode();
// Release the node.
// @returns false if node_index was invalid.
bool ReleaseNode(NodeIndex node_index);
std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
const ArgNameToTypeMap& name_to_type_map);
@ -1247,6 +1263,16 @@ class Graph {
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
gsl::not_null<Node*> AllocateNode();
// Release the node.
// @returns false if node_index was invalid.
bool ReleaseNode(NodeIndex node_index);
Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
#endif
Node* NodeAtIndexImpl(NodeIndex node_index) const {
// if we are trying to access a node that doesn't exist there's (most
// likely) either a logic issue or a graph consistency/correctness issue.
@ -1276,12 +1302,10 @@ class Graph {
sparse_tensor_names_;
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;
std::vector<std::unique_ptr<onnxruntime::Function>> function_container_;
#endif // !defined(ORT_MINIMAL_BUILD)
#endif
// Graph nodes.
// Element in <nodes_> may be nullptr due to graph optimization.

View file

@ -28,21 +28,21 @@ class ValidNodes {
Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection
@param[in] nodes Nodes to iterate, skipping invalid entries.
*/
explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(nodes) {}
explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {}
explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept
: nodes_(nodes), filter_node_fn_{std::move(filter_node_fn)} {}
: nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {}
using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
ConstNodeIterator cbegin() const noexcept {
return {nodes_.cbegin(), nodes_.cend(), filter_node_fn_};
return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_};
}
ConstNodeIterator cend() const noexcept {
return {nodes_.cend(), nodes_.cend(), filter_node_fn_};
return {nodes_->cend(), nodes_->cend(), filter_node_fn_};
}
ConstNodeIterator begin() const noexcept {
@ -54,11 +54,11 @@ class ValidNodes {
}
ConstReverseNodeIterator rbegin() const noexcept {
return {nodes_.crbegin(), nodes_.crend(), filter_node_fn_};
return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_};
}
ConstReverseNodeIterator rend() const noexcept {
return {nodes_.crend(), nodes_.crend(), filter_node_fn_};
return {nodes_->crend(), nodes_->crend(), filter_node_fn_};
}
// we only allow mutable access if the container is non-const.
@ -66,16 +66,16 @@ class ValidNodes {
template <typename T2 = TNodesContainer>
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type begin() noexcept {
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
return MutableNodeIterator(nodes_.begin(), nodes_.end(), filter_node_fn_);
return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_);
}
template <typename T2 = TNodesContainer>
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type end() noexcept {
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
return MutableNodeIterator(nodes_.end(), nodes_.end(), filter_node_fn_);
return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_);
}
bool empty() const noexcept { return nodes_.empty(); }
bool empty() const noexcept { return nodes_->empty(); }
/**
@class NodeIterator
@ -152,7 +152,7 @@ class ValidNodes {
};
private:
TNodesContainer& nodes_;
gsl::not_null<TNodesContainer*> nodes_; // always set by ctor
// no filtering if not set. this instance owns the filter func if set.
NodeFilterFunc filter_node_fn_;

View file

@ -158,6 +158,11 @@ class GraphViewer {
}
#endif
/** Get the filter info that restricts the graph viewer to a subset of nodes if set.
@returns Filter info or nullptr
*/
const IndexedSubGraph* GetFilterInfo() const { return filter_info_; }
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);
GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info);

View file

@ -1,4 +1,5 @@
// Copyright 2019 JD.com Inc. JD AI
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnxruntime_c_api.h"

View file

@ -481,7 +481,7 @@ inline SessionOptions& SessionOptions::AddInitializer(const char* name, const Or
return *this;
}
inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions * options, OrtCUDAProviderOptions * cuda_options) {
inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions* options, OrtCUDAProviderOptions* cuda_options) {
ThrowOnError(GetApi().OrtSessionOptionsAppendExecutionProvider_CUDA(options, cuda_options));
return nullptr;
}
@ -943,7 +943,8 @@ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext*
return out;
}
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count) {
OrtValue* out;
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
return out;

View file

@ -43,9 +43,11 @@ IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
#else
// We have saved hashes to lookup static kernels in an ORT format model so the default behavior is to return an
// empty vector to leave that in place. An EP that compiles nodes can override this in a minimal build.
ORT_UNUSED_PARAMETER(graph);
ORT_UNUSED_PARAMETER(kernel_registries);
ORT_NOT_IMPLEMENTED("IExecutionProvider::GetCapability is not supported in this build.");
return result;
#endif
}
@ -79,6 +81,7 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
allocator_list_.push_back(allocator);
}
#if !defined(ORT_MINIMAL_BUILD)
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
@ -88,6 +91,14 @@ common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>
std::string& /*dll_path*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
#endif
std::shared_ptr<KernelRegistry> IExecutionProvider::GetKernelRegistry() const {
return nullptr;

View file

@ -18,33 +18,33 @@ class FunctionKernel : public OpKernel {
explicit FunctionKernel(const OpKernelInfo& info) : OpKernel(info) {
num_inputs_ = info.node().InputDefs().size();
num_outputs_ = info.node().OutputDefs().size();
CreateFunctionStateFunc create_func;
auto status = info.GetFusedFuncs(&func_, &create_func, &release_func_);
auto status = info.GetFusedFuncs(compute_info_);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
if (create_func) {
if (compute_info_->create_state_func) {
//TODO: we are only provide host allocate method in compute context.
//Do we need to hold the ref-counting here?
host_allocator_ = info.GetAllocator(0, OrtMemType::OrtMemTypeDefault);
ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(), info.node().Name().c_str()};
ORT_ENFORCE(create_func(&context, &func_state_) == 0);
ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(),
info.node().Name().c_str()};
ORT_ENFORCE(compute_info_->create_state_func(&context, &func_state_) == 0);
}
}
~FunctionKernel() override {
if (release_func_ && func_state_) {
release_func_(func_state_);
if (compute_info_->release_state_func && func_state_) {
compute_info_->release_state_func(func_state_);
}
}
virtual Status Compute(OpKernelContext* context) const override {
auto* context_internal = static_cast<OpKernelContextInternal*>(context);
return func_(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION), reinterpret_cast<OrtKernelContext*>(context_internal));
return compute_info_->compute_func(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION),
reinterpret_cast<OrtKernelContext*>(context_internal));
}
private:
ComputeFunc func_;
DestroyFunctionStateFunc release_func_;
FunctionState func_state_;
NodeComputeInfo* compute_info_{nullptr};
FunctionState func_state_{nullptr};
size_t num_inputs_;
size_t num_outputs_;
AllocatorPtr host_allocator_;

View file

@ -6,25 +6,28 @@ Status FuncManager::AddFuncInfo(const std::string& name, const std::string& dll_
auto it = fused_funcs_->find(name);
if (it != fused_funcs_->end())
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist.");
(*fused_funcs_)[name] = {dll_path, nullptr, nullptr, nullptr};
(*fused_funcs_)[name] = {dll_path, NodeComputeInfo()};
return Status::OK();
}
Status FuncManager::AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release) {
Status FuncManager::AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info) {
auto it = fused_funcs_->find(name);
if (it != fused_funcs_->end())
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist.");
if (!compute || !create || !release)
if (!compute_info.compute_func || !compute_info.create_state_func || !compute_info.release_state_func)
return Status(common::ONNXRUNTIME, common::FAIL, "Can't use func with null ptr");
(*fused_funcs_)[name] = {"", compute, create, release};
(*fused_funcs_)[name] = {"", std::move(compute_info)};
return Status::OK();
}
Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
Status FuncManager::GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const {
auto it = fused_funcs_->find(name);
if (it == fused_funcs_->end())
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " not found.");
if (!it->second.compute_func) {
if (!it->second.compute_info.compute_func) {
//load from path
void* handle = nullptr;
ORT_RETURN_IF_ERROR(lib_loader_->LoadExternalLib(it->second.dso_path, &handle));
@ -40,21 +43,20 @@ Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, Crea
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle,
kReleaseStateFuncSymbol + name,
&release_func_symbol_handle));
it->second.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
it->second.compute_info.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
return reinterpret_cast<ComputeFuncC>(compute_func_symbol_handle)(state, api, context);
};
it->second.create_state_func = [=](ComputeContext* context, FunctionState* state) {
it->second.compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
return reinterpret_cast<CreateFunctionStateC>(create_func_symbol_handle)(context, state);
};
it->second.release_state_func = [=](FunctionState state) {
it->second.compute_info.release_state_func = [=](FunctionState state) {
return reinterpret_cast<DestroyFunctionStateC>(release_func_symbol_handle)(state);
};
}
*compute = it->second.compute_func;
*create = it->second.create_state_func;
*release = it->second.release_state_func;
compute_info = &it->second.compute_info;
return Status::OK();
}

View file

@ -7,13 +7,16 @@ namespace onnxruntime {
class FuncManager {
public:
FuncManager() : fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(onnxruntime::make_unique<ExLibLoader>()) {}
FuncManager()
: fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()),
lib_loader_(onnxruntime::make_unique<ExLibLoader>()) {
}
Status AddFuncInfo(const std::string& name, const std::string& dll_path);
Status AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release);
Status AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info);
Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const;
Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const;
void SetFusedFuncs(const FuncManager& func_mgr) {
fused_funcs_ = func_mgr.fused_funcs_;
@ -21,9 +24,7 @@ class FuncManager {
struct FuncInfo {
std::string dso_path;
ComputeFunc compute_func;
CreateFunctionStateFunc create_state_func;
DestroyFunctionStateFunc release_state_func;
NodeComputeInfo compute_info;
};
private:

View file

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/framework/graph_partitioner.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/graph/function.h"
@ -41,13 +43,22 @@ NonCudaOps non_cuda;
using namespace ::onnxruntime::common;
namespace onnxruntime {
static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
// minimal KernelDef based on MetaDef instead of a Function based node
static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef,
const std::string& provider_type) {
builder.SetName(metadef.name)
.SetDomain(metadef.domain)
.SinceVersion(metadef.since_version)
.Provider(provider_type);
}
#if !defined(ORT_MINIMAL_BUILD)
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
auto schema = node.Op();
builder.SetName(schema->Name())
.SetDomain(schema->domain())
.SinceVersion(schema->SinceVersion())
.Provider(node.GetExecutionProviderType());
return builder;
}
/**
@ -57,12 +68,16 @@ static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const on
* \param capability
* \param kernel_registry_mgr
* \param provider_type name of the provider to test
* \param count A counter for generating fused node names. Should be unique within this subgraph
* \param count A counter for generating fused node names. Unique across the entire model.
* \return Fused node. Return nullptr if there is no fuse
*/
static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type,
int& count) {
IExecutionProvider::FusionStyle fusion_style,
GraphPartitioner::Mode mode,
int& fused_node_unique_id) {
Node* result = nullptr;
if (nullptr == capability.GetMetaDef()) {
// The <provider> can run a single node in the <graph> if not using meta-defs.
// A fused kernel is not supported in this case.
@ -75,39 +90,82 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
}
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.
ORT_ENFORCE(nullptr != capability.GetMetaDef());
// Check whether any node in the <sub_graph> was already assigned.
// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
// in order of EP priority
bool sub_graph_available_for_assignment = true;
for (auto node_index : capability.nodes) {
auto node = graph.GetNode(node_index);
const auto* node = graph.GetNode(node_index);
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
break;
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
// should never be on those lists.
//
// when the ORT format model is loaded we will process it normally with EP priority being applied for
// whichever EPs are enabled at the time.
//
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
// so we want all the nodes that either may take to be preserved. If we did not do this we would
// need to create one ORT format model for Android and one for iOS.
if (mode != GraphPartitioner::Mode::kAssignOnly) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
break;
}
}
}
if (sub_graph_available_for_assignment) {
std::ostringstream oss;
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << count++;
std::string node_name = oss.str();
auto& fused_node = graph.FuseSubGraph(capability, node_name);
fused_node.SetExecutionProviderType(provider_type);
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, fused_node, provider_type)) {
return &fused_node;
if (mode == GraphPartitioner::Mode::kNormal) {
std::ostringstream oss;
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++;
std::string node_name = oss.str();
Node* fused_node = nullptr;
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
fused_node = &graph.FuseSubGraph(capability, node_name);
} else {
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
// through to Compile via a filtered GraphViewer.
fused_node = &graph.BeginFuseSubGraph(capability, node_name);
}
fused_node->SetExecutionProviderType(provider_type);
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) {
result = fused_node;
}
} else {
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
// at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device
// capabilities.
for (auto node_index : capability.nodes) {
auto* node = graph.GetNode(node_index);
if (node != nullptr) {
node->SetExecutionProviderType(provider_type);
}
}
}
}
}
return nullptr;
return result;
}
// for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up).
// assign any nodes to the EP that are currently unassigned, and that the EP can handle.
static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistryManager& kernel_registry_mgr,
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep) {
static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistryManager& kernel_registry_mgr,
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep,
GraphPartitioner::Mode mode,
int& fused_node_unique_id) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
@ -119,8 +177,8 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
// we pass through the export_dll value and FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, fused_kernel_registry,
current_ep));
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id));
}
}
@ -134,65 +192,135 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr
// TODO: when the graph contain a function node, and user pass in the dll which could
// run the function by SessionOption, we should create a function kernel for it and
// delegate the compute to the functions inside the dlls.
int count = 0;
std::vector<Node*> nodes_need_compile;
const std::string& type = current_ep.Type();
auto fusion_style = current_ep.GetFusionStyle();
std::vector<Node*> nodes_to_compile;
GraphViewer graph_viewer(graph);
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
// nodes_to_compile.
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
capabilities_to_compile.reserve(std::count_if(capabilities.cbegin(), capabilities.cend(),
[](const std::unique_ptr<ComputeCapability>& entry) {
return entry != nullptr &&
entry->sub_graph != nullptr &&
entry->sub_graph->GetMetaDef() != nullptr;
}));
for (auto& capability : capabilities) {
if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value...
continue;
}
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, current_ep.Type(), count);
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id);
if (n != nullptr) {
nodes_need_compile.push_back(n);
nodes_to_compile.push_back(n);
capabilities_to_compile.push_back(std::move(capability));
}
}
if (!nodes_need_compile.empty()) {
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
if (!nodes_to_compile.empty()) {
std::vector<NodeComputeInfo> node_compute_funcs;
if (export_dll) {
std::string dll_path;
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, dll_path));
for (auto* node : nodes_need_compile) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
}
} else {
std::vector<NodeComputeInfo> node_compute_funcs;
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, node_compute_funcs));
ORT_ENFORCE(node_compute_funcs.size() == nodes_need_compile.size(),
"Provider did not return correct number of compiled functions");
for (size_t j = 0; j < nodes_need_compile.size(); j++) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_need_compile[j]->Name(), node_compute_funcs[j].compute_func,
node_compute_funcs[j].create_state_func,
node_compute_funcs[j].release_state_func));
}
ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function,
"Must use Function based fusion when exporting compiled nodes to dll.");
}
for (auto* node : nodes_need_compile) {
//prepare the func kernel
KernelDefBuilder builder;
BuildFusedKernelDef(builder, *node);
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder, static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
})));
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
// Create a Function based node where the fused nodes have a new Graph instance.
if (export_dll) {
std::string dll_path;
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path));
for (auto* node : nodes_to_compile) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
}
} else {
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
}
}
for (auto* node : nodes_to_compile) {
// add the KernelDef instances for the compiled nodes
KernelDefBuilder builder;
BuildFusedKernelDef(builder, *node);
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder,
static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
})));
}
} else {
// temporary storage for the GraphViewer for each IndexedSubGraph
std::vector<std::unique_ptr<GraphViewer>> viewers;
viewers.reserve(nodes_to_compile.size());
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
auto* node = nodes_to_compile[j];
const auto& cur_capability = *capabilities_to_compile[j];
viewers.push_back(onnxruntime::make_unique<GraphViewer>(graph, *cur_capability.sub_graph));
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{*node, *viewers.back()});
}
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
auto* node = nodes_to_compile[j];
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), std::move(node_compute_funcs[j])));
const auto& cur_capability = capabilities_to_compile[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
// create the func kernel for the name in the MetaDef. this is also the node name and that name that will
// used as the key in the FuncManager entry. We need the registry to own the KernelCreateInfo that is
// used by SessionState
KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder,
static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
})));
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
}
}
}
// if this is the main graph call Resolve to put the Graph back into a guaranteed good state
// TODO: If we fix Graph::FuseSubGraph to correctly update edges when replacing nodes with a fused node we can
// avoid this Resolve call.
// TODO: Graph::FuseSubGraph and Graph::FinalizeFuseSubGraph should now create valid edges so this call to
// Graph::Resolve should not be required. Need to test to validate that, especially if node being fused
// was a control flow node with its own subgraph as more than just the edges may need updating.
if (!graph.IsSubgraph()) {
ORT_RETURN_IF_ERROR(graph.Resolve());
}
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
//But we will insert cast op to run the model, so skip the error checking here.
//If after graph transform phase, the node still not assigned, we will report error
//during kernel creation phase.
// For some cases, like fp16 on cpu, right now we don't have any kernel support that.
// But we will insert cast op to run the model, so skip the error checking here.
// If after graph transform phase, the node still not assigned, we will report error
// during kernel creation phase.
#ifdef COUNT_NON_CUDA_OPS
for (auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType() != kCudaExecutionProvider &&
@ -231,12 +359,157 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
modified_graph = true;
}
ORT_RETURN_IF_ERROR(graph.Resolve());
return Status::OK();
}
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode,
int& fused_node_unique_id) const {
bool modified_graph = false;
do {
// process full graph with each EP
for (const auto& ep : providers_) {
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
fused_kernel_registry, *ep, mode, fused_node_unique_id));
}
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
modified_graph = false;
ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph));
// Resolve and rerun graph partitioning and inlining if there was a change
if (modified_graph) {
ORT_RETURN_IF_ERROR(graph.Resolve());
}
} while (modified_graph);
return Status::OK();
}
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const {
#endif // !defined(ORT_MINIMAL_BUILD)
static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
KernelRegistryManager& kernel_registry_mgr,
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep,
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
int& fused_node_unique_id) {
// recurse into nested graphs first to partition bottom up.
for (auto& node : graph.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry,
current_ep, compiled_kernel_hashes, fused_node_unique_id));
}
}
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
return Status::OK();
}
const std::string& type = current_ep.Type();
GraphViewer graph_viewer(graph);
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
// storage for the GraphViewer for each IndexedSubGraph
std::vector<std::unique_ptr<GraphViewer>> viewers;
viewers.reserve(capabilities.size());
for (auto& capability : capabilities) {
const IndexedSubGraph& indexed_sub_graph = *capability->sub_graph;
const IndexedSubGraph::MetaDef* metadef = indexed_sub_graph.GetMetaDef();
if (!metadef) {
// Static kernel - use the kernel hash that was saved in the ORT format model
continue;
}
std::ostringstream oss;
oss << type << "_" << metadef->name << "_" << fused_node_unique_id++;
std::string node_name = oss.str();
Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name);
fused_node.SetExecutionProviderType(type);
// create filtered graph viewer for this set of nodes
//
// TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing
// GraphViewer instance instead of the Graph (copying the topological order instead of recalculating).
viewers.push_back(onnxruntime::make_unique<GraphViewer>(graph, indexed_sub_graph));
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()});
}
std::vector<NodeComputeInfo> node_compute_funcs;
node_compute_funcs.reserve(nodes_and_viewers.size());
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
if (node_compute_funcs.size() != nodes_and_viewers.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) {
Node& node = nodes_and_viewers[j].fused_node;
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j])));
const auto& cur_capability = capabilities[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
auto kernel_def = builder.Build();
// save hash so SessionState can find the kernel. each kernel name should be unique
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
". Execution Provider must generate unique names across the entire model.");
}
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
}))));
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
}
return Status::OK();
}
// Simplified partitioning where custom EPs may produce compiled nodes.
// EPs with static kernels do not need to be processed as their kernels are matched via hash information serialized
// as part of the ORT format model.
Status GraphPartitioner::PartitionOrtFormatModel(
Graph& graph, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry,
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
int& fused_node_unique_id) const {
// process full graph with each EP
for (const auto& ep : providers_) {
if (ep->Type() == kCpuExecutionProvider) {
// hash for kernel is stored in session state for EPs that have pre-registered kernels
// (vs. runtime fused kernels) so nothing to do here.
continue;
}
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry,
*ep, compiled_kernel_hashes, fused_node_unique_id));
}
return Status::OK();
}
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode,
std::unordered_map<std::string, uint64_t>* compiled_kernel_hashes) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
// 1. Execution providers' capabilities are checked one by one.
// 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet.
@ -253,21 +526,23 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
// It is only visible for current session.
std::shared_ptr<KernelRegistry> fused_kernel_registry = std::make_shared<KernelRegistry>();
bool modified_graph = false;
do {
// process full graph with each EP
for (const auto& ep : providers_) {
ORT_RETURN_IF_ERROR(PartitionImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
*fused_kernel_registry, *ep));
}
// we make sure each fused node name is unique across the entire model for clarity
int fused_node_unique_id = 0;
modified_graph = false;
// expand any nodes that have an ONNX function definition
// but no matching ORT kernel
ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph));
// rerun graph partition to assign nodes added as part of
// function expansion.
} while (modified_graph);
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
fused_node_unique_id));
#else
ORT_UNUSED_PARAMETER(export_dll);
ORT_THROW("Not supported in this build.");
#endif
} else {
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
fused_node_unique_id));
}
if (!fused_kernel_registry->IsEmpty()) {
kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry);
@ -276,3 +551,5 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
return Status::OK();
}
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -3,6 +3,8 @@
#pragma once
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/common/common.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/op_kernel.h"
@ -11,21 +13,43 @@
namespace onnxruntime {
class ExecutionProviders;
class KernelRegistry;
class KernelRegistryManager;
class GraphPartitioner {
public:
enum class Mode {
kNormal = 0,
kAssignOnly = 1, // assign nodes. no call to Compile. used to create ORT format model support for compiling EPs
kOrtFormatLoad = 2 // loading ORT format model. Partition with compiling EPs, GraphViewer based Compile.
};
//The order of providers represents the user preference.
GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers)
: kernel_registry_mgr_(kernel_registry_mgr),
providers_(providers) {}
providers_(providers) {
}
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const;
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
Mode mode = Mode::kNormal,
std::unordered_map<std::string, uint64_t>* compiled_kernel_hashes = nullptr) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);
#if !defined(ORT_MINIMAL_BUILD)
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const;
#endif
Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry,
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
int& fused_node_unique_id) const;
KernelRegistryManager& kernel_registry_mgr_;
const ExecutionProviders& providers_;
};
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -45,14 +45,16 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio
return Status::OK();
}
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry) {
if (nullptr == kernel_registry) {
return;
}
custom_kernel_registries_.push_front(kernel_registry);
}
#endif
#if !defined(ORT_MINIMAL_BUILD)
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
std::vector<const KernelRegistry*> kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
@ -84,10 +86,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node
return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. "));
}
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
for (auto& registry : custom_kernel_registries_) {
status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info);
if (status.IsOK()) return status;
if (status.IsOK()) {
return status;
}
}
#endif

View file

@ -32,8 +32,7 @@ class KernelRegistryManager {
// Register kernels from providers
Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT;
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// The registry passed in this function has highest priority than anything already in this KernelRegistryManager,
// and anything registered from RegisterKernels
// For example, if you do:
@ -43,16 +42,6 @@ class KernelRegistryManager {
// Then B > A > providers
void RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry);
// This function assumes the node is already assigned to an execution provider
// Don't call this function before graph partition is done
Status SearchKernelRegistry(const onnxruntime::Node& node,
/*out*/ const KernelCreateInfo** kernel_create_info) const;
/**
* Whether this node can be run on this provider
*/
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
/**
* Search kernel registry by provider type.
* @param type provider type string
@ -70,6 +59,18 @@ class KernelRegistryManager {
}
#endif
#if !defined(ORT_MINIMAL_BUILD)
// This function assumes the node is already assigned to an execution provider
// Don't call this function before graph partition is done
Status SearchKernelRegistry(const onnxruntime::Node& node,
/*out*/ const KernelCreateInfo** kernel_create_info) const;
/**
* Whether this node can be run on this provider
*/
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
#endif
Status SearchKernelRegistry(const onnxruntime::Node& node,
uint64_t kernel_def_hash,
/*out*/ const KernelCreateInfo** kernel_create_info) const;

View file

@ -79,7 +79,7 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
return true;
}
common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release);
common::Status OpKernelInfo::GetFusedFuncs(NodeComputeInfo*& compute_info) const {
return funcs_mgr_.GetFuncs(node_.Name(), compute_info);
}
} // namespace onnxruntime

View file

@ -114,7 +114,8 @@ void SessionState::CreateGraphInfo() {
for (const auto& output : graph_viewer_->GetOutputs()) {
if (output->Exists()) {
idx = ort_value_name_idx_map_.Add(output->Name());
VLOGS(logger_, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
VLOGS(logger_, 1) << "Added graph output with name: " << output->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
@ -122,10 +123,25 @@ void SessionState::CreateGraphInfo() {
}
#if !defined(ORT_MINIMAL_BUILD)
Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) {
Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager,
bool saving_ort_format) {
for (auto& node : graph_.Nodes()) {
const KernelCreateInfo* kci = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci));
auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
if (!status.IsOK() && saving_ort_format) {
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
// in that case we assigned the node to that EP but do not compile it into a fused node.
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
// we now revert to the CPU EP to include the hash for the kernel as a fallback. at runtime when the model
// is loaded in a minimal build, the compiling EP will replace this node if possible. if that's not possible for
// some reason we can fallback to the CPU EP implementation via this hash.
node.SetExecutionProviderType(kCpuExecutionProvider);
status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
}
ORT_RETURN_IF_ERROR(status);
ORT_IGNORE_RETURN_VALUE(
kernel_create_info_map_.insert({node.Index(), gsl::not_null<const KernelCreateInfo*>(kci)}));
}
@ -133,7 +149,7 @@ Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_regi
for (const auto& entry : subgraph_session_states_) {
for (const auto& name_to_subgraph_session_state : entry.second) {
SessionState& subgraph_session_state = *name_to_subgraph_session_state.second;
ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager));
ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
}
}
@ -813,16 +829,45 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
"Size mismatch for kernel create info node indexes and hashes. Invalid ORT format model.",
node_indices->size(), " != ", kernel_def_hashes->size());
auto add_kernel_by_hash =
[&kernel_registry_manager, this](const Node& node, uint64_t hash) {
const KernelCreateInfo* kci = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, hash, &kci));
kernel_create_info_map_.emplace(node.Index(), gsl::not_null<const KernelCreateInfo*>(kci));
return Status::OK();
};
// kernel hashes for model are in top level SessionState
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();
// process the nodes that existed when the model was created
for (flatbuffers::uoffset_t i = 0; i < node_indices->size(); i++) {
auto node_idx = node_indices->Get(i);
auto kernal_hash = kernel_def_hashes->Get(i);
auto kernel_hash = kernel_def_hashes->Get(i);
const Node* node = graph_.GetNode(node_idx);
ORT_RETURN_IF(node == nullptr, "Can't find node with index ", node_idx, ". Invalid ORT format model.");
if (node == nullptr) {
// this is OK if we have compiled kernels and the original node was replaced. if not the model is invalid.
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
"Can't find node with index ", node_idx, ". Invalid ORT format model.");
continue;
}
const KernelCreateInfo* kci = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(*node, kernal_hash, &kci));
kernel_create_info_map_.emplace(node_idx, gsl::not_null<const KernelCreateInfo*>(kci));
ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_hash));
}
// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
// as they were created at runtime.
if (!compiled_kernel_hashes.empty()) {
for (const auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.count(node.Index()) == 0) {
auto hash_info = compiled_kernel_hashes.find(node.OpType());
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
"Unable to find compiled kernel hash for node '", node.Name(), "'.")
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second));
}
}
}
if (!subgraph_session_states_.empty()) {
@ -880,7 +925,8 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options,
const onnxruntime::experimental::fbs::SessionState* serialized_session_state,
bool remove_initializers) {
bool remove_initializers,
bool saving_ort_format) {
// recursively create the subgraph session state instances and populate the kernel create info in them.
// it's simpler to handle the kernel create info recursively when deserializing,
// so also do it recursively when calling PopulateKernelCreateInfo for consistency.
@ -896,12 +942,13 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
} else {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager));
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
#else
ORT_UNUSED_PARAMETER(graph_location);
ORT_UNUSED_PARAMETER(kernel_registry_manager);
ORT_UNUSED_PARAMETER(session_options);
ORT_UNUSED_PARAMETER(remove_initializers);
ORT_UNUSED_PARAMETER(saving_ort_format);
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
"Serialized session state must be provided from an ORT format model in this build.");
#endif

View file

@ -269,6 +269,10 @@ class SessionState {
#endif
#if defined(ENABLE_ORT_FORMAT_LOAD)
void SetCompiledKernelHashes(std::unordered_map<std::string, uint64_t>&& compiled_kernel_hashes) {
compiled_kernel_hashes_ = std::move(compiled_kernel_hashes);
}
Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::SessionState& fbs_session_state,
const KernelRegistryManager& kernel_registry_manager);
#endif
@ -277,7 +281,8 @@ class SessionState {
KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options = {},
const onnxruntime::experimental::fbs::SessionState* serialized_session_state = nullptr,
bool remove_initializers = true);
bool remove_initializers = true,
bool saving_ort_format = false);
SessionState* Parent() {
return parent_;
@ -312,7 +317,7 @@ class SessionState {
std::unique_ptr<SessionState> session_state);
#if !defined(ORT_MINIMAL_BUILD)
Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager);
Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager, bool saving_ort_format);
#endif
Status FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
@ -330,9 +335,18 @@ class SessionState {
std::unordered_map<int, TensorShape>& inferred_shapes) const;
#endif
// the SessionState for the main Graph contains the compiled kernel hashes for the entire model
const std::unordered_map<std::string, uint64_t>& GetCompiledKernelHashes() const {
return parent_ ? parent_->GetCompiledKernelHashes() : compiled_kernel_hashes_;
}
// KernelCreateInfo for each node so we do kernel lookup once
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map_;
// If we compile kernels in a minimal build we need a way to find the kernel using the hash.
// We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat.
std::unordered_map<std::string, uint64_t> compiled_kernel_hashes_;
// cache of the constructed kernels to avoid spending construction time per executor
std::vector<OpKernel*> session_kernels_;
Graph& graph_;

View file

@ -101,7 +101,10 @@ bool ProviderIsCpuBased(const std::string& provider_type) {
provider_type == onnxruntime::kVitisAIExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider ||
provider_type == onnxruntime::kRknpuExecutionProvider;
provider_type == onnxruntime::kAclExecutionProvider ||
provider_type == onnxruntime::kArmNNExecutionProvider ||
provider_type == onnxruntime::kRknpuExecutionProvider ||
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
}
static common::Status AllocateHelper(const AllocatorPtr& allocator,
@ -574,16 +577,16 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context)
const Tensor* curr_input = context->Input<Tensor>(i);
ORT_ENFORCE(prev_input->Shape().Size() >= 0);
size_t input_element_count = static_cast<size_t>(prev_input->Shape().Size());
size_t input_element_size = prev_input->DataType()->Size();
size_t input_aligned_bytes = 0;
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes));
ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
prev_input = curr_input;
}
return Status::OK();

View file

@ -40,6 +40,13 @@ void DefaultFree(void* p);
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
// return true if the execution provider is CPU based (meaning no copies to device are required)
bool ProviderIsCpuBased(const std::string& provider_type);
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);

View file

@ -144,6 +144,34 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su
}
}
static std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph,
const IndexedSubGraph& nodes_to_fuse) {
const auto* meta_def = nodes_to_fuse.GetMetaDef();
auto op_schema = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
op_schema->SetName(meta_def->name);
op_schema->SetDomain(meta_def->domain);
op_schema->SetDoc(meta_def->doc_string);
op_schema->SinceVersion(meta_def->since_version);
int i = 0;
for (auto& input : meta_def->inputs) {
auto input_arg = graph.GetNodeArg(input);
// inputs must have a type. can be inferred for outputs.
ORT_ENFORCE(input_arg->Type() != nullptr);
op_schema->Input(i, input, "", *input_arg->Type());
++i;
}
i = 0;
for (auto& output : meta_def->outputs) {
auto output_arg = graph.GetNodeArg(output);
op_schema->Output(i, output, "", *output_arg->Type());
++i;
}
op_schema->Finalize();
return op_schema;
}
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
const IndexedSubGraph& nodes_to_fuse,
const logging::Logger& logger)
@ -154,12 +182,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
graph.DomainToVersionMap(), {}, logger) {
auto& function_body_graph = body_.MainGraph();
auto meta_def = nodes_to_fuse.GetMetaDef();
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
op_schema_->SetName(meta_def->name);
op_schema_->SetDomain(meta_def->domain);
op_schema_->SetDoc(meta_def->doc_string);
op_schema_->SinceVersion(meta_def->since_version);
auto* meta_def = nodes_to_fuse.GetMetaDef();
op_schema_ = CreateSchema(graph, nodes_to_fuse);
int i = 0;
std::vector<const NodeArg*> function_body_graph_inputs;
@ -168,8 +192,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
auto input_arg = parent_graph_->GetNodeArg(input);
auto& function_body_graph_input_arg = function_body_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
function_body_graph_inputs[i] = &function_body_graph_input_arg;
ORT_ENFORCE(input_arg->Type() != nullptr);
op_schema_->Input(i, input, "", *input_arg->Type());
++i;
}
@ -180,12 +202,9 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
auto output_arg = parent_graph_->GetNodeArg(output);
auto& function_body_graph_output_arg = function_body_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
function_body_graph_outputs[i] = &function_body_graph_output_arg;
op_schema_->Output(i, output, "", *output_arg->Type());
++i;
}
op_schema_->Finalize();
function_body_graph.SetInputs(function_body_graph_inputs);
function_body_graph.SetOutputs(function_body_graph_outputs);
@ -238,7 +257,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
// Hence, we make a copy prior to generating the graph representation of the function,
// as we might make some modifications to the FunctionProto along the way
auto node_in_parent_graph = parent_graph_->GetNode(node_index);
const auto* node_in_parent_graph = parent_graph_->GetNode(node_index);
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
op_schema_->SetName(onnx_func_proto_.name());
op_schema_->SetDomain(onnx_func_proto_.node().Get(0).domain());
@ -256,14 +275,13 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
auto cached_op_schema = node_in_parent_graph->Op();
if (!cached_op_schema) {
// Infer a op_schema for stand-alone functions.
IOTypeConstraintHelper(onnx_func_proto_, this->op_schema_, input_name_idx_map, output_name_idx_map);
IOTypeConstraintHelper(onnx_func_proto_, op_schema_, input_name_idx_map, output_name_idx_map);
} else {
auto type_constraint_params = cached_op_schema->typeConstraintParams();
for (auto& type_constraint_param : type_constraint_params) {
op_schema_->TypeConstraint(
type_constraint_param.type_param_str,
type_constraint_param.allowed_type_strs,
type_constraint_param.description);
op_schema_->TypeConstraint(type_constraint_param.type_param_str,
type_constraint_param.allowed_type_strs,
type_constraint_param.description);
}
int i = 0;
for (auto& input : cached_op_schema->inputs()) {
@ -286,10 +304,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
op_schema_->TypeAndShapeInferenceFunction(
[this](ONNX_NAMESPACE::InferenceContext& ctx) {
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
const ONNX_NAMESPACE::FunctionProto* func_ptr = this->GetFuncProto();
if (nullptr != func_ptr) {
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(func_ptr, schema_registry, ctx);
}
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(&onnx_func_proto_, schema_registry, ctx);
});
} else {
op_schema_->TypeAndShapeInferenceFunction(cached_op_schema->GetTypeAndShapeInferenceFunction());
@ -321,7 +336,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
}
for (int idx = 0; idx < (*node).input_size(); ++idx) {
std::string tensor_name = (*node).input().Get(idx);
const std::string& tensor_name = (*node).input().Get(idx);
auto iter = input_name_idx_map.find(tensor_name);
if (iter != input_name_idx_map.end()) {
// Preserving NodeArg and input/output names
@ -397,10 +412,14 @@ const onnxruntime::Graph& FunctionImpl::Body() const {
return body_.MainGraph();
}
const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
return &onnx_func_proto_;
ViewerFunctionImpl::ViewerFunctionImpl(const onnxruntime::Graph& graph,
const IndexedSubGraph& nodes_to_fuse,
const logging::Logger& /*logger*/) {
op_schema_ = CreateSchema(graph, nodes_to_fuse);
}
ViewerFunctionImpl::~ViewerFunctionImpl() = default;
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
const IndexedSubGraph& nodes_to_fuse,
const logging::Logger& logger) {

View file

@ -31,8 +31,6 @@ class FunctionImpl final : public Function {
const onnxruntime::Graph& Body() const override;
const ONNX_NAMESPACE::FunctionProto* GetFuncProto() const;
private:
const onnxruntime::Graph* const parent_graph_;
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
@ -40,4 +38,22 @@ class FunctionImpl final : public Function {
ONNX_NAMESPACE::FunctionProto onnx_func_proto_;
};
// Function that uses a GraphViewer so does not need to build a new Model. We still need the OpSchema to be available
// though so we just create that.
class ViewerFunctionImpl final : public Function {
public:
ViewerFunctionImpl(const onnxruntime::Graph& graph,
const IndexedSubGraph& nodes_to_fuse,
const logging::Logger& logger);
~ViewerFunctionImpl() override;
const ONNX_NAMESPACE::OpSchema& OpSchema() const override { return *op_schema_; }
const onnxruntime::Graph& Body() const override { ORT_THROW("Not supported"); }
private:
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
};
} // namespace onnxruntime

View file

@ -133,21 +133,26 @@ static TypeProto TypeProtoFromTensorProto(const TensorProto& tensor) {
return t;
}
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
NodeArg::NodeArg(const std::string& name, const TypeProto* p_node_arg_type) {
node_arg_info_.set_name(name);
// If the name is empty, it means the arg does not exist.
exists_ = !(name.empty());
if (nullptr != p_node_arg_type) {
(*node_arg_info_.mutable_type()) = *p_node_arg_type;
#if !defined(ORT_MINIMAL_BUILD)
// should not be possible to have invalid values in the ORT format model, so we don't need this
// in a minimal build
RemoveInvalidValues(*node_arg_info_.mutable_type());
#endif
type_ = DataTypeUtils::ToType(node_arg_info_.type());
} else {
type_ = nullptr;
}
}
#endif // !defined(ORT_MINIMAL_BUILD)
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
NodeArg::NodeArg(NodeArgInfo&& node_arg_info) {
node_arg_info_ = std::move(node_arg_info);
@ -427,10 +432,6 @@ void Node::SetPriority(int priority) noexcept {
#if !defined(ORT_MINIMAL_BUILD)
void Node::SetNodeType(Node::Type node_type) noexcept {
node_type_ = node_type;
}
const Function* Node::GetFunctionBody(bool try_init_func_body) {
if (nullptr != func_body_) {
return func_body_;
@ -450,10 +451,6 @@ void Node::SetFunctionBody(const Function& func) {
since_version_ = op_->since_version();
}
void Node::SetExecutionProviderType(ProviderType execution_provider_type) {
execution_provider_type_ = execution_provider_type;
}
void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
proto.set_name(name_);
proto.set_op_type(op_type_);
@ -666,9 +663,9 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::experimental::fbs::NodeEd
return Status::OK();
}
#endif
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void Node::Init(const std::string& name,
const std::string& op_type,
const std::string& description,
@ -697,7 +694,11 @@ void Node::Init(const std::string& name,
for (auto& name_to_attr : attributes_) {
if (utils::HasGraph(name_to_attr.second)) {
#if !defined(ORT_MINIMAL_BUILD)
CreateSubgraph(name_to_attr.first);
#else
ORT_THROW("Creating node with a subgraph via AddNode is not supported in this build.");
#endif
}
}
}
@ -716,7 +717,9 @@ Node::Relationships& Node::MutableRelationships() noexcept {
graph_->SetGraphProtoSyncNeeded();
return relationships_;
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
void Node::CreateSubgraph(const std::string& attr_name) {
auto attr = attributes_.find(attr_name);
@ -1264,7 +1267,9 @@ common::Status Graph::SetOuterScopeNodeArgs(const std::unordered_set<std::string
return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_slot, int dst_arg_slot) {
if (nodes_.size() <= src_node_index || src_arg_slot < 0 || nodes_.size() <= dst_node_index || dst_arg_slot < 0 ||
nullptr == nodes_[src_node_index] || nullptr == nodes_[dst_node_index]) {
@ -1349,7 +1354,9 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s
nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot));
nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot));
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
@ -1596,6 +1603,7 @@ void Graph::ReverseDFSFrom(const std::vector<const Node*>& from,
}
}
#if !defined(ORT_MINIMAL_BUILD)
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
@ -1634,7 +1642,6 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
}
}
#if !defined(ORT_MINIMAL_BUILD)
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
@ -2896,7 +2903,9 @@ std::string Graph::GenerateNodeName(const std::string& base_name) {
return new_name;
}
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Node& Graph::AddNode(const std::string& name,
const std::string& op_type,
const std::string& description,
@ -2945,7 +2954,9 @@ bool Graph::RemoveNode(NodeIndex p_index) {
return ReleaseNode(p_index);
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
bool Graph::AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index) {
if (nodes_.size() <= src_node_index ||
nodes_.size() <= dst_node_index ||
@ -3282,6 +3293,12 @@ Status Graph::SetGraphInputsOutputs() {
return Status::OK();
}
IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const {
return schema_registry_;
}
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// calling private ctor
GSL_SUPPRESS(r .11)
gsl::not_null<Node*> Graph::AllocateNode() {
@ -3313,21 +3330,24 @@ bool Graph::ReleaseNode(NodeIndex index) {
return true;
}
IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const {
return schema_registry_;
}
Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
auto* func_meta_def = sub_graph.GetMetaDef();
Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
const auto* func_meta_def = sub_graph.GetMetaDef();
ORT_ENFORCE(nullptr != func_meta_def);
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
std::unordered_map<std::string, int> input_indexes;
std::unordered_map<std::string, int> output_indexes;
int cur_idx = 0;
for (auto& arg_name : func_meta_def->inputs) {
input_args.push_back(GetNodeArg(arg_name));
input_indexes[arg_name] = cur_idx++;
}
cur_idx = 0;
for (auto& arg_name : func_meta_def->outputs) {
output_args.push_back(GetNodeArg(arg_name));
output_indexes[arg_name] = cur_idx++;
}
auto& fused_node = AddNode(fused_node_name,
@ -3339,21 +3359,102 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& f
func_meta_def->domain);
fused_node.SetNodeType(Node::Type::Fused);
function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_));
fused_node.SetFunctionBody(*function_container_.back());
// Remove nodes fused above.
return fused_node;
}
Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
Node& node = CreateFusedSubGraphNode(sub_graph, fused_node_name);
#if !defined(ORT_MINIMAL_BUILD)
// if this is a full build create the lightweight Function implementation that provides the schema so that
// kernel lookup works as per usual. in an extended minimal build we do the lookup via a hash so don't
// need to create the schema.
auto func = onnxruntime::make_unique<ViewerFunctionImpl>(*this, sub_graph, logger_);
function_container_.push_back(std::move(func));
node.SetFunctionBody(*function_container_.back());
#endif
return node;
}
void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) {
const auto* func_meta_def = sub_graph.GetMetaDef();
ORT_ENFORCE(nullptr != func_meta_def);
std::unordered_map<std::string, int> input_indexes;
std::unordered_map<std::string, int> output_indexes;
int cur_idx = 0;
for (auto& arg_name : func_meta_def->inputs) {
input_indexes[arg_name] = cur_idx++;
}
cur_idx = 0;
for (auto& arg_name : func_meta_def->outputs) {
output_indexes[arg_name] = cur_idx++;
}
auto new_node_idx = fused_node.Index();
// Remove nodes that were fused
for (auto node_index : sub_graph.nodes) {
auto node = GetNode(node_index);
if (nullptr == node) {
continue;
}
auto output_edges = node->GetRelationships().output_edges;
for (auto output_edge : output_edges) {
RemoveEdge(node->Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex());
// move any applicable input edges to the new node. remove all others
auto input_edges = node->GetRelationships().input_edges; // copy so RemoveEdge doesn't invalidate iterator
for (const auto& input_edge : input_edges) {
const auto& producer = input_edge.GetNode();
auto producer_idx = producer.Index();
auto src_idx = input_edge.GetSrcArgIndex();
auto dst_idx = input_edge.GetDstArgIndex();
// if this input is an input of the fused node add an edge for that
auto it = input_indexes.find(node->InputDefs()[dst_idx]->Name());
if (it != input_indexes.cend()) {
AddEdge(producer_idx, new_node_idx, src_idx, it->second);
}
RemoveEdge(producer_idx, node_index, src_idx, dst_idx);
}
// move any applicable output edges to the new node
auto output_edges = node->GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator
for (const auto& output_edge : output_edges) {
const auto& consumer = output_edge.GetNode();
auto consumer_idx = consumer.Index();
auto src_idx = output_edge.GetSrcArgIndex();
auto dst_idx = output_edge.GetDstArgIndex();
// if this output is an output of the fused node add an edge for that
auto it = output_indexes.find(node->OutputDefs()[src_idx]->Name());
if (it != output_indexes.cend()) {
AddEdge(new_node_idx, consumer_idx, it->second, dst_idx);
}
RemoveEdge(node_index, consumer_idx, src_idx, dst_idx);
}
RemoveNode(node_index);
}
}
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph,
const std::string& fused_node_name) {
Node& fused_node = CreateFusedSubGraphNode(sub_graph, fused_node_name);
// create Function before we remove nodes
function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_));
fused_node.SetFunctionBody(*function_container_.back());
// remove nodes and update edges
FinalizeFuseSubGraph(sub_graph, fused_node);
return fused_node;
}
@ -3541,6 +3642,7 @@ Graph::Graph(const Model& owning_model,
schema_registry_(std::make_shared<SchemaRegistryManager>()),
#endif
domain_to_version_(domain_to_version),
ir_version_(owning_model.IrVersion()),
parent_graph_(parent_graph),
parent_node_(parent_node),
logger_(logger),

View file

@ -4,6 +4,7 @@
#include "transformer_memcpy.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/utils.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
@ -62,15 +63,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (provider != onnxruntime::kCpuExecutionProvider &&
provider != onnxruntime::kDnnlExecutionProvider &&
provider != onnxruntime::kNGraphExecutionProvider &&
provider != onnxruntime::kNupharExecutionProvider &&
provider != onnxruntime::kVitisAIExecutionProvider &&
provider != onnxruntime::kOpenVINOExecutionProvider &&
provider != onnxruntime::kNnapiExecutionProvider &&
provider != onnxruntime::kAclExecutionProvider &&
provider != onnxruntime::kArmNNExecutionProvider) {
if (!utils::ProviderIsCpuBased(provider)) {
TransformerMemcpyImpl copy_impl(graph, provider);
auto current_modified = copy_impl.ModifyGraph(registry_manager_);
modified = modified || current_modified;
@ -163,9 +156,11 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
return modified;
}
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed) {
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;

View file

@ -745,45 +745,53 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) {
const auto& output = node.OutputDefs()[0]->Name();
// We will go through all the output edges
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& op_type = it->GetNode().OpType();
// TODO add quantized matmul when reshape support quantized input
if (op_type != "Gemm" && op_type != "MatMul") {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
<< " or no op is using the output (output is graph output)"
<< ", output name, " << output
<< " is used by " << op_type;
return false;
}
//
// TEMPORARILY DISABLED. Needs refinement.
//
// const auto& output = node.OutputDefs()[0]->Name();
// // We will go through all the output edges
// for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
// const auto& op_type = it->GetNode().OpType();
// // TODO add quantized matmul when reshape support quantized input
// if (op_type != "Gemm" && op_type != "MatMul") {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
// << " or no op is using the output (output is graph output)"
// << ", output name, " << output
// << " is used by " << op_type;
// return false;
// }
// NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
if (it->GetDstArgIndex() != 0) {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
<< ", output name, " << output;
return false;
}
// // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
// if (it->GetDstArgIndex() != 0) {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
// << ", output name, " << output;
// return false;
// }
// We only support 2d matmul/gemm here
// And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
if (input_rank < 2 || output_rank != 2) {
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
<< ", output name, " << output
<< ", the actual input_rank, " << input_rank
<< ", the actual output_rank, " << output_rank;
return false;
}
}
// // We only support 2d matmul/gemm here
// // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
// if (input_rank < 2 || output_rank != 2) {
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
// << ", output name, " << output
// << ", the actual input_rank, " << input_rank
// << ", the actual output_rank, " << output_rank;
// return false;
// }
// }
// If we reach here, we have either,
// all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
// or
// Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
// we can skip this Reshape now
LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
<< node.Name() << "] with output, " << output;
return true;
// // If we reach here, we have either,
// // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
// // or
// // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
// // we can skip this Reshape now
// LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
// << node.Name() << "] with output, " << output;
// return true;
ORT_UNUSED_PARAMETER(node);
ORT_UNUSED_PARAMETER(input_rank);
ORT_UNUSED_PARAMETER(output_rank);
return false;
}
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,

View file

@ -80,7 +80,6 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
// Find inputs, initializers and outputs for each supported subgraph
const std::vector<NodeIndex>& node_index = graph_view.GetNodesInTopologicalOrder();
const auto& graph_outputs = graph_view.GetOutputs();
int counter = 0;
for (const auto& group : supported_nodes_vector) {
if (group.empty())
continue;
@ -173,7 +172,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
// Assign inputs and outputs to subgraph's meta_def
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = "NNAPI_" + std::to_string(counter++);
meta_def->name = "NNAPI_" + std::to_string(metadef_id_++);
meta_def->domain = kMSDomain;
for (const auto& input : inputs) {
@ -202,6 +201,7 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort,
const std::vector<uint32_t>& output_shape,
const android::nn::wrapper::Type output_type,
void** output_buffer) ORT_MUST_USE_RESULT;
static Status GetOutputBuffer(Ort::CustomOpApi& ort,
OrtKernelContext* context,
const nnapi::Model& model,
@ -235,50 +235,43 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort,
return Status::OK();
}
common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
using namespace android::nn::wrapper;
for (const auto* fused_node : fused_nodes) {
// Reconstruct graph proto from fused node's function body
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
Node& fused_node = fused_node_and_graph.fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
const Graph& graph_body = func_body->Body();
nnapi::ModelBuilder builder(graph_viewer);
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
std::unique_ptr<nnapi::Model> nnapi_model;
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
// Build map from input name to its index in input definitions
{
onnxruntime::GraphViewer graph_viewer(graph_body);
nnapi::ModelBuilder builder(graph_viewer);
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
std::unique_ptr<nnapi::Model> nnapi_model;
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
// Build map from input name to its index in input definitions
{
std::unordered_map<std::string, size_t> input_map;
const auto& input_defs = fused_node->InputDefs();
input_map.reserve(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
input_map[input_defs[i]->Name()] = i;
}
nnapi_model->SetInputMap(std::move(input_map));
std::unordered_map<std::string, size_t> input_map;
const auto& input_defs = fused_node.InputDefs();
input_map.reserve(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
input_map[input_defs[i]->Name()] = i;
}
// Build map from output name to its index in output definitions
{
std::unordered_map<std::string, size_t> output_map;
const auto& output_defs = fused_node->OutputDefs();
output_map.reserve(output_defs.size());
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
output_map[output_defs[i]->Name()] = i;
}
nnapi_model->SetOutputMap(std::move(output_map));
}
nnapi_models_.emplace(fused_node->Name(), std::move(nnapi_model));
nnapi_model->SetInputMap(std::move(input_map));
}
// Build map from output name to its index in output definitions
{
std::unordered_map<std::string, size_t> output_map;
const auto& output_defs = fused_node.OutputDefs();
output_map.reserve(output_defs.size());
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
output_map[output_defs[i]->Name()] = i;
}
nnapi_model->SetOutputMap(std::move(output_map));
}
nnapi_models_.emplace(fused_node.Name(), std::move(nnapi_model));
NodeComputeInfo compute_info;
compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) {
*state = nnapi_models_[context->node_name].get();
@ -432,20 +425,21 @@ common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::No
return Status::OK();
}
#else
common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
ORT_UNUSED_PARAMETER(fused_node);
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
ORT_UNUSED_PARAMETER(fused_node_and_graph);
NodeComputeInfo compute_info;
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; };
compute_info.release_state_func = [](FunctionState /*state*/) {};
compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */, OrtKernelContext* /* context */) {
compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */,
OrtKernelContext* /* context */) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build.");
};
node_compute_funcs.push_back(compute_info);
}
return Status::OK();
}
#endif
#endif // __ANDROID__
} // namespace onnxruntime

View file

@ -19,11 +19,21 @@ class NnapiExecutionProvider : public IExecutionProvider {
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_view,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
// we implement the Compile that takes FusedNodeAndGraph instances
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
#endif
unsigned long GetNNAPIFlags() const { return nnapi_flags_; }
private:
// unique counter to name each fused kernel across the entire model
mutable int metadef_id_{0};
// The bit flags which define bool options for NNAPI EP, bits are defined as
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
const unsigned long nnapi_flags_;

View file

@ -793,7 +793,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
const ExecutionProviders& providers,
KernelRegistryManager& kernel_registry_manager,
const InsertCastTransformer& insert_cast_transformer,
SessionState& session_state) {
SessionState& session_state,
bool saving_model_in_ort_format) {
// The transformer order:
// 1. built-in graph rewriter
// 2. each execution provider's transformer
@ -822,10 +823,17 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
}
#endif
// if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them.
// we do this to preserve the original nodes in the model but prevent optimizers from changing them.
// at runtime, the ORT format model will re-do the partitioning/compilation of these nodes, which may change
// to cover fewer nodes due to device capabilities.
auto mode = saving_model_in_ort_format ? GraphPartitioner::Mode::kAssignOnly
: GraphPartitioner::Mode::kNormal;
// Do partitioning based on execution providers' capability.
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
session_state.GetMutableFuncMgr()));
session_state.GetMutableFuncMgr(), mode));
// apply transformers except default transformers
// Default transformers are required for correctness and they are owned and run by inference session
@ -888,9 +896,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
return common::Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph,
const ExecutionProviders& providers,
KernelRegistryManager& kernel_registry_manager,
SessionState& session_state) const {
std::unordered_map<std::string, uint64_t> compiled_kernel_hashes;
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
session_state.GetMutableFuncMgr(),
GraphPartitioner::Mode::kOrtFormatLoad,
&compiled_kernel_hashes));
if (!compiled_kernel_hashes.empty()) {
session_state.SetCompiledKernelHashes(std::move(compiled_kernel_hashes));
}
return Status::OK();
}
#endif
#if defined(ENABLE_ORT_FORMAT_LOAD)
template <typename T>
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
@ -1131,38 +1159,65 @@ common::Status InferenceSession::Initialize() {
// Register 2nd registries into KernelRegistryManager.
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
bool loading_ort_format = !ort_format_model_bytes_.empty();
bool saving_model = !session_options_.optimized_model_filepath.empty();
bool saving_ort_format = false;
if (saving_model) {
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, "");
bool has_explicit_type = !model_type.empty();
saving_ort_format = ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type &&
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)));
}
#if !defined(ORT_MINIMAL_BUILD)
// add predefined transformers
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
transformers_to_enable_);
if (!loading_ort_format) {
// add predefined transformers
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
transformers_to_enable_);
// apply any transformations to the main graph and any subgraphs
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
execution_providers_, kernel_registry_manager_,
insert_cast_transformer_,
*session_state_));
// apply any transformations to the main graph and any subgraphs
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
execution_providers_, kernel_registry_manager_,
insert_cast_transformer_,
*session_state_,
saving_ort_format));
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
} else
#endif // !defined(ORT_MINIMAL_BUILD)
{
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// nodes are already partitioned, but a custom EP may compile some at runtime.
// run the partitioning to allow that to happen.
//
// We always have the CPU EP, so only need to run this if some other EP is enabled
if (execution_providers_.NumProviders() > 1) {
ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
*session_state_));
}
#endif
}
// need to keep the initializers if we're going to save the optimized model
bool keep_initializers = !session_options_.optimized_model_filepath.empty();
const experimental::fbs::SessionState* serialized_session_state =
loading_ort_format
? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state()
: nullptr;
auto* serialized_session_state = !ort_format_model_bytes_.empty()
? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state()
: nullptr;
ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
session_options_,
serialized_session_state,
!keep_initializers));
ORT_RETURN_IF_ERROR_SESSIONID_(
session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
session_options_,
serialized_session_state,
// need to keep the initializers if saving the optimized model
!saving_model,
saving_ort_format));
#if !defined(ORT_MINIMAL_BUILD)
if (!session_options_.optimized_model_filepath.empty()) {
if (saving_model) {
if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
LOGS(*session_logger_, WARNING)
<< "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. "
@ -1170,12 +1225,7 @@ common::Status InferenceSession::Initialize() {
"and should only be used in the same environment the model was optimized for.";
}
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, "");
bool has_explicit_type = !model_type.empty();
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type &&
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) {
if (saving_ort_format) {
ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath));
} else {
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));

View file

@ -57,9 +57,7 @@ class LoggingManager;
*/
struct ModelMetadata {
ModelMetadata() = default;
ModelMetadata(const ModelMetadata& other)
: producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) {
}
ModelMetadata(const ModelMetadata&) = default;
~ModelMetadata() = default;
ModelMetadata& operator=(const ModelMetadata&) = delete;
@ -531,7 +529,8 @@ class InferenceSession {
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager,
const InsertCastTransformer& insert_cast_transformer,
SessionState& session_state) ORT_MUST_USE_RESULT;
SessionState& session_state,
bool saving_model_in_ort_format) ORT_MUST_USE_RESULT;
onnxruntime::GraphTransformerManager graph_transformation_mgr_;
@ -543,6 +542,11 @@ class InferenceSession {
std::vector<std::string> transformers_to_enable_;
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers,
KernelRegistryManager& kernel_registry_manager, SessionState& session_state) const;
#endif
SessionOptions session_options_;
/// Logging manager if provided.

View file

@ -231,9 +231,10 @@ namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
static Env& platform_env = Env::Default();
#if !defined(ORT_MINIMAL_BUILD)
// Custom op section starts
static Env& platform_env = Env::Default();
CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) {
{
@ -650,10 +651,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id,
cuda_mem_limit,
arena_extend_strategy));
RegisterExecutionProvider(
sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id,
cuda_mem_limit,
arena_extend_strategy));
#endif
} else if (type == kDnnlExecutionProvider) {
#ifdef USE_DNNL
@ -685,11 +686,9 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
} else if (option.first == "device_id") {
openvino_device_id = option.second;
}
else if (option.first == "num_of_threads") {
} else if (option.first == "num_of_threads") {
num_of_threads = std::stoi(option.second);
}
else {
} else {
ORT_THROW("Invalid OpenVINO EP option: ", option.first);
}
}
@ -946,7 +945,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
ORT_UNUSED_PARAMETER(algo);
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
#else
cudnn_conv_algo_search = algo;
cudnn_conv_algo_search = algo;
#endif
});
m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) {
@ -954,7 +953,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
ORT_UNUSED_PARAMETER(use_single_stream);
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
#else
do_copy_in_default_stream = use_single_stream;
do_copy_in_default_stream = use_single_stream;
#endif
});
m.def("set_cuda_mem_limit", [](const int64_t limit) { cuda_mem_limit = static_cast<size_t>(limit); });
@ -1734,7 +1733,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
.def("end_profiling", [](PyInferenceSession* sess) -> std::string {
return sess->GetSessionHandle()->EndProfiling();
})
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t{
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def("get_providers", [](PyInferenceSession* sess) -> const std::vector<std::string>& {

View file

@ -0,0 +1,3 @@
Internal test EP that is used to test and validate interactions between the ORT framework, optimizers and an EP.
Primary usage currently is validating support in a minimal build for an EP that compiles nodes into kernels at runtime.

View file

@ -0,0 +1,218 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "internal_testing_execution_provider.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/compute_capability.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/graph/model.h"
#include "core/session/onnxruntime_cxx_api.h"
namespace onnxruntime {
constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP";
InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops)
: IExecutionProvider{utils::kInternalTestingExecutionProvider},
ops_{ops} {
// TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through
// the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first
// and fall back to using session state next by passing in a functor it can use to call SessionState::GetAllocator).
AllocatorCreationInfo device_info(
[](int) {
return onnxruntime::make_unique<CPUAllocator>(OrtMemoryInfo(INTERNAL_TESTING_EP,
OrtAllocatorType::OrtDeviceAllocator));
});
InsertAllocator(CreateAllocator(device_info));
}
InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {}
std::vector<std::unique_ptr<ComputeCapability>>
InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const {
std::vector<std::unique_ptr<ComputeCapability>> result;
/*
Very basic search for groups of nodes that can be handled by the EP.
This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
but B is between them in the topological sort as you'll get two single node capabilities. However if can also
be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
Not sure how often each of these scenarios happens.
A B C
| / |
D E
| |
Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
connected nodes that can be handled are grouped together.
*/
std::vector<std::vector<NodeIndex>> node_groups;
std::vector<NodeIndex> cur_group;
for (NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) {
if (ops_.count(graph_viewer.GetNode(node_index)->OpType())) {
cur_group.push_back(node_index);
} else if (!cur_group.empty()) {
node_groups.push_back(std::move(cur_group));
}
}
if (!cur_group.empty()) {
node_groups.push_back(std::move(cur_group));
}
if (node_groups.empty()) {
return result;
}
const auto& graph_output_list = graph_viewer.GetOutputs();
std::unordered_set<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
for (const auto& group : node_groups) {
std::unordered_set<NodeIndex> node_set;
node_set.reserve(group.size());
for (const auto& index : group) {
node_set.insert(index);
}
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::make_unique<IndexedSubGraph>();
std::unordered_set<const NodeArg*> node_outputs;
std::unordered_set<const NodeArg*> subgraph_inputs;
std::unordered_set<const NodeArg*> subgraph_outputs;
std::vector<const NodeArg*> ordered_subgraph_inputs;
std::vector<const NodeArg*> ordered_subgraph_outputs;
for (const auto& index : group) {
sub_graph->nodes.push_back(index);
const auto* node = graph_viewer.GetNode(index);
for (const auto* input : node->InputDefs()) {
// if the node input was not produced by this subgraph, add it to the subgraph inputs.
if (node_outputs.count(input) == 0) {
if (subgraph_inputs.count(input) == 0) {
subgraph_inputs.insert(input);
ordered_subgraph_inputs.push_back(input);
}
}
}
const auto& output_defs = node->OutputDefs();
for (const auto* output_def : output_defs) {
node_outputs.insert(output_def);
// if output is overall graph output we need to produce it.
if (graph_outputs.count(output_def) != 0) {
ordered_subgraph_outputs.push_back(output_def);
}
}
// if output connects to a node not in this subgraph we need to produce it
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
if (node_set.count(it->GetNode().Index()) == 0) {
const auto* output_def = output_defs[it->GetSrcArgIndex()];
if (subgraph_outputs.count(output_def) == 0) {
subgraph_outputs.insert(output_def);
ordered_subgraph_outputs.push_back(output_def);
}
}
}
}
// Assign inputs and outputs to subgraph's meta_def
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = "InternalTestingEP_" + std::to_string(metadef_id_++);
meta_def->domain = kMSDomain;
meta_def->since_version = 1;
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
for (const auto& input : ordered_subgraph_inputs) {
meta_def->inputs.push_back(input->Name());
}
for (const auto& output : ordered_subgraph_outputs) {
meta_def->outputs.push_back(output->Name());
}
sub_graph->SetMetaDef(std::move(meta_def));
result.push_back(onnxruntime::make_unique<ComputeCapability>(std::move(sub_graph)));
}
return result;
}
common::Status InternalTestingExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
// Create a function to generate dummy empty output for each fused node so the model can be executed.
for (const auto& node_and_viewer : fused_nodes) {
NodeComputeInfo compute_info;
const Node& node = node_and_viewer.fused_node;
//{
// const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
// std::cout << "Fusing nodes: ";
// for (const auto& unfused_node : graph_viewer.Nodes()) {
// std::cout << " '" << unfused_node.Name() << "':" << unfused_node.Index();
// }
// std::cout << std::endl;
//}
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) {
return 0;
};
compute_info.release_state_func = [](FunctionState /*state*/) {
};
compute_info.compute_func = [&node](FunctionState /*state*/, const OrtCustomOpApi* c_api,
OrtKernelContext* context) -> Status {
Ort::CustomOpApi api{*c_api}; // use C++ API for convenience
const auto outputs = node.OutputDefs();
const size_t num_outputs = outputs.size();
for (size_t i = 0; i < num_outputs; i++) {
const auto* shape_proto = outputs[i]->Shape();
if (shape_proto == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown output shapes are not supported");
}
TensorShape shape = utils::GetTensorShapeFromTensorShapeProto(*shape_proto);
if (shape.Size() < 0) {
// arbitrarily set any unknown dim to 1
for (size_t idx = 0, end = shape.NumDimensions(); idx < end; ++idx) {
if (shape[idx] == -1) {
shape[idx] = 1;
}
}
}
// create the output_tensor.
auto* ortvalue = api.KernelContext_GetOutput(context, i, shape.GetDims().data(), shape.GetDims().size());
// and fill with zeros
auto* tensor = ortvalue->GetMutable<Tensor>();
void* data = tensor->MutableDataRaw();
auto bytes = tensor->SizeInBytes();
memset(data, 0, bytes);
};
return Status::OK();
};
node_compute_funcs.push_back(std::move(compute_info));
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <set>
#include "core/framework/execution_provider.h"
namespace onnxruntime {
class InternalTestingExecutionProvider : public IExecutionProvider {
public:
InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops);
virtual ~InternalTestingExecutionProvider();
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_view,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
FusionStyle GetFusionStyle() const override {
return FusionStyle::FilteredGraphViewer;
}
private:
const std::unordered_set<std::string> ops_;
// unique counter to name each fused kernel across the entire model
mutable int metadef_id_{0};
};
} // namespace onnxruntime

View file

@ -0,0 +1,242 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/framework/utils.h"
#include "core/session/inference_session.h"
#include "test/framework/test_utils.h"
#include "test/test_environment.h"
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
#include "test/util/include/asserts.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test_utils.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::logging;
namespace onnxruntime {
namespace test {
static void CreateSession(const SessionOptions& so, std::unique_ptr<InferenceSessionWrapper>& session,
const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model
bool enable_custom_ep = true,
const std::unordered_set<std::string>* override_supported_ops = nullptr) {
session = onnxruntime::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
// set supported ops to ops that are ideally found consecutively in the model.
// we can say the EP potentially handles them all, but can also test removing handling of one or more ops
// at runtime to simulate a lower spec device where not all ops can be handled. this allows us to test
// that we can revert ops back to the CPU implementation successfully
const std::unordered_set<std::string> default_supported_ops{"Conv", "Add", "Relu", "MaxPool"};
const std::unordered_set<std::string>* supported_ops = override_supported_ops ? override_supported_ops
: &default_supported_ops;
if (enable_custom_ep) {
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
onnxruntime::make_unique<InternalTestingExecutionProvider>(*supported_ops)));
}
ASSERT_STATUS_OK(session->Load(model_path));
ASSERT_STATUS_OK(session->Initialize());
}
static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enabled) {
// validate that we can execute the model. the dummy internal testing EP just creates empty output so the
// values in the output aren't relevant. all we care about is that we can execute the model and produce output.
OrtValue ml_value_x;
TensorShape input_shape{1, 1, 28, 28};
std::vector<float> input(input_shape.Size(), 1.f);
CreateMLValue<float>(input_shape.GetDims(), input.data(), OrtMemoryInfo(), &ml_value_x);
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value_x));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Plus214_Output_0");
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
if (custom_ep_enabled) {
// check that the output is all zeros. the dummy EP produces output of the correct shape with all zeros, so any
// downstream operations should still result in zeros for this model
// OR it should equal the bias in the final Add operation, which is in the Parameter194 initializer
const auto& t = fetches[0].Get<Tensor>();
const auto data = t.DataAsSpan<float>();
int idx = 0;
const auto& session_state = session.GetSessionState();
ASSERT_STATUS_OK(session_state.GetOrtValueNameIdxMap().GetIdx("Parameter194", idx));
const auto& initializer = session_state.GetConstantInitializedTensors().at(idx);
const auto expected = initializer.Get<Tensor>().DataAsSpan<float>();
ASSERT_THAT(data, ::testing::ContainerEq(expected));
}
}
#if !defined(ORT_MINIMAL_BUILD)
TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.test_output.ort");
//
// First load the onnx format model and save as an ORT model.
// This should preserve the nodes the custom EP can handle.
//
std::unique_ptr<InferenceSessionWrapper> session;
SessionOptions so;
so.optimized_model_filepath = ort_model_path;
CreateSession(so, session);
// this graph should include the original nodes that the custom EP will take at runtime
auto num_nodes = session->GetGraph().NumberOfNodes();
//
// Second, load the ORT format model with just the CPU EP to make sure it can be executed. This tests that the
// fallback to the CPU EP kernel hashes works.
//
std::unique_ptr<InferenceSessionWrapper> session2;
so.optimized_model_filepath.clear();
bool enable_custom_ep = false;
CreateSession(so, session2, ort_model_path, enable_custom_ep);
const auto& graph1 = session2->GetGraph();
// model should have all the original nodes and we should be able to execute with the fallback to CPU EP
ASSERT_EQ(graph1.NumberOfNodes(), num_nodes);
ExecuteMnist(*session2, enable_custom_ep);
session2 = nullptr;
//
// Finally, load the ORT format model with the custom EP enabled. This tests that we support runtime compilation
// for the ORT format model.
//
enable_custom_ep = true;
CreateSession(so, session2, ort_model_path, enable_custom_ep);
const auto& graph2 = session2->GetGraph();
// model should be able to be loaded, and we should compile using custom ep. that will result in one node for the
// custom EP (with Conv/Add/Relu/MaxPool), one for a reshape, and one for the fused MatMul+Add.
ASSERT_EQ(graph2.NumberOfNodes(), 3);
ExecuteMnist(*session2, enable_custom_ep);
}
#endif // !defined(ORT_MINIMAL_BUILD)
// test to validate a minimal build
TEST(InternalTestingEP, TestLoadOrtModel) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
std::unique_ptr<InferenceSessionWrapper> session;
bool enable_custom_ep = true;
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep);
ExecuteMnist(*session, enable_custom_ep);
}
// test that is the custom EP cannot take all nodes due to device limitations
// that we fallback to the CPU implementations and can execute the model
TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/};
std::unique_ptr<InferenceSessionWrapper> session;
bool enable_custom_ep = true;
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops);
const auto& graph = session->GetGraph();
// Conv+Add gets fused by level 1 optimizer into single node. The 'Conv'/'Add'/'Relu' nodes should be compiled and
// handled by the custom EP. fallback to CPU for MaxPool.
ASSERT_EQ(graph.NumberOfNodes(), 6);
const auto& func_mgr = session->GetSessionState().GetFuncMgr();
NodeComputeInfo* compute_func = nullptr;
for (const auto& node : graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
}
}
ExecuteMnist(*session, enable_custom_ep);
}
TEST(InternalTestingEP, TestMinimalRegistrationOfEPwithGetCapability) {
// TODO: In a full build we want to be able to call GetCapability for the NNAPI EP and produce an ORT format model
// with nodes correctly preserved. That requires being able to do a minimal registration of that EP where
// GetCapability is fully implemented, but Compile is a stub that just throws NOT_IMPLEMENTED if someone attempts
// to execute a model in that InferenceSession.
}
// count nodes assigned to the test EP and make sure they all have valid compute funcs
static int CountAndValidateAssignedNodes(const Graph& current_graph,
const std::unordered_set<std::string>& supported_ops,
const FuncManager& func_mgr) {
int count = 0;
for (const auto& node : current_graph.Nodes()) {
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
NodeComputeInfo* compute_func = nullptr;
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
EXPECT_NE(compute_func, nullptr);
++count;
}
if (node.ContainsSubgraph()) {
for (const auto& entry : node.GetSubgraphs()) {
count += CountAndValidateAssignedNodes(*entry, supported_ops, func_mgr);
}
}
}
return count;
}
// Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs.
// There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels.
TEST(InternalTestingEP, TestModelWithSubgraph) {
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
const std::unordered_set<std::string> supported_ops{"Add"};
std::unique_ptr<InferenceSessionWrapper> session;
bool enable_custom_ep = true;
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops);
const auto& graph = session->GetGraph();
const auto& func_mgr = session->GetSessionState().GetFuncMgr();
int num_replaced_nodes = CountAndValidateAssignedNodes(graph, supported_ops, func_mgr);
// One Add node in the Loop. One Add node in each branch of the If inside the Loop body
ASSERT_EQ(num_replaced_nodes, 3);
OrtValue ml_value;
// this is a bit of a hack. the correct output is the input value + 2, so if we start with -2 the result is 0.
// the output from fused nodes using the testing EP is always 0, so we should match the expected output this way
// as we replace all the Add nodes with something that returns 0.
// RunAndVerifyOutputsWithEP checks that nodes are assigned to the EP so we know it's being used to execute the model
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {-2.f},
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("state_var_in", ml_value));
// compare outputs from CPU EP vs custom EP
RunAndVerifyOutputsWithEP(ort_model_path,
"InternalTestingEP.TestModelWithSubgraph",
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops),
feeds);
}
} // namespace test
} // namespace onnxruntime

View file

@ -1,13 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/common/logging/logging.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
#include "core/session/inference_session.h"
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/framework/test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test/test_environment.h"
#include "test/util/include/test_utils.h"
#if !defined(ORT_MINIMAL_BUILD)
// if this is a full build we need the provider test utils
#include "test/providers/provider_test_utils.h"
#endif
#include "gtest/gtest.h"
#include "gmock/gmock.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -16,85 +28,48 @@ using namespace ::onnxruntime::logging;
namespace onnxruntime {
namespace test {
#ifdef __ANDROID__
void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
const std::vector<float>& expected_values) {
ASSERT_EQ(1, fetches.size());
auto& rtensor = fetches.front().Get<Tensor>();
TensorShape expected_shape(expected_dims);
ASSERT_EQ(expected_shape, rtensor.Shape());
const std::vector<float> found(rtensor.template Data<float>(), rtensor.template Data<float>() + expected_values.size());
ASSERT_EQ(expected_values, found);
}
#endif
void RunAndVerifyOutputs(const std::string& model_file_name,
const char* log_id,
const NameMLValMap& feeds,
const std::vector<std::string>& output_names,
const std::vector<int64_t>& expected_dims,
const std::vector<float>& expected_values) {
SessionOptions so;
so.session_logid = log_id;
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
// Since we already know the model is entirely supported by NNAPI, all nodes (2 Add nodes here) will be fused
// Get the graph after session is initialized, and verify the fused node (the only node in the graph) is using NNAPI EP
const auto& graph = session_object.GetGraph();
ASSERT_EQ(1, graph.NumberOfNodes()); // Make sure the graph has 1 fused node
ASSERT_EQ(onnxruntime::kNnapiExecutionProvider, graph.Nodes().cbegin()->GetExecutionProviderType());
// The execution can only be performed on Android
#ifdef __ANDROID__
// Now run and verify the result
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims, expected_values);
#else
ORT_UNUSED_PARAMETER(feeds);
ORT_UNUSED_PARAMETER(output_names);
ORT_UNUSED_PARAMETER(expected_dims);
ORT_UNUSED_PARAMETER(expected_values);
#endif
}
#if !defined(ORT_MINIMAL_BUILD)
// Since NNAPI EP handles Reshape and Flatten differently,
// Please see ReshapeOpBuilder::CanSkipReshape in <repo_root>/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc
// Please see ReshapeOpBuilder::CanSkipReshape in
// <repo_root>/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc
// We have a separated test for these skip reshape scenarios
TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) {
std::string model_file_name = "testdata/nnapi_reshape_flatten_test.onnx";
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/nnapi_reshape_flatten_test.onnx");
#if defined(__ANDROID__)
std::vector<int64_t> dims_mul_x = {2, 1, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f};
std::vector<int64_t> dims_mul_y = {3, 2, 2};
std::vector<float> values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
OrtValue ml_value_x;
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y, &ml_value_y);
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y,
&ml_value_y);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Z");
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_mul_z = {1, 6};
std::vector<float> expected_values_mul_z = {59.0f, 72.0f, 129.0f, 159.0f, 204.0f, 253.0f};
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", feeds, output_names, expected_dims_mul_z, expected_values_mul_z);
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest",
onnxruntime::make_unique<NnapiExecutionProvider>(0),
feeds);
#else
// test load only
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
<< "Some nodes should have been taken by the NNAPI EP";
#endif
}
TEST(NnapiExecutionProviderTest, FunctionTest) {
std::string model_file_name = "nnapi_execution_provider_test_graph.onnx";
const ORTCHAR_T* model_file_name = ORT_TSTR("nnapi_execution_provider_test_graph.onnx");
{ // Create the model with 2 add nodes
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
@ -130,30 +105,38 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
}
#if defined(__ANDROID__)
std::vector<int64_t> dims_mul_x = {1, 1, 3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
OrtValue ml_value_x;
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_y);
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_z);
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("M");
std::vector<OrtValue> fetches;
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_mul_m = {1, 1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.FunctionTest",
onnxruntime::make_unique<NnapiExecutionProvider>(0),
feeds);
#else
// test load only
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
<< "Some nodes should have been taken by the NNAPI EP";
#endif
}
#endif // !(ORT_MINIMAL_BUILD
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE;
@ -164,5 +147,38 @@ TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW);
}
TEST(NnapiExecutionProviderTest, TestOrtFormatModel) {
// mnist model that has only had basic optimizations applied. nnapi should be able to take at least some of the nodes
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort");
// The execution can only be performed on Android
#if defined(__ANDROID__)
RandomValueGenerator random{};
const std::vector<int64_t> dims = {1, 1, 28, 28};
std::vector<float> data = random.Gaussian<float>(dims, 0.0f, 1.f);
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, data,
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value));
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.TestOrtFormatModel",
onnxruntime::make_unique<NnapiExecutionProvider>(0),
feeds);
#else
// test load only
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
<< "Some nodes should have been taken by the NNAPI EP";
#endif
}
} // namespace test
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Binary file not shown.

BIN
onnxruntime/test/testdata/mnist.onnx vendored Normal file

Binary file not shown.

BIN
onnxruntime/test/testdata/mnist.ort vendored Normal file

Binary file not shown.

View file

@ -10,14 +10,19 @@
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo,
bool do_copy_in_default_stream = true);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(
OrtDevice::DeviceId device_id,
OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo,
bool do_copy_in_default_stream = true);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(
const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_NGraph(const char* ng_backend_type);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(bool, const char*);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
@ -29,6 +34,10 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ROCM(O
size_t hip_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
// EP for internal testing
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_InternalTesting(
const std::unordered_set<std::string>& supported_ops);
namespace test {
std::unique_ptr<IExecutionProvider> DefaultCpuExecutionProvider(bool enable_arena) {

View file

@ -4,17 +4,18 @@
#pragma once
#include "core/common/status.h"
#include "gtest/gtest.h"
// helpers to run a function and check the status, outputting any error if it fails.
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
#define ASSERT_STATUS_OK(function) \
do { \
Status _tmp_status = function; \
Status _tmp_status = function; \
ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
} while (false)
#define EXPECT_STATUS_OK(function) \
do { \
Status _tmp_status = function; \
Status _tmp_status = function; \
EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
} while (false)

View file

@ -21,5 +21,9 @@ std::unique_ptr<IExecutionProvider> DefaultAclExecutionProvider(bool enable_aren
std::unique_ptr<IExecutionProvider> DefaultArmNNExecutionProvider(bool enable_arena = true);
std::unique_ptr<IExecutionProvider> DefaultRocmExecutionProvider();
// EP for internal testing
std::unique_ptr<IExecutionProvider> DefaultInternalTestingExecutionProvider(
const std::unordered_set<std::string>& supported_ops);
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/framework_common.h"
#include "core/framework/execution_provider.h"
#include <memory>
#include <string>
#include <vector>
namespace onnxruntime {
class Graph;
namespace test {
// return number of nodes in the Graph and any subgraphs that are assigned to the specified execution provider
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type);
// run the model using the CPU EP to get expected output, comparing to the output when the 'execution_provider'
// is enabled. requires that at least one node is assigned to 'execution_provider'
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path,
const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds);
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,119 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/util/include/test_utils.h"
#include "core/framework/ml_value.h"
#include "core/session/inference_session.h"
#include "test/util/include/asserts.h"
#include "test/util/include/test/test_environment.h"
#include "test/util/include/inference_session_wrapper.h"
#include "gmock/gmock.h"
namespace onnxruntime {
namespace test {
static void VerifyOutputs(const std::vector<std::string>& output_names,
const std::vector<OrtValue>& expected_fetches,
const std::vector<OrtValue>& fetches) {
ASSERT_EQ(expected_fetches.size(), fetches.size());
for (size_t i = 0, end = expected_fetches.size(); i < end; ++i) {
auto& ltensor = expected_fetches[i].Get<Tensor>();
auto& rtensor = fetches[i].Get<Tensor>();
ASSERT_EQ(ltensor.Shape().GetDims(), rtensor.Shape().GetDims());
auto element_type = ltensor.GetElementType();
switch (element_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
EXPECT_THAT(ltensor.DataAsSpan<int32_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int32_t>()))
<< " mismatch for " << output_names[i];
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
EXPECT_THAT(ltensor.DataAsSpan<int64_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int64_t>()))
<< " mismatch for " << output_names[i];
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
const float abs_err = float(1e-5);
EXPECT_THAT(ltensor.DataAsSpan<float>(),
::testing::Pointwise(::testing::FloatNear(abs_err), rtensor.DataAsSpan<float>()));
break;
}
default:
ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type);
}
}
}
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type) {
int count = 0;
for (const auto& node : current_graph.Nodes()) {
if (node.GetExecutionProviderType() == ep_type) {
++count;
}
if (node.ContainsSubgraph()) {
for (const auto& entry : node.GetSubgraphs()) {
count += CountAssignedNodes(*entry, ep_type);
}
}
}
return count;
}
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds) {
SessionOptions so;
so.session_logid = log_id;
RunOptions run_options;
run_options.run_tag = so.session_logid;
//
// get expected output from CPU EP
//
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_path));
ASSERT_STATUS_OK(session_object.Initialize());
const auto& graph = session_object.GetGraph();
const auto& outputs = graph.GetOutputs();
// fetch all outputs
std::vector<std::string> output_names;
output_names.reserve(outputs.size());
for (const auto* node_arg : outputs) {
if (node_arg->Exists()) {
output_names.push_back(node_arg->Name());
}
}
std::vector<OrtValue> expected_fetches;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &expected_fetches));
auto provider_type = execution_provider->Type(); // copy string so the std::move doesn't affect us
//
// get output with EP enabled
//
InferenceSessionWrapper session_object2{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(std::move(execution_provider)));
ASSERT_STATUS_OK(session_object2.Load(model_path));
ASSERT_STATUS_OK(session_object2.Initialize());
// make sure that some nodes are assigned to the EP, otherwise this test is pointless...
const auto& graph2 = session_object2.GetGraph();
auto ep_nodes = CountAssignedNodes(graph2, provider_type);
ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type << " for " << model_path;
// Run with EP and verify the result
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(output_names, expected_fetches, fetches);
}
} // namespace test
} // namespace onnxruntime

View file

@ -419,10 +419,13 @@ def parse_arguments():
help="Build ONNXRuntime micro-benchmarks.")
# options to reduce binary size
parser.add_argument("--minimal_build", action='store_true',
parser.add_argument("--minimal_build", action='store',
const='on', default='off', nargs='?', type=str.lower,
help="Create a build that only supports ORT format models. "
"See /docs/ONNX_Runtime_Format_Model_Usage.md for more information. "
"RTTI is automatically disabled in a minimal build.")
"RTTI is automatically disabled in a minimal build. "
"To enable execution providers that compile kernels at runtime (e.g. NNAPI) pass 'extended' "
"as a parameter. e.g. '--minimal_build extended'.")
parser.add_argument("--include_ops_by_model", type=str, help="include ops from model(s) under designated path.")
parser.add_argument("--include_ops_by_config", type=str,
help="include ops from config file. "
@ -710,7 +713,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
"-Donnxruntime_DISABLE_RTTI=" + ("ON" if args.disable_rtti else "OFF"),
"-Donnxruntime_DISABLE_EXCEPTIONS=" + ("ON" if args.disable_exceptions else "OFF"),
"-Donnxruntime_DISABLE_ORT_FORMAT_LOAD=" + ("ON" if args.disable_ort_format_load else "OFF"),
"-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build else "OFF"),
"-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build != 'off' else "OFF"),
"-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + ("ON" if args.minimal_build == 'extended' else "OFF"),
"-Donnxruntime_REDUCED_OPS_BUILD=" + (
"ON" if args.include_ops_by_config or args.include_ops_by_model else "OFF"),
"-Donnxruntime_MSVC_STATIC_RUNTIME=" + (