mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Support EPs that compile nodes in a minimal build. (#5776)
* Support EPs that compile nodes in a minimal build. This enables NNAPI being used.
This commit is contained in:
parent
794e8479eb
commit
7b76b57fc8
52 changed files with 1866 additions and 505 deletions
|
|
@ -120,6 +120,7 @@ option(onnxruntime_DISABLE_RTTI "Disable RTTI" OFF)
|
|||
# For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone
|
||||
option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." OFF)
|
||||
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
|
||||
option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF)
|
||||
option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF)
|
||||
option(onnxruntime_DISABLE_ORT_FORMAT_LOAD "Disable loading an ORT format model when onnxruntime_MINIMAL_BUILD=OFF (i.e. in a full build)." OFF)
|
||||
|
||||
|
|
@ -208,12 +209,21 @@ if(onnxruntime_USE_OPENMP)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
# 'extended' implies minimal.
|
||||
if (onnxruntime_EXTENDED_MINIMAL_BUILD AND NOT onnxruntime_MINIMAL_BUILD)
|
||||
set(onnxruntime_MINIMAL_BUILD ON)
|
||||
endif()
|
||||
|
||||
# ORT build with as much excluded as possible. Supports ORT flatbuffers models only.
|
||||
# Will expose option in build.py when all pieces are available
|
||||
if(onnxruntime_MINIMAL_BUILD)
|
||||
if (onnxruntime_MINIMAL_BUILD)
|
||||
add_compile_definitions(ORT_MINIMAL_BUILD)
|
||||
add_compile_definitions(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
# enable EPs that compile kernels at runtime
|
||||
add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
endif()
|
||||
|
||||
set(onnxruntime_REDUCED_OPS_BUILD ON)
|
||||
|
||||
if (NOT onnxruntime_ENABLE_PYTHON)
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ target_link_libraries(onnxruntime PRIVATE
|
|||
${PROVIDERS_DML}
|
||||
${PROVIDERS_ACL}
|
||||
${PROVIDERS_ARMNN}
|
||||
${PROVIDERS_INTERNAL_TESTING}
|
||||
${onnxruntime_winml}
|
||||
${PROVIDERS_ROCM}
|
||||
onnxruntime_optimizer
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
|
|||
if (onnxruntime_MINIMAL_BUILD)
|
||||
file(GLOB onnxruntime_framework_src_exclude
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/graph_partitioner.*"
|
||||
"${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -657,6 +657,10 @@ if (onnxruntime_USE_OPENVINO)
|
|||
endif()
|
||||
|
||||
if (onnxruntime_USE_NNAPI_BUILTIN)
|
||||
if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'")
|
||||
endif()
|
||||
|
||||
add_compile_definitions(USE_NNAPI=1)
|
||||
|
||||
# This is the minimum Android API Level required by ORT NNAPI EP to run
|
||||
|
|
@ -782,7 +786,7 @@ if (onnxruntime_USE_DML)
|
|||
else()
|
||||
add_dependencies(${target} RESTORE_PACKAGES)
|
||||
target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/DirectML.lib")
|
||||
target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST)
|
||||
target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
|
|
|||
|
|
@ -312,6 +312,13 @@ if (onnxruntime_USE_RKNPU)
|
|||
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src})
|
||||
endif()
|
||||
|
||||
if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS
|
||||
"${TEST_SRC_DIR}/providers/internal_testing/*"
|
||||
)
|
||||
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_internal_testing_src})
|
||||
endif()
|
||||
|
||||
set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")
|
||||
set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools")
|
||||
set (ONNXRUNTIME_API_TESTS_WITHOUT_ENV_SRC_DIR "${ONNXRUNTIME_ROOT}/test/api_tests_without_env")
|
||||
|
|
@ -525,7 +532,7 @@ else()
|
|||
target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public")
|
||||
endif()
|
||||
onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest onnx onnx_proto flatbuffers)
|
||||
onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest GTest::gmock onnx onnx_proto flatbuffers)
|
||||
|
||||
if (onnxruntime_USE_DNNL)
|
||||
target_compile_definitions(onnxruntime_test_utils PUBLIC USE_DNNL=1)
|
||||
|
|
|
|||
|
|
@ -168,10 +168,13 @@ class IExecutionProvider {
|
|||
void InsertAllocator(AllocatorPtr allocator);
|
||||
void ReplaceAllocator(AllocatorPtr allocator);
|
||||
|
||||
// creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support
|
||||
// compilation via GraphViewer instances.
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/**
|
||||
Given a list of fused_node, return create_state/compute/release_state func for each node.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
|
||||
/**
|
||||
|
|
@ -181,9 +184,53 @@ class IExecutionProvider {
|
|||
Compute_${node_name}
|
||||
Release_State_${node_name}
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_node,
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::string& dll_path);
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
struct FusedNodeAndGraph {
|
||||
const std::reference_wrapper<onnxruntime::Node> fused_node;
|
||||
// GraphViewer that filters the full graph to the nodes that are covered by 'node'
|
||||
const std::reference_wrapper<GraphViewer> filtered_graph;
|
||||
};
|
||||
|
||||
/**
|
||||
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
|
||||
return create_state/compute/release_state func for each node.
|
||||
@remarks This is an optional interface that is only needed if the execution provider compiles nodes
|
||||
in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model.
|
||||
|
||||
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
|
||||
as it is only valid for the duration of the call to Compile.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
#endif
|
||||
|
||||
// Fusion approach that is suppported
|
||||
enum class FusionStyle {
|
||||
// The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
|
||||
// in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
|
||||
// A GraphProto can be produced from the Node body.
|
||||
Function,
|
||||
|
||||
// The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
|
||||
// that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
|
||||
// Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
|
||||
// This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
|
||||
// and can be supported in a minimal build.
|
||||
FilteredGraphViewer
|
||||
};
|
||||
|
||||
virtual FusionStyle GetFusionStyle() const {
|
||||
// existing EPs use this mode so default to it.
|
||||
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
|
||||
// FilteredGraphViewer
|
||||
return FusionStyle::Function;
|
||||
}
|
||||
|
||||
void SetLogger(const logging::Logger* logger) {
|
||||
logger_ = logger;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,9 +45,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
|
||||
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
|
||||
|
||||
common::Status GetFusedFuncs(ComputeFunc* compute,
|
||||
CreateFunctionStateFunc* create,
|
||||
DestroyFunctionStateFunc* release) const;
|
||||
common::Status GetFusedFuncs(NodeComputeInfo*& compute_info) const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_MOVE(OpKernelInfo);
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class Tensor final {
|
|||
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ",
|
||||
"T ", "!=", dtype_);
|
||||
const T* data = reinterpret_cast<const T*>(static_cast<char*>(p_data_) + byte_offset_);
|
||||
return gsl::make_span(data, shape_.Size());
|
||||
return gsl::make_span(data, static_cast<typename gsl::span<T>::index_type>(shape_.Size()));
|
||||
}
|
||||
|
||||
void* MutableDataRaw(MLDataType type) {
|
||||
|
|
|
|||
|
|
@ -37,49 +37,50 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
|
|||
constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
|
||||
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
|
||||
constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider";
|
||||
constexpr const char *providers_available[] = {
|
||||
kCpuExecutionProvider,
|
||||
|
||||
constexpr const char* providers_available[] = {
|
||||
kCpuExecutionProvider,
|
||||
#ifdef USE_CUDA
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DNNL
|
||||
kDnnlExecutionProvider,
|
||||
kDnnlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NGRAPH
|
||||
kNGraphExecutionProvider,
|
||||
kNGraphExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_OPENVINO
|
||||
kOpenVINOExecutionProvider,
|
||||
kOpenVINOExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NUPHAR
|
||||
kNupharExecutionProvider,
|
||||
kNupharExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_VITISAI
|
||||
kVitisAIExecutionProvider,
|
||||
kVitisAIExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_TENSORRT
|
||||
kTensorrtExecutionProvider,
|
||||
kTensorrtExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
kNnapiExecutionProvider,
|
||||
kNnapiExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_RKNPU
|
||||
kRknpuExecutionProvider,
|
||||
kRknpuExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
kDmlExecutionProvider,
|
||||
kDmlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_MIGRAPHX
|
||||
kMIGraphXExecutionProvider,
|
||||
kMIGraphXExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ACL
|
||||
kAclExecutionProvider,
|
||||
kAclExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ARMNN
|
||||
kArmNNExecutionProvider,
|
||||
kArmNNExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
kRocmExecutionProvider,
|
||||
kRocmExecutionProvider,
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Function {
|
|||
/**
|
||||
Create a new Function instance.
|
||||
@param graph The graph containing the Function.
|
||||
@param customized_func the IndexedSubGraph to use for the Function.
|
||||
@param nodes_to_fuse the IndexedSubGraph to use for the Function.
|
||||
*/
|
||||
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse,
|
||||
|
|
|
|||
|
|
@ -368,7 +368,9 @@ class Node {
|
|||
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
|
||||
|
||||
/** Sets the execution ProviderType that this Node will be executed by. */
|
||||
void SetExecutionProviderType(ProviderType execution_provider_type);
|
||||
void SetExecutionProviderType(ProviderType execution_provider_type) {
|
||||
execution_provider_type_ = execution_provider_type;
|
||||
}
|
||||
|
||||
/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
|
||||
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
|
||||
|
|
@ -476,7 +478,7 @@ class Node {
|
|||
|
||||
Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void Init(const std::string& name,
|
||||
const std::string& op_type,
|
||||
const std::string& description,
|
||||
|
|
@ -485,25 +487,24 @@ class Node {
|
|||
const NodeAttributes* attributes,
|
||||
const std::string& domain);
|
||||
|
||||
// create a Graph instance for an attribute that contains a GraphProto
|
||||
void CreateSubgraph(const std::string& attr_name);
|
||||
|
||||
// internal only method to allow selected classes to directly alter the input/output definitions and arg counts
|
||||
Definitions& MutableDefinitions() noexcept;
|
||||
|
||||
// internal only method to allow selected classes to directly alter the links between nodes.
|
||||
Relationships& MutableRelationships() noexcept;
|
||||
|
||||
void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; }
|
||||
#endif
|
||||
|
||||
// create a Graph instance for an attribute that contains a GraphProto
|
||||
void CreateSubgraph(const std::string& attr_name);
|
||||
|
||||
const std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
|
||||
|
||||
void SetNodeType(Node::Type node_type) noexcept;
|
||||
|
||||
void SetFunctionBody(const Function& func);
|
||||
|
||||
// validate and update the input arg count
|
||||
common::Status UpdateInputArgCount();
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
void SetFunctionBody(const Function& func);
|
||||
|
||||
const Definitions& GetDefinitions() const noexcept { return definitions_; }
|
||||
const Relationships& GetRelationships() const noexcept { return relationships_; }
|
||||
|
|
@ -578,6 +579,15 @@ class Graph {
|
|||
/** Gets the path of the owning model, if any. */
|
||||
const Path& ModelPath() const;
|
||||
|
||||
/** Returns true if this is a subgraph or false if it is a high-level graph. */
|
||||
bool IsSubgraph() const { return parent_graph_ != nullptr; }
|
||||
|
||||
/** Returns the parent graph if this is a subgraph */
|
||||
const Graph* ParentGraph() const { return parent_graph_; }
|
||||
|
||||
/** Returns the mutable parent graph if this is a subgraph */
|
||||
Graph* MutableParentGraph() { return parent_graph_; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Sets the Graph name. */
|
||||
void SetName(const std::string& name);
|
||||
|
|
@ -756,6 +766,15 @@ class Graph {
|
|||
/** Generate a unique name in this Graph for a Node */
|
||||
std::string GenerateNodeName(const std::string& base_name);
|
||||
|
||||
/** Copy a Node and add it to this Graph.
|
||||
@param other Node to copy
|
||||
@returns Reference to the Node that was created and added to this Graph.
|
||||
@remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
|
||||
*/
|
||||
Node& AddNode(const Node& other);
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
/** Add a Node to this Graph.
|
||||
@param name The Node name. Must be unique in this Graph.
|
||||
@param op_type The operator type. e.g. ONNX operator name.
|
||||
|
|
@ -775,13 +794,6 @@ class Graph {
|
|||
const NodeAttributes* attributes = nullptr,
|
||||
const std::string& domain = "");
|
||||
|
||||
/** Copy a Node and add it to this Graph.
|
||||
@param other Node to copy
|
||||
@returns Reference to the Node that was created and added to this Graph.
|
||||
@remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
|
||||
*/
|
||||
Node& AddNode(const Node& other);
|
||||
|
||||
/** Remove a Node from this Graph and free it.
|
||||
The output edges of this specified node MUST have been removed before removing the node.
|
||||
The input edges of this specified node is removed while removing the node. The process of
|
||||
|
|
@ -809,14 +821,15 @@ class Graph {
|
|||
@param dst_arg_index node arg index of destination node.
|
||||
*/
|
||||
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/**
|
||||
Add a control edge between two Nodes in this Graph.
|
||||
The source Node does not produce output that is directly consumed by the destination Node, however the
|
||||
destination Node must execute after the source node. The control edge allows this ordering to occur.
|
||||
*/
|
||||
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index);
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
/** Mark the Graph as needing Resolve() to be called.
|
||||
|
|
@ -880,6 +893,7 @@ class Graph {
|
|||
const std::function<bool(const Node*, const Node*)>& comp,
|
||||
const std::function<bool(const Node*, const Node*)>& stop) const;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Performs topological sort with Kahn's algorithm on the graph/s.
|
||||
@param enter Visit function that will be invoked on a node when it is visited.
|
||||
@param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
|
||||
|
|
@ -887,11 +901,29 @@ class Graph {
|
|||
void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
||||
const std::function<bool(const Node*, const Node*)>& comp) const;
|
||||
|
||||
#endif
|
||||
|
||||
/** Gets the map of operator domains to their opset versions. */
|
||||
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
|
||||
return domain_to_version_;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
/**
|
||||
Create a single Node that will be the result of the a fusion of multiple nodes in this Graph.
|
||||
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
|
||||
@param fused_node_name The name for the new Node.
|
||||
@returns Node with fused subgraph.
|
||||
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
|
||||
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
|
||||
while this is in use.
|
||||
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
|
||||
*/
|
||||
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
|
||||
|
||||
void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Gets the GraphProto representation of this Graph. */
|
||||
const ONNX_NAMESPACE::GraphProto& ToGraphProto();
|
||||
|
|
@ -901,10 +933,11 @@ class Graph {
|
|||
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const;
|
||||
|
||||
/**
|
||||
Create a single Node that is the result of the a fusion of multiple nodes in this Graph.
|
||||
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
|
||||
Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph.
|
||||
A new Graph instance will be created for the fused nodes.
|
||||
@param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node
|
||||
@param fused_node_name The name for the new Node.
|
||||
@returns Node with fused subgraph.
|
||||
@returns Function based Node with fused subgraph. The Node body will contain a Function instance.
|
||||
*/
|
||||
Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
|
||||
|
||||
|
|
@ -939,18 +972,7 @@ class Graph {
|
|||
@remarks Note that the output order matters for subgraphs.
|
||||
*/
|
||||
void SetOutputs(const std::vector<const NodeArg*>& outputs);
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
/** Returns true if this is a subgraph or false if it is a high-level graph. */
|
||||
bool IsSubgraph() const { return parent_graph_ != nullptr; }
|
||||
|
||||
/** Returns the parent graph if this is a subgraph */
|
||||
const Graph* ParentGraph() const { return parent_graph_; }
|
||||
|
||||
/** Returns the mutable parent graph if this is a subgraph */
|
||||
Graph* MutableParentGraph() { return parent_graph_; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/** Sets the type of a NodeArg, replacing existing type/shape if any */
|
||||
void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto);
|
||||
|
||||
|
|
@ -1209,12 +1231,6 @@ class Graph {
|
|||
// Clear all unused initializers
|
||||
void CleanUnusedInitializers(const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr);
|
||||
|
||||
gsl::not_null<Node*> AllocateNode();
|
||||
|
||||
// Release the node.
|
||||
// @returns false if node_index was invalid.
|
||||
bool ReleaseNode(NodeIndex node_index);
|
||||
|
||||
std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
|
||||
const ArgNameToTypeMap& name_to_type_map);
|
||||
|
||||
|
|
@ -1247,6 +1263,16 @@ class Graph {
|
|||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
gsl::not_null<Node*> AllocateNode();
|
||||
|
||||
// Release the node.
|
||||
// @returns false if node_index was invalid.
|
||||
bool ReleaseNode(NodeIndex node_index);
|
||||
|
||||
Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
|
||||
#endif
|
||||
|
||||
Node* NodeAtIndexImpl(NodeIndex node_index) const {
|
||||
// if we are trying to access a node that doesn't exist there's (most
|
||||
// likely) either a logic issue or a graph consistency/correctness issue.
|
||||
|
|
@ -1276,12 +1302,10 @@ class Graph {
|
|||
sparse_tensor_names_;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;
|
||||
|
||||
std::vector<std::unique_ptr<onnxruntime::Function>> function_container_;
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
#endif
|
||||
|
||||
// Graph nodes.
|
||||
// Element in <nodes_> may be nullptr due to graph optimization.
|
||||
|
|
|
|||
|
|
@ -28,21 +28,21 @@ class ValidNodes {
|
|||
Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection
|
||||
@param[in] nodes Nodes to iterate, skipping invalid entries.
|
||||
*/
|
||||
explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(nodes) {}
|
||||
explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {}
|
||||
|
||||
explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept
|
||||
: nodes_(nodes), filter_node_fn_{std::move(filter_node_fn)} {}
|
||||
: nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {}
|
||||
|
||||
using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
|
||||
using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
|
||||
using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
|
||||
|
||||
ConstNodeIterator cbegin() const noexcept {
|
||||
return {nodes_.cbegin(), nodes_.cend(), filter_node_fn_};
|
||||
return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_};
|
||||
}
|
||||
|
||||
ConstNodeIterator cend() const noexcept {
|
||||
return {nodes_.cend(), nodes_.cend(), filter_node_fn_};
|
||||
return {nodes_->cend(), nodes_->cend(), filter_node_fn_};
|
||||
}
|
||||
|
||||
ConstNodeIterator begin() const noexcept {
|
||||
|
|
@ -54,11 +54,11 @@ class ValidNodes {
|
|||
}
|
||||
|
||||
ConstReverseNodeIterator rbegin() const noexcept {
|
||||
return {nodes_.crbegin(), nodes_.crend(), filter_node_fn_};
|
||||
return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_};
|
||||
}
|
||||
|
||||
ConstReverseNodeIterator rend() const noexcept {
|
||||
return {nodes_.crend(), nodes_.crend(), filter_node_fn_};
|
||||
return {nodes_->crend(), nodes_->crend(), filter_node_fn_};
|
||||
}
|
||||
|
||||
// we only allow mutable access if the container is non-const.
|
||||
|
|
@ -66,16 +66,16 @@ class ValidNodes {
|
|||
template <typename T2 = TNodesContainer>
|
||||
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type begin() noexcept {
|
||||
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
|
||||
return MutableNodeIterator(nodes_.begin(), nodes_.end(), filter_node_fn_);
|
||||
return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_);
|
||||
}
|
||||
|
||||
template <typename T2 = TNodesContainer>
|
||||
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type end() noexcept {
|
||||
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
|
||||
return MutableNodeIterator(nodes_.end(), nodes_.end(), filter_node_fn_);
|
||||
return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_);
|
||||
}
|
||||
|
||||
bool empty() const noexcept { return nodes_.empty(); }
|
||||
bool empty() const noexcept { return nodes_->empty(); }
|
||||
|
||||
/**
|
||||
@class NodeIterator
|
||||
|
|
@ -152,7 +152,7 @@ class ValidNodes {
|
|||
};
|
||||
|
||||
private:
|
||||
TNodesContainer& nodes_;
|
||||
gsl::not_null<TNodesContainer*> nodes_; // always set by ctor
|
||||
|
||||
// no filtering if not set. this instance owns the filter func if set.
|
||||
NodeFilterFunc filter_node_fn_;
|
||||
|
|
|
|||
|
|
@ -158,6 +158,11 @@ class GraphViewer {
|
|||
}
|
||||
#endif
|
||||
|
||||
/** Get the filter info that restricts the graph viewer to a subset of nodes if set.
|
||||
@returns Filter info or nullptr
|
||||
*/
|
||||
const IndexedSubGraph* GetFilterInfo() const { return filter_info_; }
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);
|
||||
GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright 2019 JD.com Inc. JD AI
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ inline SessionOptions& SessionOptions::AddInitializer(const char* name, const Or
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions * options, OrtCUDAProviderOptions * cuda_options) {
|
||||
inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions* options, OrtCUDAProviderOptions* cuda_options) {
|
||||
ThrowOnError(GetApi().OrtSessionOptionsAppendExecutionProvider_CUDA(options, cuda_options));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -943,7 +943,8 @@ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext*
|
|||
return out;
|
||||
}
|
||||
|
||||
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
|
||||
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
|
||||
_In_ const int64_t* dim_values, size_t dim_count) {
|
||||
OrtValue* out;
|
||||
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
|
||||
return out;
|
||||
|
|
|
|||
|
|
@ -43,9 +43,11 @@ IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
|
||||
return result;
|
||||
#else
|
||||
// We have saved hashes to lookup static kernels in an ORT format model so the default behavior is to return an
|
||||
// empty vector to leave that in place. An EP that compiles nodes can override this in a minimal build.
|
||||
ORT_UNUSED_PARAMETER(graph);
|
||||
ORT_UNUSED_PARAMETER(kernel_registries);
|
||||
ORT_NOT_IMPLEMENTED("IExecutionProvider::GetCapability is not supported in this build.");
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -79,6 +81,7 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
|
|||
allocator_list_.push_back(allocator);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
|
|
@ -88,6 +91,14 @@ common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>
|
|||
std::string& /*dll_path*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<KernelRegistry> IExecutionProvider::GetKernelRegistry() const {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -18,33 +18,33 @@ class FunctionKernel : public OpKernel {
|
|||
explicit FunctionKernel(const OpKernelInfo& info) : OpKernel(info) {
|
||||
num_inputs_ = info.node().InputDefs().size();
|
||||
num_outputs_ = info.node().OutputDefs().size();
|
||||
CreateFunctionStateFunc create_func;
|
||||
auto status = info.GetFusedFuncs(&func_, &create_func, &release_func_);
|
||||
auto status = info.GetFusedFuncs(compute_info_);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
if (create_func) {
|
||||
if (compute_info_->create_state_func) {
|
||||
//TODO: we are only provide host allocate method in compute context.
|
||||
//Do we need to hold the ref-counting here?
|
||||
host_allocator_ = info.GetAllocator(0, OrtMemType::OrtMemTypeDefault);
|
||||
ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(), info.node().Name().c_str()};
|
||||
ORT_ENFORCE(create_func(&context, &func_state_) == 0);
|
||||
ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(),
|
||||
info.node().Name().c_str()};
|
||||
ORT_ENFORCE(compute_info_->create_state_func(&context, &func_state_) == 0);
|
||||
}
|
||||
}
|
||||
|
||||
~FunctionKernel() override {
|
||||
if (release_func_ && func_state_) {
|
||||
release_func_(func_state_);
|
||||
if (compute_info_->release_state_func && func_state_) {
|
||||
compute_info_->release_state_func(func_state_);
|
||||
}
|
||||
}
|
||||
|
||||
virtual Status Compute(OpKernelContext* context) const override {
|
||||
auto* context_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
return func_(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION), reinterpret_cast<OrtKernelContext*>(context_internal));
|
||||
return compute_info_->compute_func(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION),
|
||||
reinterpret_cast<OrtKernelContext*>(context_internal));
|
||||
}
|
||||
|
||||
private:
|
||||
ComputeFunc func_;
|
||||
DestroyFunctionStateFunc release_func_;
|
||||
FunctionState func_state_;
|
||||
NodeComputeInfo* compute_info_{nullptr};
|
||||
FunctionState func_state_{nullptr};
|
||||
size_t num_inputs_;
|
||||
size_t num_outputs_;
|
||||
AllocatorPtr host_allocator_;
|
||||
|
|
|
|||
|
|
@ -6,25 +6,28 @@ Status FuncManager::AddFuncInfo(const std::string& name, const std::string& dll_
|
|||
auto it = fused_funcs_->find(name);
|
||||
if (it != fused_funcs_->end())
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist.");
|
||||
(*fused_funcs_)[name] = {dll_path, nullptr, nullptr, nullptr};
|
||||
(*fused_funcs_)[name] = {dll_path, NodeComputeInfo()};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuncManager::AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release) {
|
||||
Status FuncManager::AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info) {
|
||||
auto it = fused_funcs_->find(name);
|
||||
if (it != fused_funcs_->end())
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist.");
|
||||
if (!compute || !create || !release)
|
||||
|
||||
if (!compute_info.compute_func || !compute_info.create_state_func || !compute_info.release_state_func)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Can't use func with null ptr");
|
||||
(*fused_funcs_)[name] = {"", compute, create, release};
|
||||
|
||||
(*fused_funcs_)[name] = {"", std::move(compute_info)};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
|
||||
Status FuncManager::GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const {
|
||||
auto it = fused_funcs_->find(name);
|
||||
if (it == fused_funcs_->end())
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " not found.");
|
||||
if (!it->second.compute_func) {
|
||||
|
||||
if (!it->second.compute_info.compute_func) {
|
||||
//load from path
|
||||
void* handle = nullptr;
|
||||
ORT_RETURN_IF_ERROR(lib_loader_->LoadExternalLib(it->second.dso_path, &handle));
|
||||
|
|
@ -40,21 +43,20 @@ Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, Crea
|
|||
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle,
|
||||
kReleaseStateFuncSymbol + name,
|
||||
&release_func_symbol_handle));
|
||||
it->second.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
|
||||
it->second.compute_info.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
|
||||
return reinterpret_cast<ComputeFuncC>(compute_func_symbol_handle)(state, api, context);
|
||||
};
|
||||
|
||||
it->second.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
it->second.compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
return reinterpret_cast<CreateFunctionStateC>(create_func_symbol_handle)(context, state);
|
||||
};
|
||||
|
||||
it->second.release_state_func = [=](FunctionState state) {
|
||||
it->second.compute_info.release_state_func = [=](FunctionState state) {
|
||||
return reinterpret_cast<DestroyFunctionStateC>(release_func_symbol_handle)(state);
|
||||
};
|
||||
}
|
||||
*compute = it->second.compute_func;
|
||||
*create = it->second.create_state_func;
|
||||
*release = it->second.release_state_func;
|
||||
|
||||
compute_info = &it->second.compute_info;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,16 @@ namespace onnxruntime {
|
|||
|
||||
class FuncManager {
|
||||
public:
|
||||
FuncManager() : fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(onnxruntime::make_unique<ExLibLoader>()) {}
|
||||
FuncManager()
|
||||
: fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()),
|
||||
lib_loader_(onnxruntime::make_unique<ExLibLoader>()) {
|
||||
}
|
||||
|
||||
Status AddFuncInfo(const std::string& name, const std::string& dll_path);
|
||||
|
||||
Status AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release);
|
||||
Status AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info);
|
||||
|
||||
Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const;
|
||||
Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const;
|
||||
|
||||
void SetFusedFuncs(const FuncManager& func_mgr) {
|
||||
fused_funcs_ = func_mgr.fused_funcs_;
|
||||
|
|
@ -21,9 +24,7 @@ class FuncManager {
|
|||
|
||||
struct FuncInfo {
|
||||
std::string dso_path;
|
||||
ComputeFunc compute_func;
|
||||
CreateFunctionStateFunc create_state_func;
|
||||
DestroyFunctionStateFunc release_state_func;
|
||||
NodeComputeInfo compute_info;
|
||||
};
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include "core/framework/graph_partitioner.h"
|
||||
#include "core/framework/kernel_registry_manager.h"
|
||||
#include "core/graph/function.h"
|
||||
|
|
@ -41,13 +43,22 @@ NonCudaOps non_cuda;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
|
||||
// minimal KernelDef based on MetaDef instead of a Function based node
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef,
|
||||
const std::string& provider_type) {
|
||||
builder.SetName(metadef.name)
|
||||
.SetDomain(metadef.domain)
|
||||
.SinceVersion(metadef.since_version)
|
||||
.Provider(provider_type);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
|
||||
auto schema = node.Op();
|
||||
builder.SetName(schema->Name())
|
||||
.SetDomain(schema->domain())
|
||||
.SinceVersion(schema->SinceVersion())
|
||||
.Provider(node.GetExecutionProviderType());
|
||||
return builder;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -57,12 +68,16 @@ static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const on
|
|||
* \param capability
|
||||
* \param kernel_registry_mgr
|
||||
* \param provider_type name of the provider to test
|
||||
* \param count A counter for generating fused node names. Should be unique within this subgraph
|
||||
* \param count A counter for generating fused node names. Unique across the entire model.
|
||||
* \return Fused node. Return nullptr if there is no fuse
|
||||
*/
|
||||
static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
||||
const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type,
|
||||
int& count) {
|
||||
IExecutionProvider::FusionStyle fusion_style,
|
||||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id) {
|
||||
Node* result = nullptr;
|
||||
|
||||
if (nullptr == capability.GetMetaDef()) {
|
||||
// The <provider> can run a single node in the <graph> if not using meta-defs.
|
||||
// A fused kernel is not supported in this case.
|
||||
|
|
@ -75,39 +90,82 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
|||
}
|
||||
} else {
|
||||
// The <provider> can run a fused <sub_graph> in the <graph>.
|
||||
ORT_ENFORCE(nullptr != capability.GetMetaDef());
|
||||
// Check whether any node in the <sub_graph> was already assigned.
|
||||
|
||||
// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
|
||||
// in order of EP priority
|
||||
bool sub_graph_available_for_assignment = true;
|
||||
for (auto node_index : capability.nodes) {
|
||||
auto node = graph.GetNode(node_index);
|
||||
const auto* node = graph.GetNode(node_index);
|
||||
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
|
||||
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
|
||||
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
|
||||
sub_graph_available_for_assignment = false;
|
||||
break;
|
||||
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
|
||||
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
|
||||
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
|
||||
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
|
||||
// should never be on those lists.
|
||||
//
|
||||
// when the ORT format model is loaded we will process it normally with EP priority being applied for
|
||||
// whichever EPs are enabled at the time.
|
||||
//
|
||||
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
|
||||
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
|
||||
// so we want all the nodes that either may take to be preserved. If we did not do this we would
|
||||
// need to create one ORT format model for Android and one for iOS.
|
||||
if (mode != GraphPartitioner::Mode::kAssignOnly) {
|
||||
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
|
||||
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
|
||||
sub_graph_available_for_assignment = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sub_graph_available_for_assignment) {
|
||||
std::ostringstream oss;
|
||||
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << count++;
|
||||
std::string node_name = oss.str();
|
||||
auto& fused_node = graph.FuseSubGraph(capability, node_name);
|
||||
fused_node.SetExecutionProviderType(provider_type);
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, fused_node, provider_type)) {
|
||||
return &fused_node;
|
||||
if (mode == GraphPartitioner::Mode::kNormal) {
|
||||
std::ostringstream oss;
|
||||
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++;
|
||||
std::string node_name = oss.str();
|
||||
|
||||
Node* fused_node = nullptr;
|
||||
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
|
||||
fused_node = &graph.FuseSubGraph(capability, node_name);
|
||||
} else {
|
||||
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
|
||||
// through to Compile via a filtered GraphViewer.
|
||||
fused_node = &graph.BeginFuseSubGraph(capability, node_name);
|
||||
}
|
||||
|
||||
fused_node->SetExecutionProviderType(provider_type);
|
||||
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) {
|
||||
result = fused_node;
|
||||
}
|
||||
} else {
|
||||
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
|
||||
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
|
||||
// at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device
|
||||
// capabilities.
|
||||
for (auto node_index : capability.nodes) {
|
||||
auto* node = graph.GetNode(node_index);
|
||||
if (node != nullptr) {
|
||||
node->SetExecutionProviderType(provider_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up).
|
||||
// assign any nodes to the EP that are currently unassigned, and that the EP can handle.
|
||||
static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep) {
|
||||
static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep,
|
||||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id) {
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
if (graph.NumberOfNodes() == 0) {
|
||||
|
|
@ -119,8 +177,8 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr
|
|||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
// we pass through the export_dll value and FuncManager from the top level graph
|
||||
ORT_RETURN_IF_ERROR(PartitionImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, fused_kernel_registry,
|
||||
current_ep));
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -134,65 +192,135 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr
|
|||
// TODO: when the graph contain a function node, and user pass in the dll which could
|
||||
// run the function by SessionOption, we should create a function kernel for it and
|
||||
// delegate the compute to the functions inside the dlls.
|
||||
int count = 0;
|
||||
std::vector<Node*> nodes_need_compile;
|
||||
const std::string& type = current_ep.Type();
|
||||
auto fusion_style = current_ep.GetFusionStyle();
|
||||
std::vector<Node*> nodes_to_compile;
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type()));
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
|
||||
|
||||
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
|
||||
// nodes_to_compile.
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
|
||||
capabilities_to_compile.reserve(std::count_if(capabilities.cbegin(), capabilities.cend(),
|
||||
[](const std::unique_ptr<ComputeCapability>& entry) {
|
||||
return entry != nullptr &&
|
||||
entry->sub_graph != nullptr &&
|
||||
entry->sub_graph->GetMetaDef() != nullptr;
|
||||
}));
|
||||
|
||||
for (auto& capability : capabilities) {
|
||||
if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value...
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, current_ep.Type(), count);
|
||||
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id);
|
||||
if (n != nullptr) {
|
||||
nodes_need_compile.push_back(n);
|
||||
nodes_to_compile.push_back(n);
|
||||
capabilities_to_compile.push_back(std::move(capability));
|
||||
}
|
||||
}
|
||||
|
||||
if (!nodes_need_compile.empty()) {
|
||||
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
|
||||
if (!nodes_to_compile.empty()) {
|
||||
std::vector<NodeComputeInfo> node_compute_funcs;
|
||||
|
||||
if (export_dll) {
|
||||
std::string dll_path;
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, dll_path));
|
||||
for (auto* node : nodes_need_compile) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
|
||||
}
|
||||
} else {
|
||||
std::vector<NodeComputeInfo> node_compute_funcs;
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, node_compute_funcs));
|
||||
ORT_ENFORCE(node_compute_funcs.size() == nodes_need_compile.size(),
|
||||
"Provider did not return correct number of compiled functions");
|
||||
for (size_t j = 0; j < nodes_need_compile.size(); j++) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_need_compile[j]->Name(), node_compute_funcs[j].compute_func,
|
||||
node_compute_funcs[j].create_state_func,
|
||||
node_compute_funcs[j].release_state_func));
|
||||
}
|
||||
ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function,
|
||||
"Must use Function based fusion when exporting compiled nodes to dll.");
|
||||
}
|
||||
|
||||
for (auto* node : nodes_need_compile) {
|
||||
//prepare the func kernel
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, *node);
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder, static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
})));
|
||||
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
|
||||
// Create a Function based node where the fused nodes have a new Graph instance.
|
||||
|
||||
if (export_dll) {
|
||||
std::string dll_path;
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path));
|
||||
|
||||
for (auto* node : nodes_to_compile) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
|
||||
|
||||
if (node_compute_funcs.size() != nodes_to_compile.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* node : nodes_to_compile) {
|
||||
// add the KernelDef instances for the compiled nodes
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, *node);
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder,
|
||||
static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
})));
|
||||
}
|
||||
|
||||
} else {
|
||||
// temporary storage for the GraphViewer for each IndexedSubGraph
|
||||
std::vector<std::unique_ptr<GraphViewer>> viewers;
|
||||
viewers.reserve(nodes_to_compile.size());
|
||||
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
|
||||
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
auto* node = nodes_to_compile[j];
|
||||
const auto& cur_capability = *capabilities_to_compile[j];
|
||||
viewers.push_back(onnxruntime::make_unique<GraphViewer>(graph, *cur_capability.sub_graph));
|
||||
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{*node, *viewers.back()});
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
|
||||
|
||||
if (node_compute_funcs.size() != nodes_to_compile.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
auto* node = nodes_to_compile[j];
|
||||
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), std::move(node_compute_funcs[j])));
|
||||
|
||||
const auto& cur_capability = capabilities_to_compile[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
|
||||
// create the func kernel for the name in the MetaDef. this is also the node name and that name that will
|
||||
// used as the key in the FuncManager entry. We need the registry to own the KernelCreateInfo that is
|
||||
// used by SessionState
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder,
|
||||
static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
})));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if this is the main graph call Resolve to put the Graph back into a guaranteed good state
|
||||
// TODO: If we fix Graph::FuseSubGraph to correctly update edges when replacing nodes with a fused node we can
|
||||
// avoid this Resolve call.
|
||||
// TODO: Graph::FuseSubGraph and Graph::FinalizeFuseSubGraph should now create valid edges so this call to
|
||||
// Graph::Resolve should not be required. Need to test to validate that, especially if node being fused
|
||||
// was a control flow node with its own subgraph as more than just the edges may need updating.
|
||||
if (!graph.IsSubgraph()) {
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
|
||||
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
|
||||
//But we will insert cast op to run the model, so skip the error checking here.
|
||||
//If after graph transform phase, the node still not assigned, we will report error
|
||||
//during kernel creation phase.
|
||||
// For some cases, like fp16 on cpu, right now we don't have any kernel support that.
|
||||
// But we will insert cast op to run the model, so skip the error checking here.
|
||||
// If after graph transform phase, the node still not assigned, we will report error
|
||||
// during kernel creation phase.
|
||||
#ifdef COUNT_NON_CUDA_OPS
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.GetExecutionProviderType() != kCudaExecutionProvider &&
|
||||
|
|
@ -231,12 +359,157 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
|
|||
modified_graph = true;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode,
|
||||
int& fused_node_unique_id) const {
|
||||
bool modified_graph = false;
|
||||
|
||||
do {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id));
|
||||
}
|
||||
|
||||
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
|
||||
modified_graph = false;
|
||||
ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph));
|
||||
|
||||
// Resolve and rerun graph partitioning and inlining if there was a change
|
||||
if (modified_graph) {
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
} while (modified_graph);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const {
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep,
|
||||
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) {
|
||||
// recurse into nested graphs first to partition bottom up.
|
||||
for (auto& node : graph.Nodes()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry,
|
||||
current_ep, compiled_kernel_hashes, fused_node_unique_id));
|
||||
}
|
||||
}
|
||||
|
||||
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
|
||||
// doing it here saves all providers checking for this in GetCapability
|
||||
if (graph.NumberOfNodes() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::string& type = current_ep.Type();
|
||||
GraphViewer graph_viewer(graph);
|
||||
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities =
|
||||
current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type));
|
||||
|
||||
// storage for the GraphViewer for each IndexedSubGraph
|
||||
std::vector<std::unique_ptr<GraphViewer>> viewers;
|
||||
viewers.reserve(capabilities.size());
|
||||
|
||||
for (auto& capability : capabilities) {
|
||||
const IndexedSubGraph& indexed_sub_graph = *capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef* metadef = indexed_sub_graph.GetMetaDef();
|
||||
if (!metadef) {
|
||||
// Static kernel - use the kernel hash that was saved in the ORT format model
|
||||
continue;
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << type << "_" << metadef->name << "_" << fused_node_unique_id++;
|
||||
std::string node_name = oss.str();
|
||||
|
||||
Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name);
|
||||
fused_node.SetExecutionProviderType(type);
|
||||
|
||||
// create filtered graph viewer for this set of nodes
|
||||
//
|
||||
// TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing
|
||||
// GraphViewer instance instead of the Graph (copying the topological order instead of recalculating).
|
||||
viewers.push_back(onnxruntime::make_unique<GraphViewer>(graph, indexed_sub_graph));
|
||||
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()});
|
||||
}
|
||||
|
||||
std::vector<NodeComputeInfo> node_compute_funcs;
|
||||
node_compute_funcs.reserve(nodes_and_viewers.size());
|
||||
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
|
||||
|
||||
if (node_compute_funcs.size() != nodes_and_viewers.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) {
|
||||
Node& node = nodes_and_viewers[j].fused_node;
|
||||
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j])));
|
||||
|
||||
const auto& cur_capability = capabilities[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
auto kernel_def = builder.Build();
|
||||
|
||||
// save hash so SessionState can find the kernel. each kernel name should be unique
|
||||
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
|
||||
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
|
||||
". Execution Provider must generate unique names across the entire model.");
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
|
||||
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
}))));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Simplified partitioning where custom EPs may produce compiled nodes.
|
||||
// EPs with static kernels do not need to be processed as their kernels are matched via hash information serialized
|
||||
// as part of the ORT format model.
|
||||
Status GraphPartitioner::PartitionOrtFormatModel(
|
||||
Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) const {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
if (ep->Type() == kCpuExecutionProvider) {
|
||||
// hash for kernel is stored in session state for EPs that have pre-registered kernels
|
||||
// (vs. runtime fused kernels) so nothing to do here.
|
||||
continue;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry,
|
||||
*ep, compiled_kernel_hashes, fused_node_unique_id));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode,
|
||||
std::unordered_map<std::string, uint64_t>* compiled_kernel_hashes) const {
|
||||
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
|
||||
// 1. Execution providers' capabilities are checked one by one.
|
||||
// 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet.
|
||||
|
|
@ -253,21 +526,23 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
|
|||
// It is only visible for current session.
|
||||
std::shared_ptr<KernelRegistry> fused_kernel_registry = std::make_shared<KernelRegistry>();
|
||||
|
||||
bool modified_graph = false;
|
||||
do {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
ORT_RETURN_IF_ERROR(PartitionImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
|
||||
*fused_kernel_registry, *ep));
|
||||
}
|
||||
// we make sure each fused node name is unique across the entire model for clarity
|
||||
int fused_node_unique_id = 0;
|
||||
|
||||
modified_graph = false;
|
||||
// expand any nodes that have an ONNX function definition
|
||||
// but no matching ORT kernel
|
||||
ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph));
|
||||
// rerun graph partition to assign nodes added as part of
|
||||
// function expansion.
|
||||
} while (modified_graph);
|
||||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
|
||||
fused_node_unique_id));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(export_dll);
|
||||
ORT_THROW("Not supported in this build.");
|
||||
#endif
|
||||
} else {
|
||||
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
|
||||
fused_node_unique_id));
|
||||
}
|
||||
|
||||
if (!fused_kernel_registry->IsEmpty()) {
|
||||
kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry);
|
||||
|
|
@ -276,3 +551,5 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
|
|||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
|
@ -11,21 +13,43 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
class ExecutionProviders;
|
||||
class KernelRegistry;
|
||||
class KernelRegistryManager;
|
||||
|
||||
class GraphPartitioner {
|
||||
public:
|
||||
enum class Mode {
|
||||
kNormal = 0,
|
||||
kAssignOnly = 1, // assign nodes. no call to Compile. used to create ORT format model support for compiling EPs
|
||||
kOrtFormatLoad = 2 // loading ORT format model. Partition with compiling EPs, GraphViewer based Compile.
|
||||
};
|
||||
|
||||
//The order of providers represents the user preference.
|
||||
GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers)
|
||||
: kernel_registry_mgr_(kernel_registry_mgr),
|
||||
providers_(providers) {}
|
||||
providers_(providers) {
|
||||
}
|
||||
|
||||
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const;
|
||||
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
|
||||
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Mode mode = Mode::kNormal,
|
||||
std::unordered_map<std::string, uint64_t>* compiled_kernel_hashes = nullptr) const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const;
|
||||
#endif
|
||||
|
||||
Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry,
|
||||
std::unordered_map<std::string, uint64_t>& compiled_kernel_hashes,
|
||||
int& fused_node_unique_id) const;
|
||||
|
||||
KernelRegistryManager& kernel_registry_mgr_;
|
||||
const ExecutionProviders& providers_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -45,14 +45,16 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry) {
|
||||
if (nullptr == kernel_registry) {
|
||||
return;
|
||||
}
|
||||
custom_kernel_registries_.push_front(kernel_registry);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
|
||||
std::vector<const KernelRegistry*> kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
|
||||
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
|
||||
|
|
@ -84,10 +86,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node
|
|||
return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. "));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info);
|
||||
if (status.IsOK()) return status;
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -32,8 +32,7 @@ class KernelRegistryManager {
|
|||
// Register kernels from providers
|
||||
Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// The registry passed in this function has highest priority than anything already in this KernelRegistryManager,
|
||||
// and anything registered from RegisterKernels
|
||||
// For example, if you do:
|
||||
|
|
@ -43,16 +42,6 @@ class KernelRegistryManager {
|
|||
// Then B > A > providers
|
||||
void RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry);
|
||||
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status SearchKernelRegistry(const onnxruntime::Node& node,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
||||
/**
|
||||
* Whether this node can be run on this provider
|
||||
*/
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
|
||||
|
||||
/**
|
||||
* Search kernel registry by provider type.
|
||||
* @param type provider type string
|
||||
|
|
@ -70,6 +59,18 @@ class KernelRegistryManager {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status SearchKernelRegistry(const onnxruntime::Node& node,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
||||
/**
|
||||
* Whether this node can be run on this provider
|
||||
*/
|
||||
static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
|
||||
#endif
|
||||
|
||||
Status SearchKernelRegistry(const onnxruntime::Node& node,
|
||||
uint64_t kernel_def_hash,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
|
|||
return true;
|
||||
}
|
||||
|
||||
common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
|
||||
return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release);
|
||||
common::Status OpKernelInfo::GetFusedFuncs(NodeComputeInfo*& compute_info) const {
|
||||
return funcs_mgr_.GetFuncs(node_.Name(), compute_info);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -114,7 +114,8 @@ void SessionState::CreateGraphInfo() {
|
|||
for (const auto& output : graph_viewer_->GetOutputs()) {
|
||||
if (output->Exists()) {
|
||||
idx = ort_value_name_idx_map_.Add(output->Name());
|
||||
VLOGS(logger_, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
|
||||
VLOGS(logger_, 1) << "Added graph output with name: " << output->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -122,10 +123,25 @@ void SessionState::CreateGraphInfo() {
|
|||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) {
|
||||
Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager,
|
||||
bool saving_ort_format) {
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci));
|
||||
|
||||
auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
if (!status.IsOK() && saving_ort_format) {
|
||||
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
|
||||
// in that case we assigned the node to that EP but do not compile it into a fused node.
|
||||
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
|
||||
// we now revert to the CPU EP to include the hash for the kernel as a fallback. at runtime when the model
|
||||
// is loaded in a minimal build, the compiling EP will replace this node if possible. if that's not possible for
|
||||
// some reason we can fallback to the CPU EP implementation via this hash.
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
ORT_IGNORE_RETURN_VALUE(
|
||||
kernel_create_info_map_.insert({node.Index(), gsl::not_null<const KernelCreateInfo*>(kci)}));
|
||||
}
|
||||
|
|
@ -133,7 +149,7 @@ Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_regi
|
|||
for (const auto& entry : subgraph_session_states_) {
|
||||
for (const auto& name_to_subgraph_session_state : entry.second) {
|
||||
SessionState& subgraph_session_state = *name_to_subgraph_session_state.second;
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager));
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -813,16 +829,45 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
|
|||
"Size mismatch for kernel create info node indexes and hashes. Invalid ORT format model.",
|
||||
node_indices->size(), " != ", kernel_def_hashes->size());
|
||||
|
||||
auto add_kernel_by_hash =
|
||||
[&kernel_registry_manager, this](const Node& node, uint64_t hash) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, hash, &kci));
|
||||
kernel_create_info_map_.emplace(node.Index(), gsl::not_null<const KernelCreateInfo*>(kci));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// kernel hashes for model are in top level SessionState
|
||||
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();
|
||||
|
||||
// process the nodes that existed when the model was created
|
||||
for (flatbuffers::uoffset_t i = 0; i < node_indices->size(); i++) {
|
||||
auto node_idx = node_indices->Get(i);
|
||||
auto kernal_hash = kernel_def_hashes->Get(i);
|
||||
auto kernel_hash = kernel_def_hashes->Get(i);
|
||||
|
||||
const Node* node = graph_.GetNode(node_idx);
|
||||
ORT_RETURN_IF(node == nullptr, "Can't find node with index ", node_idx, ". Invalid ORT format model.");
|
||||
if (node == nullptr) {
|
||||
// this is OK if we have compiled kernels and the original node was replaced. if not the model is invalid.
|
||||
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
|
||||
"Can't find node with index ", node_idx, ". Invalid ORT format model.");
|
||||
continue;
|
||||
}
|
||||
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(*node, kernal_hash, &kci));
|
||||
kernel_create_info_map_.emplace(node_idx, gsl::not_null<const KernelCreateInfo*>(kci));
|
||||
ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_hash));
|
||||
}
|
||||
|
||||
// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
|
||||
// as they were created at runtime.
|
||||
if (!compiled_kernel_hashes.empty()) {
|
||||
for (const auto& node : graph_.Nodes()) {
|
||||
if (kernel_create_info_map_.count(node.Index()) == 0) {
|
||||
auto hash_info = compiled_kernel_hashes.find(node.OpType());
|
||||
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
|
||||
"Unable to find compiled kernel hash for node '", node.Name(), "'.")
|
||||
|
||||
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!subgraph_session_states_.empty()) {
|
||||
|
|
@ -880,7 +925,8 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
|
|||
KernelRegistryManager& kernel_registry_manager,
|
||||
const SessionOptions& session_options,
|
||||
const onnxruntime::experimental::fbs::SessionState* serialized_session_state,
|
||||
bool remove_initializers) {
|
||||
bool remove_initializers,
|
||||
bool saving_ort_format) {
|
||||
// recursively create the subgraph session state instances and populate the kernel create info in them.
|
||||
// it's simpler to handle the kernel create info recursively when deserializing,
|
||||
// so also do it recursively when calling PopulateKernelCreateInfo for consistency.
|
||||
|
|
@ -896,12 +942,13 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
|
|||
|
||||
} else {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager));
|
||||
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(graph_location);
|
||||
ORT_UNUSED_PARAMETER(kernel_registry_manager);
|
||||
ORT_UNUSED_PARAMETER(session_options);
|
||||
ORT_UNUSED_PARAMETER(remove_initializers);
|
||||
ORT_UNUSED_PARAMETER(saving_ort_format);
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Serialized session state must be provided from an ORT format model in this build.");
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -269,6 +269,10 @@ class SessionState {
|
|||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
void SetCompiledKernelHashes(std::unordered_map<std::string, uint64_t>&& compiled_kernel_hashes) {
|
||||
compiled_kernel_hashes_ = std::move(compiled_kernel_hashes);
|
||||
}
|
||||
|
||||
Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::SessionState& fbs_session_state,
|
||||
const KernelRegistryManager& kernel_registry_manager);
|
||||
#endif
|
||||
|
|
@ -277,7 +281,8 @@ class SessionState {
|
|||
KernelRegistryManager& kernel_registry_manager,
|
||||
const SessionOptions& session_options = {},
|
||||
const onnxruntime::experimental::fbs::SessionState* serialized_session_state = nullptr,
|
||||
bool remove_initializers = true);
|
||||
bool remove_initializers = true,
|
||||
bool saving_ort_format = false);
|
||||
|
||||
SessionState* Parent() {
|
||||
return parent_;
|
||||
|
|
@ -312,7 +317,7 @@ class SessionState {
|
|||
std::unique_ptr<SessionState> session_state);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager);
|
||||
Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager, bool saving_ort_format);
|
||||
#endif
|
||||
|
||||
Status FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
|
|
@ -330,9 +335,18 @@ class SessionState {
|
|||
std::unordered_map<int, TensorShape>& inferred_shapes) const;
|
||||
#endif
|
||||
|
||||
// the SessionState for the main Graph contains the compiled kernel hashes for the entire model
|
||||
const std::unordered_map<std::string, uint64_t>& GetCompiledKernelHashes() const {
|
||||
return parent_ ? parent_->GetCompiledKernelHashes() : compiled_kernel_hashes_;
|
||||
}
|
||||
|
||||
// KernelCreateInfo for each node so we do kernel lookup once
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map_;
|
||||
|
||||
// If we compile kernels in a minimal build we need a way to find the kernel using the hash.
|
||||
// We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat.
|
||||
std::unordered_map<std::string, uint64_t> compiled_kernel_hashes_;
|
||||
|
||||
// cache of the constructed kernels to avoid spending construction time per executor
|
||||
std::vector<OpKernel*> session_kernels_;
|
||||
Graph& graph_;
|
||||
|
|
|
|||
|
|
@ -101,7 +101,10 @@ bool ProviderIsCpuBased(const std::string& provider_type) {
|
|||
provider_type == onnxruntime::kVitisAIExecutionProvider ||
|
||||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
|
||||
provider_type == onnxruntime::kNnapiExecutionProvider ||
|
||||
provider_type == onnxruntime::kRknpuExecutionProvider;
|
||||
provider_type == onnxruntime::kAclExecutionProvider ||
|
||||
provider_type == onnxruntime::kArmNNExecutionProvider ||
|
||||
provider_type == onnxruntime::kRknpuExecutionProvider ||
|
||||
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
|
||||
}
|
||||
|
||||
static common::Status AllocateHelper(const AllocatorPtr& allocator,
|
||||
|
|
@ -574,16 +577,16 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context)
|
|||
const Tensor* curr_input = context->Input<Tensor>(i);
|
||||
|
||||
ORT_ENFORCE(prev_input->Shape().Size() >= 0);
|
||||
|
||||
|
||||
size_t input_element_count = static_cast<size_t>(prev_input->Shape().Size());
|
||||
size_t input_element_size = prev_input->DataType()->Size();
|
||||
size_t input_aligned_bytes = 0;
|
||||
|
||||
|
||||
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes));
|
||||
|
||||
|
||||
ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
|
||||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
|
||||
|
||||
|
||||
prev_input = curr_input;
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -40,6 +40,13 @@ void DefaultFree(void* p);
|
|||
|
||||
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
|
||||
|
||||
// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
|
||||
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
|
||||
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
|
||||
|
||||
// return true if the execution provider is CPU based (meaning no copies to device are required)
|
||||
bool ProviderIsCpuBased(const std::string& provider_type);
|
||||
|
||||
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
|
||||
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
|
||||
|
||||
|
|
|
|||
|
|
@ -144,6 +144,34 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su
|
|||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse) {
|
||||
const auto* meta_def = nodes_to_fuse.GetMetaDef();
|
||||
auto op_schema = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
op_schema->SetName(meta_def->name);
|
||||
op_schema->SetDomain(meta_def->domain);
|
||||
op_schema->SetDoc(meta_def->doc_string);
|
||||
op_schema->SinceVersion(meta_def->since_version);
|
||||
int i = 0;
|
||||
|
||||
for (auto& input : meta_def->inputs) {
|
||||
auto input_arg = graph.GetNodeArg(input);
|
||||
// inputs must have a type. can be inferred for outputs.
|
||||
ORT_ENFORCE(input_arg->Type() != nullptr);
|
||||
op_schema->Input(i, input, "", *input_arg->Type());
|
||||
++i;
|
||||
}
|
||||
i = 0;
|
||||
for (auto& output : meta_def->outputs) {
|
||||
auto output_arg = graph.GetNodeArg(output);
|
||||
op_schema->Output(i, output, "", *output_arg->Type());
|
||||
++i;
|
||||
}
|
||||
op_schema->Finalize();
|
||||
|
||||
return op_schema;
|
||||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse,
|
||||
const logging::Logger& logger)
|
||||
|
|
@ -154,12 +182,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
graph.DomainToVersionMap(), {}, logger) {
|
||||
auto& function_body_graph = body_.MainGraph();
|
||||
|
||||
auto meta_def = nodes_to_fuse.GetMetaDef();
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
op_schema_->SetName(meta_def->name);
|
||||
op_schema_->SetDomain(meta_def->domain);
|
||||
op_schema_->SetDoc(meta_def->doc_string);
|
||||
op_schema_->SinceVersion(meta_def->since_version);
|
||||
auto* meta_def = nodes_to_fuse.GetMetaDef();
|
||||
op_schema_ = CreateSchema(graph, nodes_to_fuse);
|
||||
|
||||
int i = 0;
|
||||
std::vector<const NodeArg*> function_body_graph_inputs;
|
||||
|
|
@ -168,8 +192,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
auto input_arg = parent_graph_->GetNodeArg(input);
|
||||
auto& function_body_graph_input_arg = function_body_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
|
||||
function_body_graph_inputs[i] = &function_body_graph_input_arg;
|
||||
ORT_ENFORCE(input_arg->Type() != nullptr);
|
||||
op_schema_->Input(i, input, "", *input_arg->Type());
|
||||
++i;
|
||||
}
|
||||
|
||||
|
|
@ -180,12 +202,9 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
auto output_arg = parent_graph_->GetNodeArg(output);
|
||||
auto& function_body_graph_output_arg = function_body_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
|
||||
function_body_graph_outputs[i] = &function_body_graph_output_arg;
|
||||
op_schema_->Output(i, output, "", *output_arg->Type());
|
||||
++i;
|
||||
}
|
||||
|
||||
op_schema_->Finalize();
|
||||
|
||||
function_body_graph.SetInputs(function_body_graph_inputs);
|
||||
function_body_graph.SetOutputs(function_body_graph_outputs);
|
||||
|
||||
|
|
@ -238,7 +257,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
// Hence, we make a copy prior to generating the graph representation of the function,
|
||||
// as we might make some modifications to the FunctionProto along the way
|
||||
|
||||
auto node_in_parent_graph = parent_graph_->GetNode(node_index);
|
||||
const auto* node_in_parent_graph = parent_graph_->GetNode(node_index);
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
op_schema_->SetName(onnx_func_proto_.name());
|
||||
op_schema_->SetDomain(onnx_func_proto_.node().Get(0).domain());
|
||||
|
|
@ -256,14 +275,13 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
auto cached_op_schema = node_in_parent_graph->Op();
|
||||
if (!cached_op_schema) {
|
||||
// Infer a op_schema for stand-alone functions.
|
||||
IOTypeConstraintHelper(onnx_func_proto_, this->op_schema_, input_name_idx_map, output_name_idx_map);
|
||||
IOTypeConstraintHelper(onnx_func_proto_, op_schema_, input_name_idx_map, output_name_idx_map);
|
||||
} else {
|
||||
auto type_constraint_params = cached_op_schema->typeConstraintParams();
|
||||
for (auto& type_constraint_param : type_constraint_params) {
|
||||
op_schema_->TypeConstraint(
|
||||
type_constraint_param.type_param_str,
|
||||
type_constraint_param.allowed_type_strs,
|
||||
type_constraint_param.description);
|
||||
op_schema_->TypeConstraint(type_constraint_param.type_param_str,
|
||||
type_constraint_param.allowed_type_strs,
|
||||
type_constraint_param.description);
|
||||
}
|
||||
int i = 0;
|
||||
for (auto& input : cached_op_schema->inputs()) {
|
||||
|
|
@ -286,10 +304,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
op_schema_->TypeAndShapeInferenceFunction(
|
||||
[this](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
|
||||
const ONNX_NAMESPACE::FunctionProto* func_ptr = this->GetFuncProto();
|
||||
if (nullptr != func_ptr) {
|
||||
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(func_ptr, schema_registry, ctx);
|
||||
}
|
||||
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(&onnx_func_proto_, schema_registry, ctx);
|
||||
});
|
||||
} else {
|
||||
op_schema_->TypeAndShapeInferenceFunction(cached_op_schema->GetTypeAndShapeInferenceFunction());
|
||||
|
|
@ -321,7 +336,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
}
|
||||
|
||||
for (int idx = 0; idx < (*node).input_size(); ++idx) {
|
||||
std::string tensor_name = (*node).input().Get(idx);
|
||||
const std::string& tensor_name = (*node).input().Get(idx);
|
||||
auto iter = input_name_idx_map.find(tensor_name);
|
||||
if (iter != input_name_idx_map.end()) {
|
||||
// Preserving NodeArg and input/output names
|
||||
|
|
@ -397,10 +412,14 @@ const onnxruntime::Graph& FunctionImpl::Body() const {
|
|||
return body_.MainGraph();
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
|
||||
return &onnx_func_proto_;
|
||||
ViewerFunctionImpl::ViewerFunctionImpl(const onnxruntime::Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse,
|
||||
const logging::Logger& /*logger*/) {
|
||||
op_schema_ = CreateSchema(graph, nodes_to_fuse);
|
||||
}
|
||||
|
||||
ViewerFunctionImpl::~ViewerFunctionImpl() = default;
|
||||
|
||||
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse,
|
||||
const logging::Logger& logger) {
|
||||
|
|
|
|||
|
|
@ -31,8 +31,6 @@ class FunctionImpl final : public Function {
|
|||
|
||||
const onnxruntime::Graph& Body() const override;
|
||||
|
||||
const ONNX_NAMESPACE::FunctionProto* GetFuncProto() const;
|
||||
|
||||
private:
|
||||
const onnxruntime::Graph* const parent_graph_;
|
||||
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
|
||||
|
|
@ -40,4 +38,22 @@ class FunctionImpl final : public Function {
|
|||
ONNX_NAMESPACE::FunctionProto onnx_func_proto_;
|
||||
};
|
||||
|
||||
// Function that uses a GraphViewer so does not need to build a new Model. We still need the OpSchema to be available
|
||||
// though so we just create that.
|
||||
class ViewerFunctionImpl final : public Function {
|
||||
public:
|
||||
ViewerFunctionImpl(const onnxruntime::Graph& graph,
|
||||
const IndexedSubGraph& nodes_to_fuse,
|
||||
const logging::Logger& logger);
|
||||
|
||||
~ViewerFunctionImpl() override;
|
||||
|
||||
const ONNX_NAMESPACE::OpSchema& OpSchema() const override { return *op_schema_; }
|
||||
|
||||
const onnxruntime::Graph& Body() const override { ORT_THROW("Not supported"); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -133,21 +133,26 @@ static TypeProto TypeProtoFromTensorProto(const TensorProto& tensor) {
|
|||
|
||||
return t;
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
NodeArg::NodeArg(const std::string& name, const TypeProto* p_node_arg_type) {
|
||||
node_arg_info_.set_name(name);
|
||||
// If the name is empty, it means the arg does not exist.
|
||||
exists_ = !(name.empty());
|
||||
if (nullptr != p_node_arg_type) {
|
||||
(*node_arg_info_.mutable_type()) = *p_node_arg_type;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// should not be possible to have invalid values in the ORT format model, so we don't need this
|
||||
// in a minimal build
|
||||
RemoveInvalidValues(*node_arg_info_.mutable_type());
|
||||
#endif
|
||||
type_ = DataTypeUtils::ToType(node_arg_info_.type());
|
||||
} else {
|
||||
type_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
NodeArg::NodeArg(NodeArgInfo&& node_arg_info) {
|
||||
node_arg_info_ = std::move(node_arg_info);
|
||||
|
|
@ -427,10 +432,6 @@ void Node::SetPriority(int priority) noexcept {
|
|||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
void Node::SetNodeType(Node::Type node_type) noexcept {
|
||||
node_type_ = node_type;
|
||||
}
|
||||
|
||||
const Function* Node::GetFunctionBody(bool try_init_func_body) {
|
||||
if (nullptr != func_body_) {
|
||||
return func_body_;
|
||||
|
|
@ -450,10 +451,6 @@ void Node::SetFunctionBody(const Function& func) {
|
|||
since_version_ = op_->since_version();
|
||||
}
|
||||
|
||||
void Node::SetExecutionProviderType(ProviderType execution_provider_type) {
|
||||
execution_provider_type_ = execution_provider_type;
|
||||
}
|
||||
|
||||
void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
|
||||
proto.set_name(name_);
|
||||
proto.set_op_type(op_type_);
|
||||
|
|
@ -666,9 +663,9 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::experimental::fbs::NodeEd
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void Node::Init(const std::string& name,
|
||||
const std::string& op_type,
|
||||
const std::string& description,
|
||||
|
|
@ -697,7 +694,11 @@ void Node::Init(const std::string& name,
|
|||
|
||||
for (auto& name_to_attr : attributes_) {
|
||||
if (utils::HasGraph(name_to_attr.second)) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
CreateSubgraph(name_to_attr.first);
|
||||
#else
|
||||
ORT_THROW("Creating node with a subgraph via AddNode is not supported in this build.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -716,7 +717,9 @@ Node::Relationships& Node::MutableRelationships() noexcept {
|
|||
graph_->SetGraphProtoSyncNeeded();
|
||||
return relationships_;
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
void Node::CreateSubgraph(const std::string& attr_name) {
|
||||
auto attr = attributes_.find(attr_name);
|
||||
|
||||
|
|
@ -1264,7 +1267,9 @@ common::Status Graph::SetOuterScopeNodeArgs(const std::unordered_set<std::string
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_slot, int dst_arg_slot) {
|
||||
if (nodes_.size() <= src_node_index || src_arg_slot < 0 || nodes_.size() <= dst_node_index || dst_arg_slot < 0 ||
|
||||
nullptr == nodes_[src_node_index] || nullptr == nodes_[dst_node_index]) {
|
||||
|
|
@ -1349,7 +1354,9 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s
|
|||
nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot));
|
||||
nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot));
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
|
||||
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
|
||||
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
|
||||
|
|
@ -1596,6 +1603,7 @@ void Graph::ReverseDFSFrom(const std::vector<const Node*>& from,
|
|||
}
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
||||
const std::function<bool(const Node*, const Node*)>& comp) const {
|
||||
std::unordered_map<NodeIndex, size_t> in_degree;
|
||||
|
|
@ -1634,7 +1642,6 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
|||
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
|
||||
}
|
||||
}
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
|
||||
Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
|
||||
|
|
@ -2896,7 +2903,9 @@ std::string Graph::GenerateNodeName(const std::string& base_name) {
|
|||
|
||||
return new_name;
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
Node& Graph::AddNode(const std::string& name,
|
||||
const std::string& op_type,
|
||||
const std::string& description,
|
||||
|
|
@ -2945,7 +2954,9 @@ bool Graph::RemoveNode(NodeIndex p_index) {
|
|||
|
||||
return ReleaseNode(p_index);
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
bool Graph::AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index) {
|
||||
if (nodes_.size() <= src_node_index ||
|
||||
nodes_.size() <= dst_node_index ||
|
||||
|
|
@ -3282,6 +3293,12 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const {
|
||||
return schema_registry_;
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// calling private ctor
|
||||
GSL_SUPPRESS(r .11)
|
||||
gsl::not_null<Node*> Graph::AllocateNode() {
|
||||
|
|
@ -3313,21 +3330,24 @@ bool Graph::ReleaseNode(NodeIndex index) {
|
|||
return true;
|
||||
}
|
||||
|
||||
IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const {
|
||||
return schema_registry_;
|
||||
}
|
||||
|
||||
Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
|
||||
auto* func_meta_def = sub_graph.GetMetaDef();
|
||||
Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
|
||||
const auto* func_meta_def = sub_graph.GetMetaDef();
|
||||
ORT_ENFORCE(nullptr != func_meta_def);
|
||||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
std::unordered_map<std::string, int> input_indexes;
|
||||
std::unordered_map<std::string, int> output_indexes;
|
||||
|
||||
int cur_idx = 0;
|
||||
for (auto& arg_name : func_meta_def->inputs) {
|
||||
input_args.push_back(GetNodeArg(arg_name));
|
||||
input_indexes[arg_name] = cur_idx++;
|
||||
}
|
||||
|
||||
cur_idx = 0;
|
||||
for (auto& arg_name : func_meta_def->outputs) {
|
||||
output_args.push_back(GetNodeArg(arg_name));
|
||||
output_indexes[arg_name] = cur_idx++;
|
||||
}
|
||||
|
||||
auto& fused_node = AddNode(fused_node_name,
|
||||
|
|
@ -3339,21 +3359,102 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& f
|
|||
func_meta_def->domain);
|
||||
|
||||
fused_node.SetNodeType(Node::Type::Fused);
|
||||
function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_));
|
||||
fused_node.SetFunctionBody(*function_container_.back());
|
||||
|
||||
// Remove nodes fused above.
|
||||
return fused_node;
|
||||
}
|
||||
|
||||
Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) {
|
||||
Node& node = CreateFusedSubGraphNode(sub_graph, fused_node_name);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// if this is a full build create the lightweight Function implementation that provides the schema so that
|
||||
// kernel lookup works as per usual. in an extended minimal build we do the lookup via a hash so don't
|
||||
// need to create the schema.
|
||||
auto func = onnxruntime::make_unique<ViewerFunctionImpl>(*this, sub_graph, logger_);
|
||||
function_container_.push_back(std::move(func));
|
||||
node.SetFunctionBody(*function_container_.back());
|
||||
#endif
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) {
|
||||
const auto* func_meta_def = sub_graph.GetMetaDef();
|
||||
ORT_ENFORCE(nullptr != func_meta_def);
|
||||
|
||||
std::unordered_map<std::string, int> input_indexes;
|
||||
std::unordered_map<std::string, int> output_indexes;
|
||||
|
||||
int cur_idx = 0;
|
||||
for (auto& arg_name : func_meta_def->inputs) {
|
||||
input_indexes[arg_name] = cur_idx++;
|
||||
}
|
||||
|
||||
cur_idx = 0;
|
||||
for (auto& arg_name : func_meta_def->outputs) {
|
||||
output_indexes[arg_name] = cur_idx++;
|
||||
}
|
||||
|
||||
auto new_node_idx = fused_node.Index();
|
||||
|
||||
// Remove nodes that were fused
|
||||
for (auto node_index : sub_graph.nodes) {
|
||||
auto node = GetNode(node_index);
|
||||
if (nullptr == node) {
|
||||
continue;
|
||||
}
|
||||
auto output_edges = node->GetRelationships().output_edges;
|
||||
for (auto output_edge : output_edges) {
|
||||
RemoveEdge(node->Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex());
|
||||
|
||||
// move any applicable input edges to the new node. remove all others
|
||||
auto input_edges = node->GetRelationships().input_edges; // copy so RemoveEdge doesn't invalidate iterator
|
||||
for (const auto& input_edge : input_edges) {
|
||||
const auto& producer = input_edge.GetNode();
|
||||
auto producer_idx = producer.Index();
|
||||
auto src_idx = input_edge.GetSrcArgIndex();
|
||||
auto dst_idx = input_edge.GetDstArgIndex();
|
||||
|
||||
// if this input is an input of the fused node add an edge for that
|
||||
auto it = input_indexes.find(node->InputDefs()[dst_idx]->Name());
|
||||
if (it != input_indexes.cend()) {
|
||||
AddEdge(producer_idx, new_node_idx, src_idx, it->second);
|
||||
}
|
||||
|
||||
RemoveEdge(producer_idx, node_index, src_idx, dst_idx);
|
||||
}
|
||||
|
||||
// move any applicable output edges to the new node
|
||||
auto output_edges = node->GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator
|
||||
for (const auto& output_edge : output_edges) {
|
||||
const auto& consumer = output_edge.GetNode();
|
||||
auto consumer_idx = consumer.Index();
|
||||
auto src_idx = output_edge.GetSrcArgIndex();
|
||||
auto dst_idx = output_edge.GetDstArgIndex();
|
||||
|
||||
// if this output is an output of the fused node add an edge for that
|
||||
auto it = output_indexes.find(node->OutputDefs()[src_idx]->Name());
|
||||
if (it != output_indexes.cend()) {
|
||||
AddEdge(new_node_idx, consumer_idx, it->second, dst_idx);
|
||||
}
|
||||
|
||||
RemoveEdge(node_index, consumer_idx, src_idx, dst_idx);
|
||||
}
|
||||
|
||||
RemoveNode(node_index);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph,
|
||||
const std::string& fused_node_name) {
|
||||
Node& fused_node = CreateFusedSubGraphNode(sub_graph, fused_node_name);
|
||||
|
||||
// create Function before we remove nodes
|
||||
function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_));
|
||||
fused_node.SetFunctionBody(*function_container_.back());
|
||||
|
||||
// remove nodes and update edges
|
||||
FinalizeFuseSubGraph(sub_graph, fused_node);
|
||||
|
||||
return fused_node;
|
||||
}
|
||||
|
|
@ -3541,6 +3642,7 @@ Graph::Graph(const Model& owning_model,
|
|||
schema_registry_(std::make_shared<SchemaRegistryManager>()),
|
||||
#endif
|
||||
domain_to_version_(domain_to_version),
|
||||
ir_version_(owning_model.IrVersion()),
|
||||
parent_graph_(parent_graph),
|
||||
parent_node_(parent_node),
|
||||
logger_(logger),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "transformer_memcpy.h"
|
||||
#include "core/framework/kernel_registry_manager.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
#include "core/framework/utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
|
|
@ -62,15 +63,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
|
|||
// and mainly provides the subgraph recursion functionality
|
||||
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
for (auto& provider : provider_types_) {
|
||||
if (provider != onnxruntime::kCpuExecutionProvider &&
|
||||
provider != onnxruntime::kDnnlExecutionProvider &&
|
||||
provider != onnxruntime::kNGraphExecutionProvider &&
|
||||
provider != onnxruntime::kNupharExecutionProvider &&
|
||||
provider != onnxruntime::kVitisAIExecutionProvider &&
|
||||
provider != onnxruntime::kOpenVINOExecutionProvider &&
|
||||
provider != onnxruntime::kNnapiExecutionProvider &&
|
||||
provider != onnxruntime::kAclExecutionProvider &&
|
||||
provider != onnxruntime::kArmNNExecutionProvider) {
|
||||
if (!utils::ProviderIsCpuBased(provider)) {
|
||||
TransformerMemcpyImpl copy_impl(graph, provider);
|
||||
auto current_modified = copy_impl.ModifyGraph(registry_manager_);
|
||||
modified = modified || current_modified;
|
||||
|
|
@ -163,9 +156,11 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
return modified;
|
||||
}
|
||||
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed) {
|
||||
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
|
||||
InitializedTensorSet& initializers_consumed) {
|
||||
auto node_provider_type = node.GetExecutionProviderType();
|
||||
if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
|
||||
if ((node_provider_type == provider_) ||
|
||||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
|
||||
provider_nodes_.insert(&node);
|
||||
// note KernelCreateInfo might be nullptr for custom kernel
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
|
|
|
|||
|
|
@ -745,45 +745,53 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
|
|||
// If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape
|
||||
// information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape
|
||||
/* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) {
|
||||
const auto& output = node.OutputDefs()[0]->Name();
|
||||
// We will go through all the output edges
|
||||
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& op_type = it->GetNode().OpType();
|
||||
// TODO add quantized matmul when reshape support quantized input
|
||||
if (op_type != "Gemm" && op_type != "MatMul") {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
|
||||
<< " or no op is using the output (output is graph output)"
|
||||
<< ", output name, " << output
|
||||
<< " is used by " << op_type;
|
||||
return false;
|
||||
}
|
||||
//
|
||||
// TEMPORARILY DISABLED. Needs refinement.
|
||||
//
|
||||
// const auto& output = node.OutputDefs()[0]->Name();
|
||||
// // We will go through all the output edges
|
||||
// for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
// const auto& op_type = it->GetNode().OpType();
|
||||
// // TODO add quantized matmul when reshape support quantized input
|
||||
// if (op_type != "Gemm" && op_type != "MatMul") {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul"
|
||||
// << " or no op is using the output (output is graph output)"
|
||||
// << ", output name, " << output
|
||||
// << " is used by " << op_type;
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
|
||||
if (it->GetDstArgIndex() != 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
|
||||
<< ", output name, " << output;
|
||||
return false;
|
||||
}
|
||||
// // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0
|
||||
// if (it->GetDstArgIndex() != 0) {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul"
|
||||
// << ", output name, " << output;
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// We only support 2d matmul/gemm here
|
||||
// And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
|
||||
if (input_rank < 2 || output_rank != 2) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
|
||||
<< ", output name, " << output
|
||||
<< ", the actual input_rank, " << input_rank
|
||||
<< ", the actual output_rank, " << output_rank;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// // We only support 2d matmul/gemm here
|
||||
// // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2
|
||||
// if (input_rank < 2 || output_rank != 2) {
|
||||
// LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2"
|
||||
// << ", output name, " << output
|
||||
// << ", the actual input_rank, " << input_rank
|
||||
// << ", the actual output_rank, " << output_rank;
|
||||
// return false;
|
||||
// }
|
||||
// }
|
||||
|
||||
// If we reach here, we have either,
|
||||
// all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
|
||||
// or
|
||||
// Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
|
||||
// we can skip this Reshape now
|
||||
LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
|
||||
<< node.Name() << "] with output, " << output;
|
||||
return true;
|
||||
// // If we reach here, we have either,
|
||||
// // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here]
|
||||
// // or
|
||||
// // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care]
|
||||
// // we can skip this Reshape now
|
||||
// LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node ["
|
||||
// << node.Name() << "] with output, " << output;
|
||||
// return true;
|
||||
|
||||
ORT_UNUSED_PARAMETER(node);
|
||||
ORT_UNUSED_PARAMETER(input_rank);
|
||||
ORT_UNUSED_PARAMETER(output_rank);
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
// Find inputs, initializers and outputs for each supported subgraph
|
||||
const std::vector<NodeIndex>& node_index = graph_view.GetNodesInTopologicalOrder();
|
||||
const auto& graph_outputs = graph_view.GetOutputs();
|
||||
int counter = 0;
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
if (group.empty())
|
||||
continue;
|
||||
|
|
@ -173,7 +172,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
|
||||
// Assign inputs and outputs to subgraph's meta_def
|
||||
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
|
||||
meta_def->name = "NNAPI_" + std::to_string(counter++);
|
||||
meta_def->name = "NNAPI_" + std::to_string(metadef_id_++);
|
||||
meta_def->domain = kMSDomain;
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
|
|
@ -202,6 +201,7 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort,
|
|||
const std::vector<uint32_t>& output_shape,
|
||||
const android::nn::wrapper::Type output_type,
|
||||
void** output_buffer) ORT_MUST_USE_RESULT;
|
||||
|
||||
static Status GetOutputBuffer(Ort::CustomOpApi& ort,
|
||||
OrtKernelContext* context,
|
||||
const nnapi::Model& model,
|
||||
|
|
@ -235,50 +235,43 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
using namespace android::nn::wrapper;
|
||||
for (const auto* fused_node : fused_nodes) {
|
||||
// Reconstruct graph proto from fused node's function body
|
||||
const auto* func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
|
||||
Node& fused_node = fused_node_and_graph.fused_node;
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
|
||||
const Graph& graph_body = func_body->Body();
|
||||
nnapi::ModelBuilder builder(graph_viewer);
|
||||
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
|
||||
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
|
||||
std::unique_ptr<nnapi::Model> nnapi_model;
|
||||
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
|
||||
|
||||
// Build map from input name to its index in input definitions
|
||||
{
|
||||
onnxruntime::GraphViewer graph_viewer(graph_body);
|
||||
nnapi::ModelBuilder builder(graph_viewer);
|
||||
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
|
||||
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
|
||||
std::unique_ptr<nnapi::Model> nnapi_model;
|
||||
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
|
||||
|
||||
// Build map from input name to its index in input definitions
|
||||
{
|
||||
std::unordered_map<std::string, size_t> input_map;
|
||||
const auto& input_defs = fused_node->InputDefs();
|
||||
input_map.reserve(input_defs.size());
|
||||
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
input_map[input_defs[i]->Name()] = i;
|
||||
}
|
||||
nnapi_model->SetInputMap(std::move(input_map));
|
||||
std::unordered_map<std::string, size_t> input_map;
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
input_map.reserve(input_defs.size());
|
||||
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
input_map[input_defs[i]->Name()] = i;
|
||||
}
|
||||
|
||||
// Build map from output name to its index in output definitions
|
||||
{
|
||||
std::unordered_map<std::string, size_t> output_map;
|
||||
const auto& output_defs = fused_node->OutputDefs();
|
||||
output_map.reserve(output_defs.size());
|
||||
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
|
||||
output_map[output_defs[i]->Name()] = i;
|
||||
}
|
||||
nnapi_model->SetOutputMap(std::move(output_map));
|
||||
}
|
||||
|
||||
nnapi_models_.emplace(fused_node->Name(), std::move(nnapi_model));
|
||||
nnapi_model->SetInputMap(std::move(input_map));
|
||||
}
|
||||
|
||||
// Build map from output name to its index in output definitions
|
||||
{
|
||||
std::unordered_map<std::string, size_t> output_map;
|
||||
const auto& output_defs = fused_node.OutputDefs();
|
||||
output_map.reserve(output_defs.size());
|
||||
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
|
||||
output_map[output_defs[i]->Name()] = i;
|
||||
}
|
||||
nnapi_model->SetOutputMap(std::move(output_map));
|
||||
}
|
||||
|
||||
nnapi_models_.emplace(fused_node.Name(), std::move(nnapi_model));
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) {
|
||||
*state = nnapi_models_[context->node_name].get();
|
||||
|
|
@ -432,20 +425,21 @@ common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::No
|
|||
return Status::OK();
|
||||
}
|
||||
#else
|
||||
common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status NnapiExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto* fused_node : fused_nodes) {
|
||||
ORT_UNUSED_PARAMETER(fused_node);
|
||||
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
|
||||
ORT_UNUSED_PARAMETER(fused_node_and_graph);
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; };
|
||||
compute_info.release_state_func = [](FunctionState /*state*/) {};
|
||||
compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */, OrtKernelContext* /* context */) {
|
||||
compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */,
|
||||
OrtKernelContext* /* context */) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build.");
|
||||
};
|
||||
node_compute_funcs.push_back(compute_info);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
#endif // __ANDROID__
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -19,11 +19,21 @@ class NnapiExecutionProvider : public IExecutionProvider {
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph_view,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
|
||||
// we implement the Compile that takes FusedNodeAndGraph instances
|
||||
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
#endif
|
||||
|
||||
unsigned long GetNNAPIFlags() const { return nnapi_flags_; }
|
||||
|
||||
private:
|
||||
// unique counter to name each fused kernel across the entire model
|
||||
mutable int metadef_id_{0};
|
||||
|
||||
// The bit flags which define bool options for NNAPI EP, bits are defined as
|
||||
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
|
||||
const unsigned long nnapi_flags_;
|
||||
|
|
|
|||
|
|
@ -793,7 +793,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
const ExecutionProviders& providers,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const InsertCastTransformer& insert_cast_transformer,
|
||||
SessionState& session_state) {
|
||||
SessionState& session_state,
|
||||
bool saving_model_in_ort_format) {
|
||||
// The transformer order:
|
||||
// 1. built-in graph rewriter
|
||||
// 2. each execution provider's transformer
|
||||
|
|
@ -822,10 +823,17 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
}
|
||||
#endif
|
||||
|
||||
// if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them.
|
||||
// we do this to preserve the original nodes in the model but prevent optimizers from changing them.
|
||||
// at runtime, the ORT format model will re-do the partitioning/compilation of these nodes, which may change
|
||||
// to cover fewer nodes due to device capabilities.
|
||||
auto mode = saving_model_in_ort_format ? GraphPartitioner::Mode::kAssignOnly
|
||||
: GraphPartitioner::Mode::kNormal;
|
||||
|
||||
// Do partitioning based on execution providers' capability.
|
||||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
session_state.GetMutableFuncMgr()));
|
||||
session_state.GetMutableFuncMgr(), mode));
|
||||
|
||||
// apply transformers except default transformers
|
||||
// Default transformers are required for correctness and they are owned and run by inference session
|
||||
|
|
@ -888,9 +896,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
||||
const ExecutionProviders& providers,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
SessionState& session_state) const {
|
||||
std::unordered_map<std::string, uint64_t> compiled_kernel_hashes;
|
||||
|
||||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
session_state.GetMutableFuncMgr(),
|
||||
GraphPartitioner::Mode::kOrtFormatLoad,
|
||||
&compiled_kernel_hashes));
|
||||
|
||||
if (!compiled_kernel_hashes.empty()) {
|
||||
session_state.SetCompiledKernelHashes(std::move(compiled_kernel_hashes));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
||||
template <typename T>
|
||||
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
|
||||
|
|
@ -1131,38 +1159,65 @@ common::Status InferenceSession::Initialize() {
|
|||
// Register 2nd registries into KernelRegistryManager.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
|
||||
|
||||
bool loading_ort_format = !ort_format_model_bytes_.empty();
|
||||
bool saving_model = !session_options_.optimized_model_filepath.empty();
|
||||
bool saving_ort_format = false;
|
||||
if (saving_model) {
|
||||
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, "");
|
||||
bool has_explicit_type = !model_type.empty();
|
||||
saving_ort_format = ((has_explicit_type && model_type == "ORT") ||
|
||||
(!has_explicit_type &&
|
||||
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)));
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// add predefined transformers
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
||||
transformers_to_enable_);
|
||||
if (!loading_ort_format) {
|
||||
// add predefined transformers
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
||||
transformers_to_enable_);
|
||||
|
||||
// apply any transformations to the main graph and any subgraphs
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
|
||||
execution_providers_, kernel_registry_manager_,
|
||||
insert_cast_transformer_,
|
||||
*session_state_));
|
||||
// apply any transformations to the main graph and any subgraphs
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
|
||||
execution_providers_, kernel_registry_manager_,
|
||||
insert_cast_transformer_,
|
||||
*session_state_,
|
||||
saving_ort_format));
|
||||
|
||||
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
|
||||
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
|
||||
|
||||
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
|
||||
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
|
||||
} else
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
{
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// nodes are already partitioned, but a custom EP may compile some at runtime.
|
||||
// run the partitioning to allow that to happen.
|
||||
//
|
||||
// We always have the CPU EP, so only need to run this if some other EP is enabled
|
||||
if (execution_providers_.NumProviders() > 1) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
|
||||
*session_state_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// need to keep the initializers if we're going to save the optimized model
|
||||
bool keep_initializers = !session_options_.optimized_model_filepath.empty();
|
||||
const experimental::fbs::SessionState* serialized_session_state =
|
||||
loading_ort_format
|
||||
? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state()
|
||||
: nullptr;
|
||||
|
||||
auto* serialized_session_state = !ort_format_model_bytes_.empty()
|
||||
? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state()
|
||||
: nullptr;
|
||||
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
|
||||
session_options_,
|
||||
serialized_session_state,
|
||||
!keep_initializers));
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
|
||||
session_options_,
|
||||
serialized_session_state,
|
||||
// need to keep the initializers if saving the optimized model
|
||||
!saving_model,
|
||||
saving_ort_format));
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
if (!session_options_.optimized_model_filepath.empty()) {
|
||||
if (saving_model) {
|
||||
if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
|
||||
LOGS(*session_logger_, WARNING)
|
||||
<< "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. "
|
||||
|
|
@ -1170,12 +1225,7 @@ common::Status InferenceSession::Initialize() {
|
|||
"and should only be used in the same environment the model was optimized for.";
|
||||
}
|
||||
|
||||
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, "");
|
||||
bool has_explicit_type = !model_type.empty();
|
||||
|
||||
if ((has_explicit_type && model_type == "ORT") ||
|
||||
(!has_explicit_type &&
|
||||
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) {
|
||||
if (saving_ort_format) {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
|
||||
|
|
|
|||
|
|
@ -57,9 +57,7 @@ class LoggingManager;
|
|||
*/
|
||||
struct ModelMetadata {
|
||||
ModelMetadata() = default;
|
||||
ModelMetadata(const ModelMetadata& other)
|
||||
: producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) {
|
||||
}
|
||||
ModelMetadata(const ModelMetadata&) = default;
|
||||
~ModelMetadata() = default;
|
||||
ModelMetadata& operator=(const ModelMetadata&) = delete;
|
||||
|
||||
|
|
@ -531,7 +529,8 @@ class InferenceSession {
|
|||
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
|
||||
const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager,
|
||||
const InsertCastTransformer& insert_cast_transformer,
|
||||
SessionState& session_state) ORT_MUST_USE_RESULT;
|
||||
SessionState& session_state,
|
||||
bool saving_model_in_ort_format) ORT_MUST_USE_RESULT;
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr_;
|
||||
|
||||
|
|
@ -543,6 +542,11 @@ class InferenceSession {
|
|||
std::vector<std::string> transformers_to_enable_;
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers,
|
||||
KernelRegistryManager& kernel_registry_manager, SessionState& session_state) const;
|
||||
#endif
|
||||
|
||||
SessionOptions session_options_;
|
||||
|
||||
/// Logging manager if provided.
|
||||
|
|
|
|||
|
|
@ -231,9 +231,10 @@ namespace py = pybind11;
|
|||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
static Env& platform_env = Env::Default();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Custom op section starts
|
||||
static Env& platform_env = Env::Default();
|
||||
|
||||
CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) {
|
||||
{
|
||||
|
|
@ -650,10 +651,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
#endif
|
||||
} else if (type == kRocmExecutionProvider) {
|
||||
#ifdef USE_ROCM
|
||||
RegisterExecutionProvider(
|
||||
sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id,
|
||||
cuda_mem_limit,
|
||||
arena_extend_strategy));
|
||||
RegisterExecutionProvider(
|
||||
sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id,
|
||||
cuda_mem_limit,
|
||||
arena_extend_strategy));
|
||||
#endif
|
||||
} else if (type == kDnnlExecutionProvider) {
|
||||
#ifdef USE_DNNL
|
||||
|
|
@ -685,11 +686,9 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
|
||||
} else if (option.first == "device_id") {
|
||||
openvino_device_id = option.second;
|
||||
}
|
||||
else if (option.first == "num_of_threads") {
|
||||
} else if (option.first == "num_of_threads") {
|
||||
num_of_threads = std::stoi(option.second);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
ORT_THROW("Invalid OpenVINO EP option: ", option.first);
|
||||
}
|
||||
}
|
||||
|
|
@ -946,7 +945,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
|
|||
ORT_UNUSED_PARAMETER(algo);
|
||||
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
|
||||
#else
|
||||
cudnn_conv_algo_search = algo;
|
||||
cudnn_conv_algo_search = algo;
|
||||
#endif
|
||||
});
|
||||
m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) {
|
||||
|
|
@ -954,7 +953,7 @@ void addGlobalMethods(py::module& m, const Environment& env) {
|
|||
ORT_UNUSED_PARAMETER(use_single_stream);
|
||||
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
|
||||
#else
|
||||
do_copy_in_default_stream = use_single_stream;
|
||||
do_copy_in_default_stream = use_single_stream;
|
||||
#endif
|
||||
});
|
||||
m.def("set_cuda_mem_limit", [](const int64_t limit) { cuda_mem_limit = static_cast<size_t>(limit); });
|
||||
|
|
@ -1734,7 +1733,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
.def("end_profiling", [](PyInferenceSession* sess) -> std::string {
|
||||
return sess->GetSessionHandle()->EndProfiling();
|
||||
})
|
||||
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t{
|
||||
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
|
||||
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
|
||||
})
|
||||
.def("get_providers", [](PyInferenceSession* sess) -> const std::vector<std::string>& {
|
||||
|
|
|
|||
3
onnxruntime/test/providers/internal_testing/README.md
Normal file
3
onnxruntime/test/providers/internal_testing/README.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
Internal test EP that is used to test and validate interactions between the ORT framework, optimizers and an EP.
|
||||
|
||||
Primary usage currently is validating support in a minimal build for an EP that compiles nodes into kernels at runtime.
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "internal_testing_execution_provider.h"
|
||||
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP";
|
||||
|
||||
InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops)
|
||||
: IExecutionProvider{utils::kInternalTestingExecutionProvider},
|
||||
ops_{ops} {
|
||||
// TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through
|
||||
// the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first
|
||||
// and fall back to using session state next by passing in a functor it can use to call SessionState::GetAllocator).
|
||||
|
||||
AllocatorCreationInfo device_info(
|
||||
[](int) {
|
||||
return onnxruntime::make_unique<CPUAllocator>(OrtMemoryInfo(INTERNAL_TESTING_EP,
|
||||
OrtAllocatorType::OrtDeviceAllocator));
|
||||
});
|
||||
|
||||
InsertAllocator(CreateAllocator(device_info));
|
||||
}
|
||||
|
||||
InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {}
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const {
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
/*
|
||||
Very basic search for groups of nodes that can be handled by the EP.
|
||||
This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
|
||||
but B is between them in the topological sort as you'll get two single node capabilities. However if can also
|
||||
be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
|
||||
Not sure how often each of these scenarios happens.
|
||||
|
||||
A B C
|
||||
| / |
|
||||
D E
|
||||
| |
|
||||
|
||||
Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
|
||||
accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
|
||||
connected nodes that can be handled are grouped together.
|
||||
*/
|
||||
|
||||
std::vector<std::vector<NodeIndex>> node_groups;
|
||||
std::vector<NodeIndex> cur_group;
|
||||
for (NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
if (ops_.count(graph_viewer.GetNode(node_index)->OpType())) {
|
||||
cur_group.push_back(node_index);
|
||||
} else if (!cur_group.empty()) {
|
||||
node_groups.push_back(std::move(cur_group));
|
||||
}
|
||||
}
|
||||
|
||||
if (!cur_group.empty()) {
|
||||
node_groups.push_back(std::move(cur_group));
|
||||
}
|
||||
|
||||
if (node_groups.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto& graph_output_list = graph_viewer.GetOutputs();
|
||||
std::unordered_set<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
|
||||
|
||||
for (const auto& group : node_groups) {
|
||||
std::unordered_set<NodeIndex> node_set;
|
||||
node_set.reserve(group.size());
|
||||
for (const auto& index : group) {
|
||||
node_set.insert(index);
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::make_unique<IndexedSubGraph>();
|
||||
|
||||
std::unordered_set<const NodeArg*> node_outputs;
|
||||
std::unordered_set<const NodeArg*> subgraph_inputs;
|
||||
std::unordered_set<const NodeArg*> subgraph_outputs;
|
||||
std::vector<const NodeArg*> ordered_subgraph_inputs;
|
||||
std::vector<const NodeArg*> ordered_subgraph_outputs;
|
||||
|
||||
for (const auto& index : group) {
|
||||
sub_graph->nodes.push_back(index);
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
|
||||
for (const auto* input : node->InputDefs()) {
|
||||
// if the node input was not produced by this subgraph, add it to the subgraph inputs.
|
||||
if (node_outputs.count(input) == 0) {
|
||||
if (subgraph_inputs.count(input) == 0) {
|
||||
subgraph_inputs.insert(input);
|
||||
ordered_subgraph_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto& output_defs = node->OutputDefs();
|
||||
for (const auto* output_def : output_defs) {
|
||||
node_outputs.insert(output_def);
|
||||
// if output is overall graph output we need to produce it.
|
||||
if (graph_outputs.count(output_def) != 0) {
|
||||
ordered_subgraph_outputs.push_back(output_def);
|
||||
}
|
||||
}
|
||||
|
||||
// if output connects to a node not in this subgraph we need to produce it
|
||||
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
|
||||
if (node_set.count(it->GetNode().Index()) == 0) {
|
||||
const auto* output_def = output_defs[it->GetSrcArgIndex()];
|
||||
if (subgraph_outputs.count(output_def) == 0) {
|
||||
subgraph_outputs.insert(output_def);
|
||||
ordered_subgraph_outputs.push_back(output_def);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assign inputs and outputs to subgraph's meta_def
|
||||
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
|
||||
meta_def->name = "InternalTestingEP_" + std::to_string(metadef_id_++);
|
||||
meta_def->domain = kMSDomain;
|
||||
meta_def->since_version = 1;
|
||||
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
|
||||
|
||||
for (const auto& input : ordered_subgraph_inputs) {
|
||||
meta_def->inputs.push_back(input->Name());
|
||||
}
|
||||
|
||||
for (const auto& output : ordered_subgraph_outputs) {
|
||||
meta_def->outputs.push_back(output->Name());
|
||||
}
|
||||
|
||||
sub_graph->SetMetaDef(std::move(meta_def));
|
||||
|
||||
result.push_back(onnxruntime::make_unique<ComputeCapability>(std::move(sub_graph)));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
common::Status InternalTestingExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
// Create a function to generate dummy empty output for each fused node so the model can be executed.
|
||||
for (const auto& node_and_viewer : fused_nodes) {
|
||||
NodeComputeInfo compute_info;
|
||||
const Node& node = node_and_viewer.fused_node;
|
||||
|
||||
//{
|
||||
// const GraphViewer& graph_viewer = node_and_viewer.filtered_graph;
|
||||
// std::cout << "Fusing nodes: ";
|
||||
// for (const auto& unfused_node : graph_viewer.Nodes()) {
|
||||
// std::cout << " '" << unfused_node.Name() << "':" << unfused_node.Index();
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//}
|
||||
|
||||
compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) {
|
||||
return 0;
|
||||
};
|
||||
|
||||
compute_info.release_state_func = [](FunctionState /*state*/) {
|
||||
};
|
||||
|
||||
compute_info.compute_func = [&node](FunctionState /*state*/, const OrtCustomOpApi* c_api,
|
||||
OrtKernelContext* context) -> Status {
|
||||
Ort::CustomOpApi api{*c_api}; // use C++ API for convenience
|
||||
|
||||
const auto outputs = node.OutputDefs();
|
||||
const size_t num_outputs = outputs.size();
|
||||
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
const auto* shape_proto = outputs[i]->Shape();
|
||||
if (shape_proto == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown output shapes are not supported");
|
||||
}
|
||||
|
||||
TensorShape shape = utils::GetTensorShapeFromTensorShapeProto(*shape_proto);
|
||||
if (shape.Size() < 0) {
|
||||
// arbitrarily set any unknown dim to 1
|
||||
for (size_t idx = 0, end = shape.NumDimensions(); idx < end; ++idx) {
|
||||
if (shape[idx] == -1) {
|
||||
shape[idx] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create the output_tensor.
|
||||
auto* ortvalue = api.KernelContext_GetOutput(context, i, shape.GetDims().data(), shape.GetDims().size());
|
||||
|
||||
// and fill with zeros
|
||||
auto* tensor = ortvalue->GetMutable<Tensor>();
|
||||
void* data = tensor->MutableDataRaw();
|
||||
auto bytes = tensor->SizeInBytes();
|
||||
memset(data, 0, bytes);
|
||||
};
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
node_compute_funcs.push_back(std::move(compute_info));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <set>
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class InternalTestingExecutionProvider : public IExecutionProvider {
|
||||
public:
|
||||
InternalTestingExecutionProvider(const std::unordered_set<std::string>& ops);
|
||||
virtual ~InternalTestingExecutionProvider();
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph_view,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
FusionStyle GetFusionStyle() const override {
|
||||
return FusionStyle::FilteredGraphViewer;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_set<std::string> ops_;
|
||||
|
||||
// unique counter to name each fused kernel across the entire model
|
||||
mutable int metadef_id_{0};
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/inference_session_wrapper.h"
|
||||
#include "test/util/include/test_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace test {
|
||||
|
||||
static void CreateSession(const SessionOptions& so, std::unique_ptr<InferenceSessionWrapper>& session,
|
||||
const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model
|
||||
bool enable_custom_ep = true,
|
||||
const std::unordered_set<std::string>* override_supported_ops = nullptr) {
|
||||
session = onnxruntime::make_unique<InferenceSessionWrapper>(so, GetEnvironment());
|
||||
|
||||
// set supported ops to ops that are ideally found consecutively in the model.
|
||||
// we can say the EP potentially handles them all, but can also test removing handling of one or more ops
|
||||
// at runtime to simulate a lower spec device where not all ops can be handled. this allows us to test
|
||||
// that we can revert ops back to the CPU implementation successfully
|
||||
const std::unordered_set<std::string> default_supported_ops{"Conv", "Add", "Relu", "MaxPool"};
|
||||
const std::unordered_set<std::string>* supported_ops = override_supported_ops ? override_supported_ops
|
||||
: &default_supported_ops;
|
||||
|
||||
if (enable_custom_ep) {
|
||||
ASSERT_STATUS_OK(session->RegisterExecutionProvider(
|
||||
onnxruntime::make_unique<InternalTestingExecutionProvider>(*supported_ops)));
|
||||
}
|
||||
|
||||
ASSERT_STATUS_OK(session->Load(model_path));
|
||||
ASSERT_STATUS_OK(session->Initialize());
|
||||
}
|
||||
|
||||
static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enabled) {
|
||||
// validate that we can execute the model. the dummy internal testing EP just creates empty output so the
|
||||
// values in the output aren't relevant. all we care about is that we can execute the model and produce output.
|
||||
OrtValue ml_value_x;
|
||||
TensorShape input_shape{1, 1, 28, 28};
|
||||
std::vector<float> input(input_shape.Size(), 1.f);
|
||||
|
||||
CreateMLValue<float>(input_shape.GetDims(), input.data(), OrtMemoryInfo(), &ml_value_x);
|
||||
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("Input3", ml_value_x));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("Plus214_Output_0");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
|
||||
|
||||
if (custom_ep_enabled) {
|
||||
// check that the output is all zeros. the dummy EP produces output of the correct shape with all zeros, so any
|
||||
// downstream operations should still result in zeros for this model
|
||||
// OR it should equal the bias in the final Add operation, which is in the Parameter194 initializer
|
||||
const auto& t = fetches[0].Get<Tensor>();
|
||||
const auto data = t.DataAsSpan<float>();
|
||||
|
||||
int idx = 0;
|
||||
const auto& session_state = session.GetSessionState();
|
||||
ASSERT_STATUS_OK(session_state.GetOrtValueNameIdxMap().GetIdx("Parameter194", idx));
|
||||
const auto& initializer = session_state.GetConstantInitializedTensors().at(idx);
|
||||
const auto expected = initializer.Get<Tensor>().DataAsSpan<float>();
|
||||
|
||||
ASSERT_THAT(data, ::testing::ContainerEq(expected));
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
TEST(InternalTestingEP, TestSaveAndLoadOrtModel) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.test_output.ort");
|
||||
|
||||
//
|
||||
// First load the onnx format model and save as an ORT model.
|
||||
// This should preserve the nodes the custom EP can handle.
|
||||
//
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
SessionOptions so;
|
||||
so.optimized_model_filepath = ort_model_path;
|
||||
|
||||
CreateSession(so, session);
|
||||
// this graph should include the original nodes that the custom EP will take at runtime
|
||||
auto num_nodes = session->GetGraph().NumberOfNodes();
|
||||
|
||||
//
|
||||
// Second, load the ORT format model with just the CPU EP to make sure it can be executed. This tests that the
|
||||
// fallback to the CPU EP kernel hashes works.
|
||||
//
|
||||
std::unique_ptr<InferenceSessionWrapper> session2;
|
||||
|
||||
so.optimized_model_filepath.clear();
|
||||
bool enable_custom_ep = false;
|
||||
|
||||
CreateSession(so, session2, ort_model_path, enable_custom_ep);
|
||||
const auto& graph1 = session2->GetGraph();
|
||||
// model should have all the original nodes and we should be able to execute with the fallback to CPU EP
|
||||
ASSERT_EQ(graph1.NumberOfNodes(), num_nodes);
|
||||
ExecuteMnist(*session2, enable_custom_ep);
|
||||
session2 = nullptr;
|
||||
|
||||
//
|
||||
// Finally, load the ORT format model with the custom EP enabled. This tests that we support runtime compilation
|
||||
// for the ORT format model.
|
||||
//
|
||||
enable_custom_ep = true;
|
||||
CreateSession(so, session2, ort_model_path, enable_custom_ep);
|
||||
const auto& graph2 = session2->GetGraph();
|
||||
// model should be able to be loaded, and we should compile using custom ep. that will result in one node for the
|
||||
// custom EP (with Conv/Add/Relu/MaxPool), one for a reshape, and one for the fused MatMul+Add.
|
||||
ASSERT_EQ(graph2.NumberOfNodes(), 3);
|
||||
ExecuteMnist(*session2, enable_custom_ep);
|
||||
}
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// test to validate a minimal build
|
||||
TEST(InternalTestingEP, TestLoadOrtModel) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
bool enable_custom_ep = true;
|
||||
|
||||
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep);
|
||||
ExecuteMnist(*session, enable_custom_ep);
|
||||
}
|
||||
|
||||
// test that is the custom EP cannot take all nodes due to device limitations
|
||||
// that we fallback to the CPU implementations and can execute the model
|
||||
TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
|
||||
const std::unordered_set<std::string> supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/};
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
bool enable_custom_ep = true;
|
||||
|
||||
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops);
|
||||
|
||||
const auto& graph = session->GetGraph();
|
||||
// Conv+Add gets fused by level 1 optimizer into single node. The 'Conv'/'Add'/'Relu' nodes should be compiled and
|
||||
// handled by the custom EP. fallback to CPU for MaxPool.
|
||||
ASSERT_EQ(graph.NumberOfNodes(), 6);
|
||||
const auto& func_mgr = session->GetSessionState().GetFuncMgr();
|
||||
NodeComputeInfo* compute_func = nullptr;
|
||||
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
|
||||
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
|
||||
if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
|
||||
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
|
||||
EXPECT_NE(compute_func, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
ExecuteMnist(*session, enable_custom_ep);
|
||||
}
|
||||
|
||||
TEST(InternalTestingEP, TestMinimalRegistrationOfEPwithGetCapability) {
|
||||
// TODO: In a full build we want to be able to call GetCapability for the NNAPI EP and produce an ORT format model
|
||||
// with nodes correctly preserved. That requires being able to do a minimal registration of that EP where
|
||||
// GetCapability is fully implemented, but Compile is a stub that just throws NOT_IMPLEMENTED if someone attempts
|
||||
// to execute a model in that InferenceSession.
|
||||
}
|
||||
|
||||
// count nodes assigned to the test EP and make sure they all have valid compute funcs
|
||||
static int CountAndValidateAssignedNodes(const Graph& current_graph,
|
||||
const std::unordered_set<std::string>& supported_ops,
|
||||
const FuncManager& func_mgr) {
|
||||
int count = 0;
|
||||
|
||||
for (const auto& node : current_graph.Nodes()) {
|
||||
EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0))
|
||||
<< "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not.";
|
||||
if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) {
|
||||
NodeComputeInfo* compute_func = nullptr;
|
||||
EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func));
|
||||
EXPECT_NE(compute_func, nullptr);
|
||||
++count;
|
||||
}
|
||||
|
||||
if (node.ContainsSubgraph()) {
|
||||
for (const auto& entry : node.GetSubgraphs()) {
|
||||
count += CountAndValidateAssignedNodes(*entry, supported_ops, func_mgr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
// Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs.
|
||||
// There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels.
|
||||
TEST(InternalTestingEP, TestModelWithSubgraph) {
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
|
||||
const std::unordered_set<std::string> supported_ops{"Add"};
|
||||
|
||||
std::unique_ptr<InferenceSessionWrapper> session;
|
||||
bool enable_custom_ep = true;
|
||||
|
||||
CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops);
|
||||
|
||||
const auto& graph = session->GetGraph();
|
||||
const auto& func_mgr = session->GetSessionState().GetFuncMgr();
|
||||
|
||||
int num_replaced_nodes = CountAndValidateAssignedNodes(graph, supported_ops, func_mgr);
|
||||
|
||||
// One Add node in the Loop. One Add node in each branch of the If inside the Loop body
|
||||
ASSERT_EQ(num_replaced_nodes, 3);
|
||||
|
||||
OrtValue ml_value;
|
||||
|
||||
// this is a bit of a hack. the correct output is the input value + 2, so if we start with -2 the result is 0.
|
||||
// the output from fused nodes using the testing EP is always 0, so we should match the expected output this way
|
||||
// as we replace all the Add nodes with something that returns 0.
|
||||
// RunAndVerifyOutputsWithEP checks that nodes are assigned to the EP so we know it's being used to execute the model
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {-2.f},
|
||||
&ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("state_var_in", ml_value));
|
||||
// compare outputs from CPU EP vs custom EP
|
||||
RunAndVerifyOutputsWithEP(ort_model_path,
|
||||
"InternalTestingEP.TestModelWithSubgraph",
|
||||
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops),
|
||||
feeds);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,13 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/inference_session_wrapper.h"
|
||||
#include "test/util/include/test/test_environment.h"
|
||||
#include "test/util/include/test_utils.h"
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// if this is a full build we need the provider test utils
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#endif
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -16,85 +28,48 @@ using namespace ::onnxruntime::logging;
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
#ifdef __ANDROID__
|
||||
void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
|
||||
const std::vector<float>& expected_values) {
|
||||
ASSERT_EQ(1, fetches.size());
|
||||
auto& rtensor = fetches.front().Get<Tensor>();
|
||||
TensorShape expected_shape(expected_dims);
|
||||
ASSERT_EQ(expected_shape, rtensor.Shape());
|
||||
const std::vector<float> found(rtensor.template Data<float>(), rtensor.template Data<float>() + expected_values.size());
|
||||
ASSERT_EQ(expected_values, found);
|
||||
}
|
||||
#endif
|
||||
|
||||
void RunAndVerifyOutputs(const std::string& model_file_name,
|
||||
const char* log_id,
|
||||
const NameMLValMap& feeds,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<int64_t>& expected_dims,
|
||||
const std::vector<float>& expected_values) {
|
||||
SessionOptions so;
|
||||
so.session_logid = log_id;
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
// Since we already know the model is entirely supported by NNAPI, all nodes (2 Add nodes here) will be fused
|
||||
// Get the graph after session is initialized, and verify the fused node (the only node in the graph) is using NNAPI EP
|
||||
const auto& graph = session_object.GetGraph();
|
||||
ASSERT_EQ(1, graph.NumberOfNodes()); // Make sure the graph has 1 fused node
|
||||
ASSERT_EQ(onnxruntime::kNnapiExecutionProvider, graph.Nodes().cbegin()->GetExecutionProviderType());
|
||||
|
||||
// The execution can only be performed on Android
|
||||
#ifdef __ANDROID__
|
||||
// Now run and verify the result
|
||||
std::vector<OrtValue> fetches;
|
||||
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(fetches, expected_dims, expected_values);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(feeds);
|
||||
ORT_UNUSED_PARAMETER(output_names);
|
||||
ORT_UNUSED_PARAMETER(expected_dims);
|
||||
ORT_UNUSED_PARAMETER(expected_values);
|
||||
#endif
|
||||
}
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// Since NNAPI EP handles Reshape and Flatten differently,
|
||||
// Please see ReshapeOpBuilder::CanSkipReshape in <repo_root>/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc
|
||||
// Please see ReshapeOpBuilder::CanSkipReshape in
|
||||
// <repo_root>/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc
|
||||
// We have a separated test for these skip reshape scenarios
|
||||
TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) {
|
||||
std::string model_file_name = "testdata/nnapi_reshape_flatten_test.onnx";
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/nnapi_reshape_flatten_test.onnx");
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
std::vector<int64_t> dims_mul_x = {2, 1, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
std::vector<int64_t> dims_mul_y = {3, 2, 2};
|
||||
std::vector<float> values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
|
||||
&ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y,
|
||||
&ml_value_y);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("Z");
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_mul_z = {1, 6};
|
||||
std::vector<float> expected_values_mul_z = {59.0f, 72.0f, 129.0f, 159.0f, 204.0f, 253.0f};
|
||||
|
||||
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", feeds, output_names, expected_dims_mul_z, expected_values_mul_z);
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest",
|
||||
onnxruntime::make_unique<NnapiExecutionProvider>(0),
|
||||
feeds);
|
||||
#else
|
||||
// test load only
|
||||
SessionOptions so;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
|
||||
<< "Some nodes should have been taken by the NNAPI EP";
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(NnapiExecutionProviderTest, FunctionTest) {
|
||||
std::string model_file_name = "nnapi_execution_provider_test_graph.onnx";
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("nnapi_execution_provider_test_graph.onnx");
|
||||
|
||||
{ // Create the model with 2 add nodes
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
@ -130,30 +105,38 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
|
|||
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
std::vector<int64_t> dims_mul_x = {1, 1, 3, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
|
||||
&ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_y);
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
|
||||
&ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_z);
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
|
||||
&ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
feeds.insert(std::make_pair("Z", ml_value_z));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("M");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_mul_m = {1, 1, 3, 2};
|
||||
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
|
||||
|
||||
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.FunctionTest",
|
||||
onnxruntime::make_unique<NnapiExecutionProvider>(0),
|
||||
feeds);
|
||||
#else
|
||||
// test load only
|
||||
SessionOptions so;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
|
||||
<< "Some nodes should have been taken by the NNAPI EP";
|
||||
#endif
|
||||
}
|
||||
#endif // !(ORT_MINIMAL_BUILD
|
||||
|
||||
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
|
||||
unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE;
|
||||
|
|
@ -164,5 +147,38 @@ TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
|
|||
ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW);
|
||||
}
|
||||
|
||||
TEST(NnapiExecutionProviderTest, TestOrtFormatModel) {
|
||||
// mnist model that has only had basic optimizations applied. nnapi should be able to take at least some of the nodes
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort");
|
||||
|
||||
// The execution can only be performed on Android
|
||||
#if defined(__ANDROID__)
|
||||
RandomValueGenerator random{};
|
||||
const std::vector<int64_t> dims = {1, 1, 28, 28};
|
||||
std::vector<float> data = random.Gaussian<float>(dims, 0.0f, 1.f);
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, data,
|
||||
&ml_value);
|
||||
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("Input3", ml_value));
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.TestOrtFormatModel",
|
||||
onnxruntime::make_unique<NnapiExecutionProvider>(0),
|
||||
feeds);
|
||||
#else
|
||||
// test load only
|
||||
SessionOptions so;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
|
||||
<< "Some nodes should have been taken by the NNAPI EP";
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/mnist.level1_opt.ort
vendored
Normal file
BIN
onnxruntime/test/testdata/mnist.level1_opt.ort
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/mnist.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/mnist.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/mnist.ort
vendored
Normal file
BIN
onnxruntime/test/testdata/mnist.ort
vendored
Normal file
Binary file not shown.
|
|
@ -10,14 +10,19 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id,
|
||||
OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
|
||||
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
|
||||
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo,
|
||||
bool do_copy_in_default_stream = true);
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(
|
||||
OrtDevice::DeviceId device_id,
|
||||
OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE,
|
||||
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
|
||||
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo,
|
||||
bool do_copy_in_default_stream = true);
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(
|
||||
const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_NGraph(const char* ng_backend_type);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(bool, const char*);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
|
||||
|
|
@ -29,6 +34,10 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ROCM(O
|
|||
size_t hip_mem_limit = std::numeric_limits<size_t>::max(),
|
||||
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo);
|
||||
|
||||
// EP for internal testing
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_InternalTesting(
|
||||
const std::unordered_set<std::string>& supported_ops);
|
||||
|
||||
namespace test {
|
||||
|
||||
std::unique_ptr<IExecutionProvider> DefaultCpuExecutionProvider(bool enable_arena) {
|
||||
|
|
|
|||
|
|
@ -4,17 +4,18 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
// helpers to run a function and check the status, outputting any error if it fails.
|
||||
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
|
||||
#define ASSERT_STATUS_OK(function) \
|
||||
do { \
|
||||
Status _tmp_status = function; \
|
||||
Status _tmp_status = function; \
|
||||
ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
|
||||
} while (false)
|
||||
|
||||
#define EXPECT_STATUS_OK(function) \
|
||||
do { \
|
||||
Status _tmp_status = function; \
|
||||
Status _tmp_status = function; \
|
||||
EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
|
||||
} while (false)
|
||||
|
|
|
|||
|
|
@ -21,5 +21,9 @@ std::unique_ptr<IExecutionProvider> DefaultAclExecutionProvider(bool enable_aren
|
|||
std::unique_ptr<IExecutionProvider> DefaultArmNNExecutionProvider(bool enable_arena = true);
|
||||
std::unique_ptr<IExecutionProvider> DefaultRocmExecutionProvider();
|
||||
|
||||
// EP for internal testing
|
||||
std::unique_ptr<IExecutionProvider> DefaultInternalTestingExecutionProvider(
|
||||
const std::unordered_set<std::string>& supported_ops);
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
28
onnxruntime/test/util/include/test_utils.h
Normal file
28
onnxruntime/test/util/include/test_utils.h
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
class Graph;
|
||||
|
||||
namespace test {
|
||||
|
||||
// return number of nodes in the Graph and any subgraphs that are assigned to the specified execution provider
|
||||
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type);
|
||||
|
||||
// run the model using the CPU EP to get expected output, comparing to the output when the 'execution_provider'
|
||||
// is enabled. requires that at least one node is assigned to 'execution_provider'
|
||||
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path,
|
||||
const char* log_id,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider,
|
||||
const NameMLValMap& feeds);
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
119
onnxruntime/test/util/test_utils.cc
Normal file
119
onnxruntime/test/util/test_utils.cc
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/util/include/test_utils.h"
|
||||
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/test/test_environment.h"
|
||||
#include "test/util/include/inference_session_wrapper.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
static void VerifyOutputs(const std::vector<std::string>& output_names,
|
||||
const std::vector<OrtValue>& expected_fetches,
|
||||
const std::vector<OrtValue>& fetches) {
|
||||
ASSERT_EQ(expected_fetches.size(), fetches.size());
|
||||
|
||||
for (size_t i = 0, end = expected_fetches.size(); i < end; ++i) {
|
||||
auto& ltensor = expected_fetches[i].Get<Tensor>();
|
||||
auto& rtensor = fetches[i].Get<Tensor>();
|
||||
ASSERT_EQ(ltensor.Shape().GetDims(), rtensor.Shape().GetDims());
|
||||
auto element_type = ltensor.GetElementType();
|
||||
switch (element_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
EXPECT_THAT(ltensor.DataAsSpan<int32_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int32_t>()))
|
||||
<< " mismatch for " << output_names[i];
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
EXPECT_THAT(ltensor.DataAsSpan<int64_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int64_t>()))
|
||||
<< " mismatch for " << output_names[i];
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
|
||||
const float abs_err = float(1e-5);
|
||||
|
||||
EXPECT_THAT(ltensor.DataAsSpan<float>(),
|
||||
::testing::Pointwise(::testing::FloatNear(abs_err), rtensor.DataAsSpan<float>()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type) {
|
||||
int count = 0;
|
||||
|
||||
for (const auto& node : current_graph.Nodes()) {
|
||||
if (node.GetExecutionProviderType() == ep_type) {
|
||||
++count;
|
||||
}
|
||||
|
||||
if (node.ContainsSubgraph()) {
|
||||
for (const auto& entry : node.GetSubgraphs()) {
|
||||
count += CountAssignedNodes(*entry, ep_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, const char* log_id,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider,
|
||||
const NameMLValMap& feeds) {
|
||||
SessionOptions so;
|
||||
so.session_logid = log_id;
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
//
|
||||
// get expected output from CPU EP
|
||||
//
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_path));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
const auto& graph = session_object.GetGraph();
|
||||
const auto& outputs = graph.GetOutputs();
|
||||
|
||||
// fetch all outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.reserve(outputs.size());
|
||||
for (const auto* node_arg : outputs) {
|
||||
if (node_arg->Exists()) {
|
||||
output_names.push_back(node_arg->Name());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<OrtValue> expected_fetches;
|
||||
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &expected_fetches));
|
||||
|
||||
auto provider_type = execution_provider->Type(); // copy string so the std::move doesn't affect us
|
||||
|
||||
//
|
||||
// get output with EP enabled
|
||||
//
|
||||
InferenceSessionWrapper session_object2{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(std::move(execution_provider)));
|
||||
ASSERT_STATUS_OK(session_object2.Load(model_path));
|
||||
ASSERT_STATUS_OK(session_object2.Initialize());
|
||||
|
||||
// make sure that some nodes are assigned to the EP, otherwise this test is pointless...
|
||||
const auto& graph2 = session_object2.GetGraph();
|
||||
auto ep_nodes = CountAssignedNodes(graph2, provider_type);
|
||||
ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type << " for " << model_path;
|
||||
|
||||
// Run with EP and verify the result
|
||||
std::vector<OrtValue> fetches;
|
||||
ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(output_names, expected_fetches, fetches);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -419,10 +419,13 @@ def parse_arguments():
|
|||
help="Build ONNXRuntime micro-benchmarks.")
|
||||
|
||||
# options to reduce binary size
|
||||
parser.add_argument("--minimal_build", action='store_true',
|
||||
parser.add_argument("--minimal_build", action='store',
|
||||
const='on', default='off', nargs='?', type=str.lower,
|
||||
help="Create a build that only supports ORT format models. "
|
||||
"See /docs/ONNX_Runtime_Format_Model_Usage.md for more information. "
|
||||
"RTTI is automatically disabled in a minimal build.")
|
||||
"RTTI is automatically disabled in a minimal build. "
|
||||
"To enable execution providers that compile kernels at runtime (e.g. NNAPI) pass 'extended' "
|
||||
"as a parameter. e.g. '--minimal_build extended'.")
|
||||
parser.add_argument("--include_ops_by_model", type=str, help="include ops from model(s) under designated path.")
|
||||
parser.add_argument("--include_ops_by_config", type=str,
|
||||
help="include ops from config file. "
|
||||
|
|
@ -710,7 +713,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
|
|||
"-Donnxruntime_DISABLE_RTTI=" + ("ON" if args.disable_rtti else "OFF"),
|
||||
"-Donnxruntime_DISABLE_EXCEPTIONS=" + ("ON" if args.disable_exceptions else "OFF"),
|
||||
"-Donnxruntime_DISABLE_ORT_FORMAT_LOAD=" + ("ON" if args.disable_ort_format_load else "OFF"),
|
||||
"-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build else "OFF"),
|
||||
"-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build != 'off' else "OFF"),
|
||||
"-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + ("ON" if args.minimal_build == 'extended' else "OFF"),
|
||||
"-Donnxruntime_REDUCED_OPS_BUILD=" + (
|
||||
"ON" if args.include_ops_by_config or args.include_ops_by_model else "OFF"),
|
||||
"-Donnxruntime_MSVC_STATIC_RUNTIME=" + (
|
||||
|
|
|
|||
Loading…
Reference in a new issue