diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 18f762e605..bc1d62a898 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -120,6 +120,7 @@ option(onnxruntime_DISABLE_RTTI "Disable RTTI" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) +option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF) option(onnxruntime_DISABLE_ORT_FORMAT_LOAD "Disable loading an ORT format model when onnxruntime_MINIMAL_BUILD=OFF (i.e. in a full build)." OFF) @@ -208,12 +209,21 @@ if(onnxruntime_USE_OPENMP) endif() endif() +# 'extended' implies minimal. +if (onnxruntime_EXTENDED_MINIMAL_BUILD AND NOT onnxruntime_MINIMAL_BUILD) + set(onnxruntime_MINIMAL_BUILD ON) +endif() + # ORT build with as much excluded as possible. Supports ORT flatbuffers models only. -# Will expose option in build.py when all pieces are available -if(onnxruntime_MINIMAL_BUILD) +if (onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_MINIMAL_BUILD) add_compile_definitions(ENABLE_ORT_FORMAT_LOAD) + if (onnxruntime_EXTENDED_MINIMAL_BUILD) + # enable EPs that compile kernels at runtime + add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD) + endif() + set(onnxruntime_REDUCED_OPS_BUILD ON) if (NOT onnxruntime_ENABLE_PYTHON) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index e4a731a03e..70c97612c4 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -110,6 +110,7 @@ target_link_libraries(onnxruntime PRIVATE ${PROVIDERS_DML} ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} + ${PROVIDERS_INTERNAL_TESTING} ${onnxruntime_winml} ${PROVIDERS_ROCM} onnxruntime_optimizer diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index b97cc0ed48..6f742c2d9b 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -10,7 +10,6 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS if (onnxruntime_MINIMAL_BUILD) file(GLOB onnxruntime_framework_src_exclude "${ONNXRUNTIME_ROOT}/core/framework/provider_bridge_ort.cc" - "${ONNXRUNTIME_ROOT}/core/framework/graph_partitioner.*" "${ONNXRUNTIME_INCLUDE_DIR}/core/framework/customregistry.h" "${ONNXRUNTIME_ROOT}/core/framework/customregistry.cc" ) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 6b0d1de5d1..4692dfd4de 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -657,6 +657,10 @@ if (onnxruntime_USE_OPENVINO) endif() if (onnxruntime_USE_NNAPI_BUILTIN) + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + add_compile_definitions(USE_NNAPI=1) # This is the minimum Android API Level required by ORT NNAPI EP to run @@ -782,7 +786,7 @@ if (onnxruntime_USE_DML) else() add_dependencies(${target} RESTORE_PACKAGES) target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/DirectML.lib") - target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST) + target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST) endif() endfunction() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a287ea143d..8ee81cc34b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -312,6 +312,13 @@ if (onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src}) endif() +if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD) + file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/providers/internal_testing/*" + ) + list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_internal_testing_src}) +endif() + set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib") set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools") set (ONNXRUNTIME_API_TESTS_WITHOUT_ENV_SRC_DIR "${ONNXRUNTIME_ROOT}/test/api_tests_without_env") @@ -525,7 +532,7 @@ else() target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} "${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public") endif() -onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest onnx onnx_proto flatbuffers) +onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework GTest::gtest GTest::gmock onnx onnx_proto flatbuffers) if (onnxruntime_USE_DNNL) target_compile_definitions(onnxruntime_test_utils PUBLIC USE_DNNL=1) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 654487bfde..2be3f20767 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -168,10 +168,13 @@ class IExecutionProvider { void InsertAllocator(AllocatorPtr allocator); void ReplaceAllocator(AllocatorPtr allocator); + // creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support + // compilation via GraphViewer instances. +#if !defined(ORT_MINIMAL_BUILD) /** Given a list of fused_node, return create_state/compute/release_state func for each node. */ - virtual common::Status Compile(const std::vector& fused_node, + virtual common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs); /** @@ -181,9 +184,53 @@ class IExecutionProvider { Compute_${node_name} Release_State_${node_name} */ - virtual common::Status Compile(const std::vector& fused_node, + virtual common::Status Compile(const std::vector& fused_nodes, std::string& dll_path); +#endif + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + struct FusedNodeAndGraph { + const std::reference_wrapper fused_node; + // GraphViewer that filters the full graph to the nodes that are covered by 'node' + const std::reference_wrapper filtered_graph; + }; + + /** + Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused, + return create_state/compute/release_state func for each node. + @remarks This is an optional interface that is only needed if the execution provider compiles nodes + in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model. + + Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions + as it is only valid for the duration of the call to Compile. + */ + virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs); +#endif + + // Fusion approach that is suppported + enum class FusionStyle { + // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance + // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body(). + // A GraphProto can be produced from the Node body. + Function, + + // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph + // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body(). + // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created. + // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance, + // and can be supported in a minimal build. + FilteredGraphViewer + }; + + virtual FusionStyle GetFusionStyle() const { + // existing EPs use this mode so default to it. + // newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return + // FilteredGraphViewer + return FusionStyle::Function; + } + void SetLogger(const logging::Logger* logger) { logger_ = logger; } diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index 1ec827c565..40b8c9a1fa 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -45,9 +45,7 @@ class OpKernelInfo : public OpNodeProtoHelper { bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const; - common::Status GetFusedFuncs(ComputeFunc* compute, - CreateFunctionStateFunc* create, - DestroyFunctionStateFunc* release) const; + common::Status GetFusedFuncs(NodeComputeInfo*& compute_info) const; private: ORT_DISALLOW_MOVE(OpKernelInfo); diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index cdf1d24b23..5c67796075 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -158,7 +158,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, shape_.Size()); + return gsl::make_span(data, static_cast::index_type>(shape_.Size())); } void* MutableDataRaw(MLDataType type) { diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index f498da464f..4d56c370f3 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -37,49 +37,50 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; -constexpr const char *providers_available[] = { - kCpuExecutionProvider, + +constexpr const char* providers_available[] = { + kCpuExecutionProvider, #ifdef USE_CUDA - kCudaExecutionProvider, + kCudaExecutionProvider, #endif #ifdef USE_DNNL - kDnnlExecutionProvider, + kDnnlExecutionProvider, #endif #ifdef USE_NGRAPH - kNGraphExecutionProvider, + kNGraphExecutionProvider, #endif #ifdef USE_OPENVINO - kOpenVINOExecutionProvider, + kOpenVINOExecutionProvider, #endif #ifdef USE_NUPHAR - kNupharExecutionProvider, + kNupharExecutionProvider, #endif #ifdef USE_VITISAI - kVitisAIExecutionProvider, + kVitisAIExecutionProvider, #endif #ifdef USE_TENSORRT - kTensorrtExecutionProvider, + kTensorrtExecutionProvider, #endif #ifdef USE_NNAPI - kNnapiExecutionProvider, + kNnapiExecutionProvider, #endif #ifdef USE_RKNPU - kRknpuExecutionProvider, + kRknpuExecutionProvider, #endif #ifdef USE_DML - kDmlExecutionProvider, + kDmlExecutionProvider, #endif #ifdef USE_MIGRAPHX - kMIGraphXExecutionProvider, + kMIGraphXExecutionProvider, #endif #ifdef USE_ACL - kAclExecutionProvider, + kAclExecutionProvider, #endif #ifdef USE_ARMNN - kArmNNExecutionProvider, + kArmNNExecutionProvider, #endif #ifdef USE_ROCM - kRocmExecutionProvider, + kRocmExecutionProvider, #endif }; diff --git a/include/onnxruntime/core/graph/function.h b/include/onnxruntime/core/graph/function.h index 6328713705..efe5eb5c9f 100644 --- a/include/onnxruntime/core/graph/function.h +++ b/include/onnxruntime/core/graph/function.h @@ -31,7 +31,7 @@ class Function { /** Create a new Function instance. @param graph The graph containing the Function. -@param customized_func the IndexedSubGraph to use for the Function. +@param nodes_to_fuse the IndexedSubGraph to use for the Function. */ std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, const IndexedSubGraph& nodes_to_fuse, diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 6c3df7f469..be92550efe 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -368,7 +368,9 @@ class Node { ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; } /** Sets the execution ProviderType that this Node will be executed by. */ - void SetExecutionProviderType(ProviderType execution_provider_type); + void SetExecutionProviderType(ProviderType execution_provider_type) { + execution_provider_type_ = execution_provider_type; + } /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node. If the NodeArg is an explicit or implicit input, is_input will be true when func is called. @@ -476,7 +478,7 @@ class Node { Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {} -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void Init(const std::string& name, const std::string& op_type, const std::string& description, @@ -485,25 +487,24 @@ class Node { const NodeAttributes* attributes, const std::string& domain); - // create a Graph instance for an attribute that contains a GraphProto - void CreateSubgraph(const std::string& attr_name); - // internal only method to allow selected classes to directly alter the input/output definitions and arg counts Definitions& MutableDefinitions() noexcept; // internal only method to allow selected classes to directly alter the links between nodes. Relationships& MutableRelationships() noexcept; + void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; } +#endif + + // create a Graph instance for an attribute that contains a GraphProto + void CreateSubgraph(const std::string& attr_name); + const std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } - void SetNodeType(Node::Type node_type) noexcept; - - void SetFunctionBody(const Function& func); - // validate and update the input arg count common::Status UpdateInputArgCount(); -#endif // !defined(ORT_MINIMAL_BUILD) + void SetFunctionBody(const Function& func); const Definitions& GetDefinitions() const noexcept { return definitions_; } const Relationships& GetRelationships() const noexcept { return relationships_; } @@ -578,6 +579,15 @@ class Graph { /** Gets the path of the owning model, if any. */ const Path& ModelPath() const; + /** Returns true if this is a subgraph or false if it is a high-level graph. */ + bool IsSubgraph() const { return parent_graph_ != nullptr; } + + /** Returns the parent graph if this is a subgraph */ + const Graph* ParentGraph() const { return parent_graph_; } + + /** Returns the mutable parent graph if this is a subgraph */ + Graph* MutableParentGraph() { return parent_graph_; } + #if !defined(ORT_MINIMAL_BUILD) /** Sets the Graph name. */ void SetName(const std::string& name); @@ -756,6 +766,15 @@ class Graph { /** Generate a unique name in this Graph for a Node */ std::string GenerateNodeName(const std::string& base_name); + /** Copy a Node and add it to this Graph. + @param other Node to copy + @returns Reference to the Node that was created and added to this Graph. + @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe. + */ + Node& AddNode(const Node& other); +#endif + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Add a Node to this Graph. @param name The Node name. Must be unique in this Graph. @param op_type The operator type. e.g. ONNX operator name. @@ -775,13 +794,6 @@ class Graph { const NodeAttributes* attributes = nullptr, const std::string& domain = ""); - /** Copy a Node and add it to this Graph. - @param other Node to copy - @returns Reference to the Node that was created and added to this Graph. - @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe. - */ - Node& AddNode(const Node& other); - /** Remove a Node from this Graph and free it. The output edges of this specified node MUST have been removed before removing the node. The input edges of this specified node is removed while removing the node. The process of @@ -809,14 +821,15 @@ class Graph { @param dst_arg_index node arg index of destination node. */ void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index); +#endif +#if !defined(ORT_MINIMAL_BUILD) /** Add a control edge between two Nodes in this Graph. The source Node does not produce output that is directly consumed by the destination Node, however the destination Node must execute after the source node. The control edge allows this ordering to occur. */ bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index); - #endif // !defined(ORT_MINIMAL_BUILD) /** Mark the Graph as needing Resolve() to be called. @@ -880,6 +893,7 @@ class Graph { const std::function& comp, const std::function& stop) const; +#if !defined(ORT_MINIMAL_BUILD) /** Performs topological sort with Kahn's algorithm on the graph/s. @param enter Visit function that will be invoked on a node when it is visited. @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. @@ -887,11 +901,29 @@ class Graph { void KahnsTopologicalSort(const std::function& enter, const std::function& comp) const; +#endif + /** Gets the map of operator domains to their opset versions. */ const std::unordered_map& DomainToVersionMap() const noexcept { return domain_to_version_; } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + /** + Create a single Node that will be the result of the a fusion of multiple nodes in this Graph. + @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. + @param fused_node_name The name for the new Node. + @returns Node with fused subgraph. + @remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the + IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place + while this is in use. + Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created. + */ + Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); + + void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); +#endif + #if !defined(ORT_MINIMAL_BUILD) /** Gets the GraphProto representation of this Graph. */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); @@ -901,10 +933,11 @@ class Graph { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; /** - Create a single Node that is the result of the a fusion of multiple nodes in this Graph. - @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. + Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph. + A new Graph instance will be created for the fused nodes. + @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node @param fused_node_name The name for the new Node. - @returns Node with fused subgraph. + @returns Function based Node with fused subgraph. The Node body will contain a Function instance. */ Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); @@ -939,18 +972,7 @@ class Graph { @remarks Note that the output order matters for subgraphs. */ void SetOutputs(const std::vector& outputs); -#endif // !defined(ORT_MINIMAL_BUILD) - /** Returns true if this is a subgraph or false if it is a high-level graph. */ - bool IsSubgraph() const { return parent_graph_ != nullptr; } - - /** Returns the parent graph if this is a subgraph */ - const Graph* ParentGraph() const { return parent_graph_; } - - /** Returns the mutable parent graph if this is a subgraph */ - Graph* MutableParentGraph() { return parent_graph_; } - -#if !defined(ORT_MINIMAL_BUILD) /** Sets the type of a NodeArg, replacing existing type/shape if any */ void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto); @@ -1209,12 +1231,6 @@ class Graph { // Clear all unused initializers void CleanUnusedInitializers(const std::unordered_set* initializer_names_to_preserve = nullptr); - gsl::not_null AllocateNode(); - - // Release the node. - // @returns false if node_index was invalid. - bool ReleaseNode(NodeIndex node_index); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map); @@ -1247,6 +1263,16 @@ class Graph { #endif // !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + gsl::not_null AllocateNode(); + + // Release the node. + // @returns false if node_index was invalid. + bool ReleaseNode(NodeIndex node_index); + + Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); +#endif + Node* NodeAtIndexImpl(NodeIndex node_index) const { // if we are trying to access a node that doesn't exist there's (most // likely) either a logic issue or a graph consistency/correctness issue. @@ -1276,12 +1302,10 @@ class Graph { sparse_tensor_names_; #if !defined(ORT_MINIMAL_BUILD) - IOnnxRuntimeOpSchemaCollectionPtr schema_registry_; std::vector> function_container_; - -#endif // !defined(ORT_MINIMAL_BUILD) +#endif // Graph nodes. // Element in may be nullptr due to graph optimization. diff --git a/include/onnxruntime/core/graph/graph_nodes.h b/include/onnxruntime/core/graph/graph_nodes.h index aaa3a2b07b..422fe9538e 100644 --- a/include/onnxruntime/core/graph/graph_nodes.h +++ b/include/onnxruntime/core/graph/graph_nodes.h @@ -28,21 +28,21 @@ class ValidNodes { Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection @param[in] nodes Nodes to iterate, skipping invalid entries. */ - explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(nodes) {} + explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {} explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept - : nodes_(nodes), filter_node_fn_{std::move(filter_node_fn)} {} + : nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {} using ConstNodeIterator = NodeIterator; using MutableNodeIterator = NodeIterator; using ConstReverseNodeIterator = NodeIterator; ConstNodeIterator cbegin() const noexcept { - return {nodes_.cbegin(), nodes_.cend(), filter_node_fn_}; + return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_}; } ConstNodeIterator cend() const noexcept { - return {nodes_.cend(), nodes_.cend(), filter_node_fn_}; + return {nodes_->cend(), nodes_->cend(), filter_node_fn_}; } ConstNodeIterator begin() const noexcept { @@ -54,11 +54,11 @@ class ValidNodes { } ConstReverseNodeIterator rbegin() const noexcept { - return {nodes_.crbegin(), nodes_.crend(), filter_node_fn_}; + return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_}; } ConstReverseNodeIterator rend() const noexcept { - return {nodes_.crend(), nodes_.crend(), filter_node_fn_}; + return {nodes_->crend(), nodes_->crend(), filter_node_fn_}; } // we only allow mutable access if the container is non-const. @@ -66,16 +66,16 @@ class ValidNodes { template typename std::enable_if::value, MutableNodeIterator>::type begin() noexcept { static_assert(std::is_same::value, "Explicit specialization is not allowed"); - return MutableNodeIterator(nodes_.begin(), nodes_.end(), filter_node_fn_); + return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_); } template typename std::enable_if::value, MutableNodeIterator>::type end() noexcept { static_assert(std::is_same::value, "Explicit specialization is not allowed"); - return MutableNodeIterator(nodes_.end(), nodes_.end(), filter_node_fn_); + return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_); } - bool empty() const noexcept { return nodes_.empty(); } + bool empty() const noexcept { return nodes_->empty(); } /** @class NodeIterator @@ -152,7 +152,7 @@ class ValidNodes { }; private: - TNodesContainer& nodes_; + gsl::not_null nodes_; // always set by ctor // no filtering if not set. this instance owns the filter func if set. NodeFilterFunc filter_node_fn_; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 6c71aa368e..c0ace1ffdb 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -158,6 +158,11 @@ class GraphViewer { } #endif + /** Get the filter info that restricts the graph viewer to a subset of nodes if set. + @returns Filter info or nullptr + */ + const IndexedSubGraph* GetFilterInfo() const { return filter_info_; } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h index e9f68dc28a..068bee2f6c 100644 --- a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h +++ b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h @@ -1,4 +1,5 @@ -// Copyright 2019 JD.com Inc. JD AI +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once #include "onnxruntime_c_api.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0172dc7867..1d21104c4a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -481,7 +481,7 @@ inline SessionOptions& SessionOptions::AddInitializer(const char* name, const Or return *this; } -inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions * options, OrtCUDAProviderOptions * cuda_options) { +inline OrtStatus* SessionOptions::OrtSessionOptionsAppendExecutionProvider_CUDA(OrtSessionOptions* options, OrtCUDAProviderOptions* cuda_options) { ThrowOnError(GetApi().OrtSessionOptionsAppendExecutionProvider_CUDA(options, cuda_options)); return nullptr; } @@ -943,7 +943,8 @@ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* return out; } -inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { +inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count) { OrtValue* out; ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out)); return out; diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 7964304bd0..3812abb473 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -43,9 +43,11 @@ IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; #else + // We have saved hashes to lookup static kernels in an ORT format model so the default behavior is to return an + // empty vector to leave that in place. An EP that compiles nodes can override this in a minimal build. ORT_UNUSED_PARAMETER(graph); ORT_UNUSED_PARAMETER(kernel_registries); - ORT_NOT_IMPLEMENTED("IExecutionProvider::GetCapability is not supported in this build."); + return result; #endif } @@ -79,6 +81,7 @@ void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) { allocator_list_.push_back(allocator); } +#if !defined(ORT_MINIMAL_BUILD) common::Status IExecutionProvider::Compile(const std::vector& /*fused_node*/, std::vector& /*node_compute_funcs*/) { return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); @@ -88,6 +91,14 @@ common::Status IExecutionProvider::Compile(const std::vector std::string& /*dll_path*/) { return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } +#endif + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +common::Status IExecutionProvider::Compile(const std::vector& /*fused_nodes_and_graphs*/, + std::vector& /*node_compute_funcs*/) { + return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); +} +#endif std::shared_ptr IExecutionProvider::GetKernelRegistry() const { return nullptr; diff --git a/onnxruntime/core/framework/func_kernel.h b/onnxruntime/core/framework/func_kernel.h index 128c6cc3fd..ece2c24304 100644 --- a/onnxruntime/core/framework/func_kernel.h +++ b/onnxruntime/core/framework/func_kernel.h @@ -18,33 +18,33 @@ class FunctionKernel : public OpKernel { explicit FunctionKernel(const OpKernelInfo& info) : OpKernel(info) { num_inputs_ = info.node().InputDefs().size(); num_outputs_ = info.node().OutputDefs().size(); - CreateFunctionStateFunc create_func; - auto status = info.GetFusedFuncs(&func_, &create_func, &release_func_); + auto status = info.GetFusedFuncs(compute_info_); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - if (create_func) { + if (compute_info_->create_state_func) { //TODO: we are only provide host allocate method in compute context. //Do we need to hold the ref-counting here? host_allocator_ = info.GetAllocator(0, OrtMemType::OrtMemTypeDefault); - ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(), info.node().Name().c_str()}; - ORT_ENFORCE(create_func(&context, &func_state_) == 0); + ComputeContext context = {allocate_helper_func, release_helper_func, host_allocator_.get(), + info.node().Name().c_str()}; + ORT_ENFORCE(compute_info_->create_state_func(&context, &func_state_) == 0); } } ~FunctionKernel() override { - if (release_func_ && func_state_) { - release_func_(func_state_); + if (compute_info_->release_state_func && func_state_) { + compute_info_->release_state_func(func_state_); } } virtual Status Compute(OpKernelContext* context) const override { auto* context_internal = static_cast(context); - return func_(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION), reinterpret_cast(context_internal)); + return compute_info_->compute_func(func_state_, OrtGetApiBase()->GetApi(ORT_API_VERSION), + reinterpret_cast(context_internal)); } private: - ComputeFunc func_; - DestroyFunctionStateFunc release_func_; - FunctionState func_state_; + NodeComputeInfo* compute_info_{nullptr}; + FunctionState func_state_{nullptr}; size_t num_inputs_; size_t num_outputs_; AllocatorPtr host_allocator_; diff --git a/onnxruntime/core/framework/fuse_nodes_funcs.cc b/onnxruntime/core/framework/fuse_nodes_funcs.cc index 8665242d10..f4678afbed 100644 --- a/onnxruntime/core/framework/fuse_nodes_funcs.cc +++ b/onnxruntime/core/framework/fuse_nodes_funcs.cc @@ -6,25 +6,28 @@ Status FuncManager::AddFuncInfo(const std::string& name, const std::string& dll_ auto it = fused_funcs_->find(name); if (it != fused_funcs_->end()) return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist."); - (*fused_funcs_)[name] = {dll_path, nullptr, nullptr, nullptr}; + (*fused_funcs_)[name] = {dll_path, NodeComputeInfo()}; return Status::OK(); } -Status FuncManager::AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release) { +Status FuncManager::AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info) { auto it = fused_funcs_->find(name); if (it != fused_funcs_->end()) return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " already exist."); - if (!compute || !create || !release) + + if (!compute_info.compute_func || !compute_info.create_state_func || !compute_info.release_state_func) return Status(common::ONNXRUNTIME, common::FAIL, "Can't use func with null ptr"); - (*fused_funcs_)[name] = {"", compute, create, release}; + + (*fused_funcs_)[name] = {"", std::move(compute_info)}; return Status::OK(); } -Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const { +Status FuncManager::GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const { auto it = fused_funcs_->find(name); if (it == fused_funcs_->end()) return Status(common::ONNXRUNTIME, common::FAIL, "func info for node: " + name + " not found."); - if (!it->second.compute_func) { + + if (!it->second.compute_info.compute_func) { //load from path void* handle = nullptr; ORT_RETURN_IF_ERROR(lib_loader_->LoadExternalLib(it->second.dso_path, &handle)); @@ -40,21 +43,20 @@ Status FuncManager::GetFuncs(const std::string& name, ComputeFunc* compute, Crea ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, kReleaseStateFuncSymbol + name, &release_func_symbol_handle)); - it->second.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + it->second.compute_info.compute_func = [=](FunctionState state, const OrtApi* api, OrtKernelContext* context) { return reinterpret_cast(compute_func_symbol_handle)(state, api, context); }; - it->second.create_state_func = [=](ComputeContext* context, FunctionState* state) { + it->second.compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { return reinterpret_cast(create_func_symbol_handle)(context, state); }; - it->second.release_state_func = [=](FunctionState state) { + it->second.compute_info.release_state_func = [=](FunctionState state) { return reinterpret_cast(release_func_symbol_handle)(state); }; } - *compute = it->second.compute_func; - *create = it->second.create_state_func; - *release = it->second.release_state_func; + + compute_info = &it->second.compute_info; return Status::OK(); } diff --git a/onnxruntime/core/framework/fuse_nodes_funcs.h b/onnxruntime/core/framework/fuse_nodes_funcs.h index a8c2fe808d..6b36b3fb1e 100644 --- a/onnxruntime/core/framework/fuse_nodes_funcs.h +++ b/onnxruntime/core/framework/fuse_nodes_funcs.h @@ -7,13 +7,16 @@ namespace onnxruntime { class FuncManager { public: - FuncManager() : fused_funcs_(std::make_shared >()), lib_loader_(onnxruntime::make_unique()) {} + FuncManager() + : fused_funcs_(std::make_shared >()), + lib_loader_(onnxruntime::make_unique()) { + } Status AddFuncInfo(const std::string& name, const std::string& dll_path); - Status AddFuncInfo(const std::string& name, ComputeFunc compute, CreateFunctionStateFunc create, DestroyFunctionStateFunc release); + Status AddFuncInfo(const std::string& name, NodeComputeInfo&& compute_info); - Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const; + Status GetFuncs(const std::string& name, NodeComputeInfo*& compute_info) const; void SetFusedFuncs(const FuncManager& func_mgr) { fused_funcs_ = func_mgr.fused_funcs_; @@ -21,9 +24,7 @@ class FuncManager { struct FuncInfo { std::string dso_path; - ComputeFunc compute_func; - CreateFunctionStateFunc create_state_func; - DestroyFunctionStateFunc release_state_func; + NodeComputeInfo compute_info; }; private: diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 69c9096480..04f5979aa4 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/framework/graph_partitioner.h" #include "core/framework/kernel_registry_manager.h" #include "core/graph/function.h" @@ -41,13 +43,22 @@ NonCudaOps non_cuda; using namespace ::onnxruntime::common; namespace onnxruntime { -static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) { +// minimal KernelDef based on MetaDef instead of a Function based node +static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef, + const std::string& provider_type) { + builder.SetName(metadef.name) + .SetDomain(metadef.domain) + .SinceVersion(metadef.since_version) + .Provider(provider_type); +} + +#if !defined(ORT_MINIMAL_BUILD) +static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) { auto schema = node.Op(); builder.SetName(schema->Name()) .SetDomain(schema->domain()) .SinceVersion(schema->SinceVersion()) .Provider(node.GetExecutionProviderType()); - return builder; } /** @@ -57,12 +68,16 @@ static KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const on * \param capability * \param kernel_registry_mgr * \param provider_type name of the provider to test - * \param count A counter for generating fused node names. Should be unique within this subgraph + * \param count A counter for generating fused node names. Unique across the entire model. * \return Fused node. Return nullptr if there is no fuse */ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type, - int& count) { + IExecutionProvider::FusionStyle fusion_style, + GraphPartitioner::Mode mode, + int& fused_node_unique_id) { + Node* result = nullptr; + if (nullptr == capability.GetMetaDef()) { // The can run a single node in the if not using meta-defs. // A fused kernel is not supported in this case. @@ -75,39 +90,82 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, } } else { // The can run a fused in the . - ORT_ENFORCE(nullptr != capability.GetMetaDef()); - // Check whether any node in the was already assigned. + + // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done + // in order of EP priority bool sub_graph_available_for_assignment = true; for (auto node_index : capability.nodes) { - auto node = graph.GetNode(node_index); + const auto* node = graph.GetNode(node_index); if (nullptr == node || !node->GetExecutionProviderType().empty()) { - // The node was fused or assigned, so that the whole sub-graph will not be assigned to this - // The assumption is that this can only run the sub-graph as a whole unit. - sub_graph_available_for_assignment = false; - break; + // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, + // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by + // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to + // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes + // should never be on those lists. + // + // when the ORT format model is loaded we will process it normally with EP priority being applied for + // whichever EPs are enabled at the time. + // + // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. + // We want the ORT format model to be able to be run as efficiently as possible on either platform, + // so we want all the nodes that either may take to be preserved. If we did not do this we would + // need to create one ORT format model for Android and one for iOS. + if (mode != GraphPartitioner::Mode::kAssignOnly) { + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. + sub_graph_available_for_assignment = false; + break; + } } } + if (sub_graph_available_for_assignment) { - std::ostringstream oss; - oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << count++; - std::string node_name = oss.str(); - auto& fused_node = graph.FuseSubGraph(capability, node_name); - fused_node.SetExecutionProviderType(provider_type); - // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, fused_node, provider_type)) { - return &fused_node; + if (mode == GraphPartitioner::Mode::kNormal) { + std::ostringstream oss; + oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; + std::string node_name = oss.str(); + + Node* fused_node = nullptr; + if (fusion_style == IExecutionProvider::FusionStyle::Function) { + fused_node = &graph.FuseSubGraph(capability, node_name); + } else { + // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed + // through to Compile via a filtered GraphViewer. + fused_node = &graph.BeginFuseSubGraph(capability, node_name); + } + + fused_node->SetExecutionProviderType(provider_type); + + // searching in kernel registries, if no kernel registered for the fused_node, use compile approach + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) { + result = fused_node; + } + } else { + // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. + // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion + // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device + // capabilities. + for (auto node_index : capability.nodes) { + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + node->SetExecutionProviderType(provider_type); + } + } } } } - return nullptr; + + return result; } // for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up). // assign any nodes to the EP that are currently unassigned, and that the EP can handle. -static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr, - KernelRegistryManager& kernel_registry_mgr, - KernelRegistry& fused_kernel_registry, - IExecutionProvider& current_ep) { +static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr, + KernelRegistryManager& kernel_registry_mgr, + KernelRegistry& fused_kernel_registry, + IExecutionProvider& current_ep, + GraphPartitioner::Mode mode, + int& fused_node_unique_id) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -119,8 +177,8 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { Graph* subgraph = entry.second; // we pass through the export_dll value and FuncManager from the top level graph - ORT_RETURN_IF_ERROR(PartitionImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, fused_kernel_registry, - current_ep)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, + fused_kernel_registry, current_ep, mode, fused_node_unique_id)); } } @@ -134,65 +192,135 @@ static Status PartitionImpl(Graph& graph, bool export_dll, FuncManager& func_mgr // TODO: when the graph contain a function node, and user pass in the dll which could // run the function by SessionOption, we should create a function kernel for it and // delegate the compute to the functions inside the dlls. - int count = 0; - std::vector nodes_need_compile; + const std::string& type = current_ep.Type(); + auto fusion_style = current_ep.GetFusionStyle(); + std::vector nodes_to_compile; GraphViewer graph_viewer(graph); std::vector> capabilities = - current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(current_ep.Type())); + current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type)); + + // filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with + // nodes_to_compile. + std::vector> capabilities_to_compile; + capabilities_to_compile.reserve(std::count_if(capabilities.cbegin(), capabilities.cend(), + [](const std::unique_ptr& entry) { + return entry != nullptr && + entry->sub_graph != nullptr && + entry->sub_graph->GetMetaDef() != nullptr; + })); for (auto& capability : capabilities) { if (!capability || !capability->sub_graph) { // in theory an EP could return an empty value... continue; } - Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, current_ep.Type(), count); + Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id); if (n != nullptr) { - nodes_need_compile.push_back(n); + nodes_to_compile.push_back(n); + capabilities_to_compile.push_back(std::move(capability)); } } - if (!nodes_need_compile.empty()) { + // NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode + if (!nodes_to_compile.empty()) { + std::vector node_compute_funcs; + if (export_dll) { - std::string dll_path; - ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, dll_path)); - for (auto* node : nodes_need_compile) { - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path)); - } - } else { - std::vector node_compute_funcs; - ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_need_compile, node_compute_funcs)); - ORT_ENFORCE(node_compute_funcs.size() == nodes_need_compile.size(), - "Provider did not return correct number of compiled functions"); - for (size_t j = 0; j < nodes_need_compile.size(); j++) { - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_need_compile[j]->Name(), node_compute_funcs[j].compute_func, - node_compute_funcs[j].create_state_func, - node_compute_funcs[j].release_state_func)); - } + ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function, + "Must use Function based fusion when exporting compiled nodes to dll."); } - for (auto* node : nodes_need_compile) { - //prepare the func kernel - KernelDefBuilder builder; - BuildFusedKernelDef(builder, *node); - ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder, static_cast( - [](const OpKernelInfo& info) -> OpKernel* { - return new FunctionKernel(info); - }))); + if (fusion_style == IExecutionProvider::FusionStyle::Function) { + // Create a Function based node where the fused nodes have a new Graph instance. + + if (export_dll) { + std::string dll_path; + ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path)); + + for (auto* node : nodes_to_compile) { + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path)); + } + } else { + ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs)); + + if (node_compute_funcs.size() != nodes_to_compile.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); + } + + for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j]))); + } + } + + for (auto* node : nodes_to_compile) { + // add the KernelDef instances for the compiled nodes + KernelDefBuilder builder; + BuildFusedKernelDef(builder, *node); + ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder, + static_cast( + [](const OpKernelInfo& info) -> OpKernel* { + return new FunctionKernel(info); + }))); + } + + } else { + // temporary storage for the GraphViewer for each IndexedSubGraph + std::vector> viewers; + viewers.reserve(nodes_to_compile.size()); + std::vector nodes_and_viewers; + + for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { + auto* node = nodes_to_compile[j]; + const auto& cur_capability = *capabilities_to_compile[j]; + viewers.push_back(onnxruntime::make_unique(graph, *cur_capability.sub_graph)); + nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{*node, *viewers.back()}); + } + + ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); + + if (node_compute_funcs.size() != nodes_to_compile.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); + } + + for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { + auto* node = nodes_to_compile[j]; + + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), std::move(node_compute_funcs[j]))); + + const auto& cur_capability = capabilities_to_compile[j]; + const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); + + // create the func kernel for the name in the MetaDef. this is also the node name and that name that will + // used as the key in the FuncManager entry. We need the registry to own the KernelCreateInfo that is + // used by SessionState + KernelDefBuilder builder; + BuildFusedKernelDef(builder, metadef, type); + ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(builder, + static_cast( + [](const OpKernelInfo& info) -> OpKernel* { + return new FunctionKernel(info); + }))); + + // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one + graph.FinalizeFuseSubGraph(indexed_sub_graph, *node); + } } } // if this is the main graph call Resolve to put the Graph back into a guaranteed good state - // TODO: If we fix Graph::FuseSubGraph to correctly update edges when replacing nodes with a fused node we can - // avoid this Resolve call. + // TODO: Graph::FuseSubGraph and Graph::FinalizeFuseSubGraph should now create valid edges so this call to + // Graph::Resolve should not be required. Need to test to validate that, especially if node being fused + // was a control flow node with its own subgraph as more than just the edges may need updating. if (!graph.IsSubgraph()) { ORT_RETURN_IF_ERROR(graph.Resolve()); } - //For some cases, like fp16 on cpu, right now we don't have any kernel support that. - //But we will insert cast op to run the model, so skip the error checking here. - //If after graph transform phase, the node still not assigned, we will report error - //during kernel creation phase. + // For some cases, like fp16 on cpu, right now we don't have any kernel support that. + // But we will insert cast op to run the model, so skip the error checking here. + // If after graph transform phase, the node still not assigned, we will report error + // during kernel creation phase. #ifdef COUNT_NON_CUDA_OPS for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType() != kCudaExecutionProvider && @@ -231,12 +359,157 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { modified_graph = true; } - ORT_RETURN_IF_ERROR(graph.Resolve()); + return Status::OK(); +} + +Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, + KernelRegistry& fused_kernel_registry, Mode mode, + int& fused_node_unique_id) const { + bool modified_graph = false; + + do { + // process full graph with each EP + for (const auto& ep : providers_) { + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_, + fused_kernel_registry, *ep, mode, fused_node_unique_id)); + } + + // expand any nodes that have an ONNX function definition but no matching ORT kernel. + modified_graph = false; + ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph)); + + // Resolve and rerun graph partitioning and inlining if there was a change + if (modified_graph) { + ORT_RETURN_IF_ERROR(graph.Resolve()); + } + } while (modified_graph); return Status::OK(); } -Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const { +#endif // !defined(ORT_MINIMAL_BUILD) + +static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, + KernelRegistryManager& kernel_registry_mgr, + KernelRegistry& fused_kernel_registry, + IExecutionProvider& current_ep, + std::unordered_map& compiled_kernel_hashes, + int& fused_node_unique_id) { + // recurse into nested graphs first to partition bottom up. + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, + current_ep, compiled_kernel_hashes, fused_node_unique_id)); + } + } + + // handle testing edge case where optimizers or constant lifting results in graph with no nodes. + // doing it here saves all providers checking for this in GetCapability + if (graph.NumberOfNodes() == 0) { + return Status::OK(); + } + + const std::string& type = current_ep.Type(); + GraphViewer graph_viewer(graph); + std::vector nodes_and_viewers; + + std::vector> capabilities = + current_ep.GetCapability(graph_viewer, kernel_registry_mgr.GetKernelRegistriesByProviderType(type)); + + // storage for the GraphViewer for each IndexedSubGraph + std::vector> viewers; + viewers.reserve(capabilities.size()); + + for (auto& capability : capabilities) { + const IndexedSubGraph& indexed_sub_graph = *capability->sub_graph; + const IndexedSubGraph::MetaDef* metadef = indexed_sub_graph.GetMetaDef(); + if (!metadef) { + // Static kernel - use the kernel hash that was saved in the ORT format model + continue; + } + + std::ostringstream oss; + oss << type << "_" << metadef->name << "_" << fused_node_unique_id++; + std::string node_name = oss.str(); + + Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name); + fused_node.SetExecutionProviderType(type); + + // create filtered graph viewer for this set of nodes + // + // TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing + // GraphViewer instance instead of the Graph (copying the topological order instead of recalculating). + viewers.push_back(onnxruntime::make_unique(graph, indexed_sub_graph)); + nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()}); + } + + std::vector node_compute_funcs; + node_compute_funcs.reserve(nodes_and_viewers.size()); + + ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); + + if (node_compute_funcs.size() != nodes_and_viewers.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); + } + + for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) { + Node& node = nodes_and_viewers[j].fused_node; + + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j]))); + + const auto& cur_capability = capabilities[j]; + const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); + + KernelDefBuilder builder; + BuildFusedKernelDef(builder, metadef, type); + auto kernel_def = builder.Build(); + + // save hash so SessionState can find the kernel. each kernel name should be unique + if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { + ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, + ". Execution Provider must generate unique names across the entire model."); + } + + ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( + KernelCreateInfo(std::move(kernel_def), static_cast( + [](const OpKernelInfo& info) -> OpKernel* { + return new FunctionKernel(info); + })))); + + // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one + graph.FinalizeFuseSubGraph(indexed_sub_graph, node); + } + + return Status::OK(); +} + +// Simplified partitioning where custom EPs may produce compiled nodes. +// EPs with static kernels do not need to be processed as their kernels are matched via hash information serialized +// as part of the ORT format model. +Status GraphPartitioner::PartitionOrtFormatModel( + Graph& graph, FuncManager& func_mgr, + KernelRegistry& fused_kernel_registry, + std::unordered_map& compiled_kernel_hashes, + int& fused_node_unique_id) const { + // process full graph with each EP + for (const auto& ep : providers_) { + if (ep->Type() == kCpuExecutionProvider) { + // hash for kernel is stored in session state for EPs that have pre-registered kernels + // (vs. runtime fused kernels) so nothing to do here. + continue; + } + + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry, + *ep, compiled_kernel_hashes, fused_node_unique_id)); + } + + return Status::OK(); +} + +Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, Mode mode, + std::unordered_map* compiled_kernel_hashes) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. // 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet. @@ -253,21 +526,23 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f // It is only visible for current session. std::shared_ptr fused_kernel_registry = std::make_shared(); - bool modified_graph = false; - do { - // process full graph with each EP - for (const auto& ep : providers_) { - ORT_RETURN_IF_ERROR(PartitionImpl(graph, export_dll, func_mgr, kernel_registry_mgr_, - *fused_kernel_registry, *ep)); - } + // we make sure each fused node name is unique across the entire model for clarity + int fused_node_unique_id = 0; - modified_graph = false; - // expand any nodes that have an ONNX function definition - // but no matching ORT kernel - ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph)); - // rerun graph partition to assign nodes added as part of - // function expansion. - } while (modified_graph); + if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode, + fused_node_unique_id)); +#else + ORT_UNUSED_PARAMETER(export_dll); + ORT_THROW("Not supported in this build."); +#endif + } else { + ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided"); + + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes, + fused_node_unique_id)); + } if (!fused_kernel_registry->IsEmpty()) { kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry); @@ -276,3 +551,5 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f return Status::OK(); } } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 64c0ebe71c..48225f9c74 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -3,6 +3,8 @@ #pragma once +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/common/common.h" #include "core/graph/graph_viewer.h" #include "core/framework/op_kernel.h" @@ -11,21 +13,43 @@ namespace onnxruntime { class ExecutionProviders; +class KernelRegistry; class KernelRegistryManager; class GraphPartitioner { public: + enum class Mode { + kNormal = 0, + kAssignOnly = 1, // assign nodes. no call to Compile. used to create ORT format model support for compiling EPs + kOrtFormatLoad = 2 // loading ORT format model. Partition with compiling EPs, GraphViewer based Compile. + }; + //The order of providers represents the user preference. GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) : kernel_registry_mgr_(kernel_registry_mgr), - providers_(providers) {} + providers_(providers) { + } - Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr) const; + // Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad. + Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, + Mode mode = Mode::kNormal, + std::unordered_map* compiled_kernel_hashes = nullptr) const; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); +#if !defined(ORT_MINIMAL_BUILD) + Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, + KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id) const; +#endif + + Status PartitionOrtFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, + std::unordered_map& compiled_kernel_hashes, + int& fused_node_unique_id) const; + KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; }; } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index b7e772701f..d4c27112e3 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -45,14 +45,16 @@ Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& executio return Status::OK(); } -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptr kernel_registry) { if (nullptr == kernel_registry) { return; } custom_kernel_registries_.push_front(kernel_registry); } +#endif +#if !defined(ORT_MINIMAL_BUILD) bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { std::vector kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { @@ -84,10 +86,12 @@ Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. ")); } -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) for (auto& registry : custom_kernel_registries_) { status = registry->TryFindKernel(node, std::string(), kernel_def_hash, kernel_create_info); - if (status.IsOK()) return status; + if (status.IsOK()) { + return status; + } } #endif diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 0e34e1e57d..62d91e4ab7 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -32,8 +32,7 @@ class KernelRegistryManager { // Register kernels from providers Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT; -#if !defined(ORT_MINIMAL_BUILD) - +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // The registry passed in this function has highest priority than anything already in this KernelRegistryManager, // and anything registered from RegisterKernels // For example, if you do: @@ -43,16 +42,6 @@ class KernelRegistryManager { // Then B > A > providers void RegisterKernelRegistry(std::shared_ptr kernel_registry); - // This function assumes the node is already assigned to an execution provider - // Don't call this function before graph partition is done - Status SearchKernelRegistry(const onnxruntime::Node& node, - /*out*/ const KernelCreateInfo** kernel_create_info) const; - - /** - * Whether this node can be run on this provider - */ - static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); - /** * Search kernel registry by provider type. * @param type provider type string @@ -70,6 +59,18 @@ class KernelRegistryManager { } #endif +#if !defined(ORT_MINIMAL_BUILD) + // This function assumes the node is already assigned to an execution provider + // Don't call this function before graph partition is done + Status SearchKernelRegistry(const onnxruntime::Node& node, + /*out*/ const KernelCreateInfo** kernel_create_info) const; + + /** + * Whether this node can be run on this provider + */ + static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); +#endif + Status SearchKernelRegistry(const onnxruntime::Node& node, uint64_t kernel_def_hash, /*out*/ const KernelCreateInfo** kernel_create_info) const; diff --git a/onnxruntime/core/framework/op_kernel_info.cc b/onnxruntime/core/framework/op_kernel_info.cc index e26f31483b..4be25b5fe7 100644 --- a/onnxruntime/core/framework/op_kernel_info.cc +++ b/onnxruntime/core/framework/op_kernel_info.cc @@ -79,7 +79,7 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_ return true; } -common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const { - return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release); +common::Status OpKernelInfo::GetFusedFuncs(NodeComputeInfo*& compute_info) const { + return funcs_mgr_.GetFuncs(node_.Name(), compute_info); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0de874bf7b..260b975c73 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -114,7 +114,8 @@ void SessionState::CreateGraphInfo() { for (const auto& output : graph_viewer_->GetOutputs()) { if (output->Exists()) { idx = ort_value_name_idx_map_.Add(output->Name()); - VLOGS(logger_, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx; + VLOGS(logger_, 1) << "Added graph output with name: " << output->Name() + << " to OrtValueIndex with index: " << idx; } } @@ -122,10 +123,25 @@ void SessionState::CreateGraphInfo() { } #if !defined(ORT_MINIMAL_BUILD) -Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) { +Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager, + bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci)); + + auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + if (!status.IsOK() && saving_ort_format) { + // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. + // in that case we assigned the node to that EP but do not compile it into a fused node. + // this keeps the original node and prevents level 2 and level 3 optimizers from modifying it. + // we now revert to the CPU EP to include the hash for the kernel as a fallback. at runtime when the model + // is loaded in a minimal build, the compiling EP will replace this node if possible. if that's not possible for + // some reason we can fallback to the CPU EP implementation via this hash. + node.SetExecutionProviderType(kCpuExecutionProvider); + status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + } + + ORT_RETURN_IF_ERROR(status); + ORT_IGNORE_RETURN_VALUE( kernel_create_info_map_.insert({node.Index(), gsl::not_null(kci)})); } @@ -133,7 +149,7 @@ Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_regi for (const auto& entry : subgraph_session_states_) { for (const auto& name_to_subgraph_session_state : entry.second) { SessionState& subgraph_session_state = *name_to_subgraph_session_state.second; - ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager)); + ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format)); } } @@ -813,16 +829,45 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat "Size mismatch for kernel create info node indexes and hashes. Invalid ORT format model.", node_indices->size(), " != ", kernel_def_hashes->size()); + auto add_kernel_by_hash = + [&kernel_registry_manager, this](const Node& node, uint64_t hash) { + const KernelCreateInfo* kci = nullptr; + ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, hash, &kci)); + kernel_create_info_map_.emplace(node.Index(), gsl::not_null(kci)); + return Status::OK(); + }; + + // kernel hashes for model are in top level SessionState + const auto& compiled_kernel_hashes = GetCompiledKernelHashes(); + + // process the nodes that existed when the model was created for (flatbuffers::uoffset_t i = 0; i < node_indices->size(); i++) { auto node_idx = node_indices->Get(i); - auto kernal_hash = kernel_def_hashes->Get(i); + auto kernel_hash = kernel_def_hashes->Get(i); const Node* node = graph_.GetNode(node_idx); - ORT_RETURN_IF(node == nullptr, "Can't find node with index ", node_idx, ". Invalid ORT format model."); + if (node == nullptr) { + // this is OK if we have compiled kernels and the original node was replaced. if not the model is invalid. + ORT_RETURN_IF(compiled_kernel_hashes.empty(), + "Can't find node with index ", node_idx, ". Invalid ORT format model."); + continue; + } - const KernelCreateInfo* kci = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(*node, kernal_hash, &kci)); - kernel_create_info_map_.emplace(node_idx, gsl::not_null(kci)); + ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_hash)); + } + + // lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices + // as they were created at runtime. + if (!compiled_kernel_hashes.empty()) { + for (const auto& node : graph_.Nodes()) { + if (kernel_create_info_map_.count(node.Index()) == 0) { + auto hash_info = compiled_kernel_hashes.find(node.OpType()); + ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(), + "Unable to find compiled kernel hash for node '", node.Name(), "'.") + + ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second)); + } + } } if (!subgraph_session_states_.empty()) { @@ -880,7 +925,8 @@ Status SessionState::FinalizeSessionState(const std::basic_string&& compiled_kernel_hashes) { + compiled_kernel_hashes_ = std::move(compiled_kernel_hashes); + } + Status LoadFromOrtFormat(const onnxruntime::experimental::fbs::SessionState& fbs_session_state, const KernelRegistryManager& kernel_registry_manager); #endif @@ -277,7 +281,8 @@ class SessionState { KernelRegistryManager& kernel_registry_manager, const SessionOptions& session_options = {}, const onnxruntime::experimental::fbs::SessionState* serialized_session_state = nullptr, - bool remove_initializers = true); + bool remove_initializers = true, + bool saving_ort_format = false); SessionState* Parent() { return parent_; @@ -312,7 +317,7 @@ class SessionState { std::unique_ptr session_state); #if !defined(ORT_MINIMAL_BUILD) - Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager); + Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager, bool saving_ort_format); #endif Status FinalizeSessionStateImpl(const std::basic_string& graph_loc, @@ -330,9 +335,18 @@ class SessionState { std::unordered_map& inferred_shapes) const; #endif + // the SessionState for the main Graph contains the compiled kernel hashes for the entire model + const std::unordered_map& GetCompiledKernelHashes() const { + return parent_ ? parent_->GetCompiledKernelHashes() : compiled_kernel_hashes_; + } + // KernelCreateInfo for each node so we do kernel lookup once std::unordered_map> kernel_create_info_map_; + // If we compile kernels in a minimal build we need a way to find the kernel using the hash. + // We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat. + std::unordered_map compiled_kernel_hashes_; + // cache of the constructed kernels to avoid spending construction time per executor std::vector session_kernels_; Graph& graph_; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 2b95fe3146..0a844be24d 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -101,7 +101,10 @@ bool ProviderIsCpuBased(const std::string& provider_type) { provider_type == onnxruntime::kVitisAIExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || - provider_type == onnxruntime::kRknpuExecutionProvider; + provider_type == onnxruntime::kAclExecutionProvider || + provider_type == onnxruntime::kArmNNExecutionProvider || + provider_type == onnxruntime::kRknpuExecutionProvider || + provider_type == onnxruntime::utils::kInternalTestingExecutionProvider; } static common::Status AllocateHelper(const AllocatorPtr& allocator, @@ -574,16 +577,16 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) const Tensor* curr_input = context->Input(i); ORT_ENFORCE(prev_input->Shape().Size() >= 0); - + size_t input_element_count = static_cast(prev_input->Shape().Size()); size_t input_element_size = prev_input->DataType()->Size(); size_t input_aligned_bytes = 0; - + ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes)); - + ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + input_aligned_bytes || curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + prev_input->SizeInBytes()); - + prev_input = curr_input; } return Status::OK(); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index cfe1087a09..fa380027e3 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -40,6 +40,13 @@ void DefaultFree(void* p); const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info); +// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want +// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal. +constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider"; + +// return true if the execution provider is CPU based (meaning no copies to device are required) +bool ProviderIsCpuBased(const std::string& provider_type); + common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue); diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 439073ffd5..e15d198b20 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -144,6 +144,34 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su } } +static std::unique_ptr CreateSchema(const Graph& graph, + const IndexedSubGraph& nodes_to_fuse) { + const auto* meta_def = nodes_to_fuse.GetMetaDef(); + auto op_schema = onnxruntime::make_unique(); + op_schema->SetName(meta_def->name); + op_schema->SetDomain(meta_def->domain); + op_schema->SetDoc(meta_def->doc_string); + op_schema->SinceVersion(meta_def->since_version); + int i = 0; + + for (auto& input : meta_def->inputs) { + auto input_arg = graph.GetNodeArg(input); + // inputs must have a type. can be inferred for outputs. + ORT_ENFORCE(input_arg->Type() != nullptr); + op_schema->Input(i, input, "", *input_arg->Type()); + ++i; + } + i = 0; + for (auto& output : meta_def->outputs) { + auto output_arg = graph.GetNodeArg(output); + op_schema->Output(i, output, "", *output_arg->Type()); + ++i; + } + op_schema->Finalize(); + + return op_schema; +} + FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, const IndexedSubGraph& nodes_to_fuse, const logging::Logger& logger) @@ -154,12 +182,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, graph.DomainToVersionMap(), {}, logger) { auto& function_body_graph = body_.MainGraph(); - auto meta_def = nodes_to_fuse.GetMetaDef(); - op_schema_ = onnxruntime::make_unique(); - op_schema_->SetName(meta_def->name); - op_schema_->SetDomain(meta_def->domain); - op_schema_->SetDoc(meta_def->doc_string); - op_schema_->SinceVersion(meta_def->since_version); + auto* meta_def = nodes_to_fuse.GetMetaDef(); + op_schema_ = CreateSchema(graph, nodes_to_fuse); int i = 0; std::vector function_body_graph_inputs; @@ -168,8 +192,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, auto input_arg = parent_graph_->GetNodeArg(input); auto& function_body_graph_input_arg = function_body_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); function_body_graph_inputs[i] = &function_body_graph_input_arg; - ORT_ENFORCE(input_arg->Type() != nullptr); - op_schema_->Input(i, input, "", *input_arg->Type()); ++i; } @@ -180,12 +202,9 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, auto output_arg = parent_graph_->GetNodeArg(output); auto& function_body_graph_output_arg = function_body_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); function_body_graph_outputs[i] = &function_body_graph_output_arg; - op_schema_->Output(i, output, "", *output_arg->Type()); ++i; } - op_schema_->Finalize(); - function_body_graph.SetInputs(function_body_graph_inputs); function_body_graph.SetOutputs(function_body_graph_outputs); @@ -238,7 +257,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, // Hence, we make a copy prior to generating the graph representation of the function, // as we might make some modifications to the FunctionProto along the way - auto node_in_parent_graph = parent_graph_->GetNode(node_index); + const auto* node_in_parent_graph = parent_graph_->GetNode(node_index); op_schema_ = onnxruntime::make_unique(); op_schema_->SetName(onnx_func_proto_.name()); op_schema_->SetDomain(onnx_func_proto_.node().Get(0).domain()); @@ -256,14 +275,13 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, auto cached_op_schema = node_in_parent_graph->Op(); if (!cached_op_schema) { // Infer a op_schema for stand-alone functions. - IOTypeConstraintHelper(onnx_func_proto_, this->op_schema_, input_name_idx_map, output_name_idx_map); + IOTypeConstraintHelper(onnx_func_proto_, op_schema_, input_name_idx_map, output_name_idx_map); } else { auto type_constraint_params = cached_op_schema->typeConstraintParams(); for (auto& type_constraint_param : type_constraint_params) { - op_schema_->TypeConstraint( - type_constraint_param.type_param_str, - type_constraint_param.allowed_type_strs, - type_constraint_param.description); + op_schema_->TypeConstraint(type_constraint_param.type_param_str, + type_constraint_param.allowed_type_strs, + type_constraint_param.description); } int i = 0; for (auto& input : cached_op_schema->inputs()) { @@ -286,10 +304,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, op_schema_->TypeAndShapeInferenceFunction( [this](ONNX_NAMESPACE::InferenceContext& ctx) { auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); - const ONNX_NAMESPACE::FunctionProto* func_ptr = this->GetFuncProto(); - if (nullptr != func_ptr) { - ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(func_ptr, schema_registry, ctx); - } + ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(&onnx_func_proto_, schema_registry, ctx); }); } else { op_schema_->TypeAndShapeInferenceFunction(cached_op_schema->GetTypeAndShapeInferenceFunction()); @@ -321,7 +336,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, } for (int idx = 0; idx < (*node).input_size(); ++idx) { - std::string tensor_name = (*node).input().Get(idx); + const std::string& tensor_name = (*node).input().Get(idx); auto iter = input_name_idx_map.find(tensor_name); if (iter != input_name_idx_map.end()) { // Preserving NodeArg and input/output names @@ -397,10 +412,14 @@ const onnxruntime::Graph& FunctionImpl::Body() const { return body_.MainGraph(); } -const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const { - return &onnx_func_proto_; +ViewerFunctionImpl::ViewerFunctionImpl(const onnxruntime::Graph& graph, + const IndexedSubGraph& nodes_to_fuse, + const logging::Logger& /*logger*/) { + op_schema_ = CreateSchema(graph, nodes_to_fuse); } +ViewerFunctionImpl::~ViewerFunctionImpl() = default; + std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, const IndexedSubGraph& nodes_to_fuse, const logging::Logger& logger) { diff --git a/onnxruntime/core/graph/function_impl.h b/onnxruntime/core/graph/function_impl.h index e477b35f22..3333e3c2a6 100644 --- a/onnxruntime/core/graph/function_impl.h +++ b/onnxruntime/core/graph/function_impl.h @@ -31,8 +31,6 @@ class FunctionImpl final : public Function { const onnxruntime::Graph& Body() const override; - const ONNX_NAMESPACE::FunctionProto* GetFuncProto() const; - private: const onnxruntime::Graph* const parent_graph_; std::unique_ptr op_schema_; @@ -40,4 +38,22 @@ class FunctionImpl final : public Function { ONNX_NAMESPACE::FunctionProto onnx_func_proto_; }; +// Function that uses a GraphViewer so does not need to build a new Model. We still need the OpSchema to be available +// though so we just create that. +class ViewerFunctionImpl final : public Function { + public: + ViewerFunctionImpl(const onnxruntime::Graph& graph, + const IndexedSubGraph& nodes_to_fuse, + const logging::Logger& logger); + + ~ViewerFunctionImpl() override; + + const ONNX_NAMESPACE::OpSchema& OpSchema() const override { return *op_schema_; } + + const onnxruntime::Graph& Body() const override { ORT_THROW("Not supported"); } + + private: + std::unique_ptr op_schema_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f020326a04..f2217e4b85 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -133,21 +133,26 @@ static TypeProto TypeProtoFromTensorProto(const TensorProto& tensor) { return t; } +#endif // !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) NodeArg::NodeArg(const std::string& name, const TypeProto* p_node_arg_type) { node_arg_info_.set_name(name); // If the name is empty, it means the arg does not exist. exists_ = !(name.empty()); if (nullptr != p_node_arg_type) { (*node_arg_info_.mutable_type()) = *p_node_arg_type; +#if !defined(ORT_MINIMAL_BUILD) + // should not be possible to have invalid values in the ORT format model, so we don't need this + // in a minimal build RemoveInvalidValues(*node_arg_info_.mutable_type()); +#endif type_ = DataTypeUtils::ToType(node_arg_info_.type()); } else { type_ = nullptr; } } - -#endif // !defined(ORT_MINIMAL_BUILD) +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) NodeArg::NodeArg(NodeArgInfo&& node_arg_info) { node_arg_info_ = std::move(node_arg_info); @@ -427,10 +432,6 @@ void Node::SetPriority(int priority) noexcept { #if !defined(ORT_MINIMAL_BUILD) -void Node::SetNodeType(Node::Type node_type) noexcept { - node_type_ = node_type; -} - const Function* Node::GetFunctionBody(bool try_init_func_body) { if (nullptr != func_body_) { return func_body_; @@ -450,10 +451,6 @@ void Node::SetFunctionBody(const Function& func) { since_version_ = op_->since_version(); } -void Node::SetExecutionProviderType(ProviderType execution_provider_type) { - execution_provider_type_ = execution_provider_type; -} - void Node::ToProto(NodeProto& proto, bool update_subgraphs) const { proto.set_name(name_); proto.set_op_type(op_type_); @@ -666,9 +663,9 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::experimental::fbs::NodeEd return Status::OK(); } -#endif +#endif // defined(ENABLE_ORT_FORMAT_LOAD) -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void Node::Init(const std::string& name, const std::string& op_type, const std::string& description, @@ -697,7 +694,11 @@ void Node::Init(const std::string& name, for (auto& name_to_attr : attributes_) { if (utils::HasGraph(name_to_attr.second)) { +#if !defined(ORT_MINIMAL_BUILD) CreateSubgraph(name_to_attr.first); +#else + ORT_THROW("Creating node with a subgraph via AddNode is not supported in this build."); +#endif } } } @@ -716,7 +717,9 @@ Node::Relationships& Node::MutableRelationships() noexcept { graph_->SetGraphProtoSyncNeeded(); return relationships_; } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) void Node::CreateSubgraph(const std::string& attr_name) { auto attr = attributes_.find(attr_name); @@ -1264,7 +1267,9 @@ common::Status Graph::SetOuterScopeNodeArgs(const std::unordered_setMutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot)); nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot)); } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint Status Graph::BuildConnections(std::unordered_set& outer_scope_node_args_consumed) { const std::unordered_set& outer_scope_node_args = resolve_context_.outer_scope_node_args; @@ -1596,6 +1603,7 @@ void Graph::ReverseDFSFrom(const std::vector& from, } } +#if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { std::unordered_map in_degree; @@ -1634,7 +1642,6 @@ void Graph::KahnsTopologicalSort(const std::function& enter, ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } -#if !defined(ORT_MINIMAL_BUILD) GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...) Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { @@ -2896,7 +2903,9 @@ std::string Graph::GenerateNodeName(const std::string& base_name) { return new_name; } +#endif // !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Node& Graph::AddNode(const std::string& name, const std::string& op_type, const std::string& description, @@ -2945,7 +2954,9 @@ bool Graph::RemoveNode(NodeIndex p_index) { return ReleaseNode(p_index); } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) bool Graph::AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index) { if (nodes_.size() <= src_node_index || nodes_.size() <= dst_node_index || @@ -3282,6 +3293,12 @@ Status Graph::SetGraphInputsOutputs() { return Status::OK(); } +IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const { + return schema_registry_; +} +#endif // !defined(ORT_MINIMAL_BUILD) + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // calling private ctor GSL_SUPPRESS(r .11) gsl::not_null Graph::AllocateNode() { @@ -3313,21 +3330,24 @@ bool Graph::ReleaseNode(NodeIndex index) { return true; } -IOnnxRuntimeOpSchemaCollectionPtr Graph::GetSchemaRegistry() const { - return schema_registry_; -} - -Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { - auto* func_meta_def = sub_graph.GetMetaDef(); +Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { + const auto* func_meta_def = sub_graph.GetMetaDef(); ORT_ENFORCE(nullptr != func_meta_def); - std::vector input_args; std::vector output_args; + std::unordered_map input_indexes; + std::unordered_map output_indexes; + + int cur_idx = 0; for (auto& arg_name : func_meta_def->inputs) { input_args.push_back(GetNodeArg(arg_name)); + input_indexes[arg_name] = cur_idx++; } + + cur_idx = 0; for (auto& arg_name : func_meta_def->outputs) { output_args.push_back(GetNodeArg(arg_name)); + output_indexes[arg_name] = cur_idx++; } auto& fused_node = AddNode(fused_node_name, @@ -3339,21 +3359,102 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& f func_meta_def->domain); fused_node.SetNodeType(Node::Type::Fused); - function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_)); - fused_node.SetFunctionBody(*function_container_.back()); - // Remove nodes fused above. + return fused_node; +} + +Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { + Node& node = CreateFusedSubGraphNode(sub_graph, fused_node_name); + +#if !defined(ORT_MINIMAL_BUILD) + // if this is a full build create the lightweight Function implementation that provides the schema so that + // kernel lookup works as per usual. in an extended minimal build we do the lookup via a hash so don't + // need to create the schema. + auto func = onnxruntime::make_unique(*this, sub_graph, logger_); + function_container_.push_back(std::move(func)); + node.SetFunctionBody(*function_container_.back()); +#endif + + return node; +} + +void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) { + const auto* func_meta_def = sub_graph.GetMetaDef(); + ORT_ENFORCE(nullptr != func_meta_def); + + std::unordered_map input_indexes; + std::unordered_map output_indexes; + + int cur_idx = 0; + for (auto& arg_name : func_meta_def->inputs) { + input_indexes[arg_name] = cur_idx++; + } + + cur_idx = 0; + for (auto& arg_name : func_meta_def->outputs) { + output_indexes[arg_name] = cur_idx++; + } + + auto new_node_idx = fused_node.Index(); + + // Remove nodes that were fused for (auto node_index : sub_graph.nodes) { auto node = GetNode(node_index); if (nullptr == node) { continue; } - auto output_edges = node->GetRelationships().output_edges; - for (auto output_edge : output_edges) { - RemoveEdge(node->Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); + + // move any applicable input edges to the new node. remove all others + auto input_edges = node->GetRelationships().input_edges; // copy so RemoveEdge doesn't invalidate iterator + for (const auto& input_edge : input_edges) { + const auto& producer = input_edge.GetNode(); + auto producer_idx = producer.Index(); + auto src_idx = input_edge.GetSrcArgIndex(); + auto dst_idx = input_edge.GetDstArgIndex(); + + // if this input is an input of the fused node add an edge for that + auto it = input_indexes.find(node->InputDefs()[dst_idx]->Name()); + if (it != input_indexes.cend()) { + AddEdge(producer_idx, new_node_idx, src_idx, it->second); + } + + RemoveEdge(producer_idx, node_index, src_idx, dst_idx); } + + // move any applicable output edges to the new node + auto output_edges = node->GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator + for (const auto& output_edge : output_edges) { + const auto& consumer = output_edge.GetNode(); + auto consumer_idx = consumer.Index(); + auto src_idx = output_edge.GetSrcArgIndex(); + auto dst_idx = output_edge.GetDstArgIndex(); + + // if this output is an output of the fused node add an edge for that + auto it = output_indexes.find(node->OutputDefs()[src_idx]->Name()); + if (it != output_indexes.cend()) { + AddEdge(new_node_idx, consumer_idx, it->second, dst_idx); + } + + RemoveEdge(node_index, consumer_idx, src_idx, dst_idx); + } + RemoveNode(node_index); } +} + +#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#if !defined(ORT_MINIMAL_BUILD) +Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, + const std::string& fused_node_name) { + Node& fused_node = CreateFusedSubGraphNode(sub_graph, fused_node_name); + + // create Function before we remove nodes + function_container_.emplace_back(MakeFunction(*this, sub_graph, logger_)); + fused_node.SetFunctionBody(*function_container_.back()); + + // remove nodes and update edges + FinalizeFuseSubGraph(sub_graph, fused_node); return fused_node; } @@ -3541,6 +3642,7 @@ Graph::Graph(const Model& owning_model, schema_registry_(std::make_shared()), #endif domain_to_version_(domain_to_version), + ir_version_(owning_model.IrVersion()), parent_graph_(parent_graph), parent_node_(parent_node), logger_(logger), diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index b0f563139b..ebfcfdd16f 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -4,6 +4,7 @@ #include "transformer_memcpy.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -62,15 +63,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // and mainly provides the subgraph recursion functionality common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { for (auto& provider : provider_types_) { - if (provider != onnxruntime::kCpuExecutionProvider && - provider != onnxruntime::kDnnlExecutionProvider && - provider != onnxruntime::kNGraphExecutionProvider && - provider != onnxruntime::kNupharExecutionProvider && - provider != onnxruntime::kVitisAIExecutionProvider && - provider != onnxruntime::kOpenVINOExecutionProvider && - provider != onnxruntime::kNnapiExecutionProvider && - provider != onnxruntime::kAclExecutionProvider && - provider != onnxruntime::kArmNNExecutionProvider) { + if (!utils::ProviderIsCpuBased(provider)) { TransformerMemcpyImpl copy_impl(graph, provider); auto current_modified = copy_impl.ModifyGraph(registry_manager_); modified = modified || current_modified; @@ -163,9 +156,11 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed) { auto node_provider_type = node.GetExecutionProviderType(); - if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) { + if ((node_provider_type == provider_) || + (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index dc8bbf6a74..06e77f4bba 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -745,45 +745,53 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // If the Reshape output is also a graph output, since NNAPI output is a void* buffer, we can find the shape // information in onnxruntime::nnapi::Model and pass the correct shape information back to ORT to be used as output shape /* static */ bool ReshapeOpBuilder::CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank) { - const auto& output = node.OutputDefs()[0]->Name(); - // We will go through all the output edges - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - const auto& op_type = it->GetNode().OpType(); - // TODO add quantized matmul when reshape support quantized input - if (op_type != "Gemm" && op_type != "MatMul") { - LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul" - << " or no op is using the output (output is graph output)" - << ", output name, " << output - << " is used by " << op_type; - return false; - } + // + // TEMPORARILY DISABLED. Needs refinement. + // + // const auto& output = node.OutputDefs()[0]->Name(); + // // We will go through all the output edges + // for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + // const auto& op_type = it->GetNode().OpType(); + // // TODO add quantized matmul when reshape support quantized input + // if (op_type != "Gemm" && op_type != "MatMul") { + // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul" + // << " or no op is using the output (output is graph output)" + // << ", output name, " << output + // << " is used by " << op_type; + // return false; + // } - // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0 - if (it->GetDstArgIndex() != 0) { - LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul" - << ", output name, " << output; - return false; - } + // // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0 + // if (it->GetDstArgIndex() != 0) { + // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul" + // << ", output name, " << output; + // return false; + // } - // We only support 2d matmul/gemm here - // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2 - if (input_rank < 2 || output_rank != 2) { - LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2" - << ", output name, " << output - << ", the actual input_rank, " << input_rank - << ", the actual output_rank, " << output_rank; - return false; - } - } + // // We only support 2d matmul/gemm here + // // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2 + // if (input_rank < 2 || output_rank != 2) { + // LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2" + // << ", output name, " << output + // << ", the actual input_rank, " << input_rank + // << ", the actual output_rank, " << output_rank; + // return false; + // } + // } - // If we reach here, we have either, - // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here] - // or - // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care] - // we can skip this Reshape now - LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node [" - << node.Name() << "] with output, " << output; - return true; + // // If we reach here, we have either, + // // all the Reshape outputs are used by gemm/matmul, the output can also be a model output [doesn't really matter here] + // // or + // // Reshape has no output edge ==> the output is a graph output or a dead end [which we don't care] + // // we can skip this Reshape now + // LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node [" + // << node.Name() << "] with output, " << output; + // return true; + + ORT_UNUSED_PARAMETER(node); + ORT_UNUSED_PARAMETER(input_rank); + ORT_UNUSED_PARAMETER(output_rank); + return false; } /* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index a4d9c18ee2..16651720fb 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -80,7 +80,6 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // Find inputs, initializers and outputs for each supported subgraph const std::vector& node_index = graph_view.GetNodesInTopologicalOrder(); const auto& graph_outputs = graph_view.GetOutputs(); - int counter = 0; for (const auto& group : supported_nodes_vector) { if (group.empty()) continue; @@ -173,7 +172,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // Assign inputs and outputs to subgraph's meta_def auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "NNAPI_" + std::to_string(counter++); + meta_def->name = "NNAPI_" + std::to_string(metadef_id_++); meta_def->domain = kMSDomain; for (const auto& input : inputs) { @@ -202,6 +201,7 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort, const std::vector& output_shape, const android::nn::wrapper::Type output_type, void** output_buffer) ORT_MUST_USE_RESULT; + static Status GetOutputBuffer(Ort::CustomOpApi& ort, OrtKernelContext* context, const nnapi::Model& model, @@ -235,50 +235,43 @@ static Status GetOutputBuffer(Ort::CustomOpApi& ort, return Status::OK(); } -common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes, +common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; - for (const auto* fused_node : fused_nodes) { - // Reconstruct graph proto from fused node's function body - const auto* func_body = fused_node->GetFunctionBody(); - if (!func_body) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - } + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { + Node& fused_node = fused_node_and_graph.fused_node; + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - const Graph& graph_body = func_body->Body(); + nnapi::ModelBuilder builder(graph_viewer); + builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); + builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); + std::unique_ptr nnapi_model; + ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model)); + + // Build map from input name to its index in input definitions { - onnxruntime::GraphViewer graph_viewer(graph_body); - nnapi::ModelBuilder builder(graph_viewer); - builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); - builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); - std::unique_ptr nnapi_model; - ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model)); - - // Build map from input name to its index in input definitions - { - std::unordered_map input_map; - const auto& input_defs = fused_node->InputDefs(); - input_map.reserve(input_defs.size()); - for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - input_map[input_defs[i]->Name()] = i; - } - nnapi_model->SetInputMap(std::move(input_map)); + std::unordered_map input_map; + const auto& input_defs = fused_node.InputDefs(); + input_map.reserve(input_defs.size()); + for (size_t i = 0, end = input_defs.size(); i < end; ++i) { + input_map[input_defs[i]->Name()] = i; } - - // Build map from output name to its index in output definitions - { - std::unordered_map output_map; - const auto& output_defs = fused_node->OutputDefs(); - output_map.reserve(output_defs.size()); - for (size_t i = 0, end = output_defs.size(); i < end; ++i) { - output_map[output_defs[i]->Name()] = i; - } - nnapi_model->SetOutputMap(std::move(output_map)); - } - - nnapi_models_.emplace(fused_node->Name(), std::move(nnapi_model)); + nnapi_model->SetInputMap(std::move(input_map)); } + // Build map from output name to its index in output definitions + { + std::unordered_map output_map; + const auto& output_defs = fused_node.OutputDefs(); + output_map.reserve(output_defs.size()); + for (size_t i = 0, end = output_defs.size(); i < end; ++i) { + output_map[output_defs[i]->Name()] = i; + } + nnapi_model->SetOutputMap(std::move(output_map)); + } + + nnapi_models_.emplace(fused_node.Name(), std::move(nnapi_model)); + NodeComputeInfo compute_info; compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) { *state = nnapi_models_[context->node_name].get(); @@ -432,20 +425,21 @@ common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes, +common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { - for (const auto* fused_node : fused_nodes) { - ORT_UNUSED_PARAMETER(fused_node); + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { + ORT_UNUSED_PARAMETER(fused_node_and_graph); NodeComputeInfo compute_info; compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; }; compute_info.release_state_func = [](FunctionState /*state*/) {}; - compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */, OrtKernelContext* /* context */) { + compute_info.compute_func = [](FunctionState /* state */, const OrtCustomOpApi* /* api */, + OrtKernelContext* /* context */) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build."); }; node_compute_funcs.push_back(compute_info); } return Status::OK(); } -#endif +#endif // __ANDROID__ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index ff72e93570..a443ee58c1 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -19,11 +19,21 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const std::vector& /*kernel_registries*/) const override; - common::Status Compile(const std::vector& fused_nodes, + + // we implement the Compile that takes FusedNodeAndGraph instances + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; +#endif + unsigned long GetNNAPIFlags() const { return nnapi_flags_; } private: + // unique counter to name each fused kernel across the entire model + mutable int metadef_id_{0}; + // The bit flags which define bool options for NNAPI EP, bits are defined as // NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h const unsigned long nnapi_flags_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0b70fab0b6..87fe26849d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -793,7 +793,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, const InsertCastTransformer& insert_cast_transformer, - SessionState& session_state) { + SessionState& session_state, + bool saving_model_in_ort_format) { // The transformer order: // 1. built-in graph rewriter // 2. each execution provider's transformer @@ -822,10 +823,17 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, } #endif + // if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them. + // we do this to preserve the original nodes in the model but prevent optimizers from changing them. + // at runtime, the ORT format model will re-do the partitioning/compilation of these nodes, which may change + // to cover fewer nodes due to device capabilities. + auto mode = saving_model_in_ort_format ? GraphPartitioner::Mode::kAssignOnly + : GraphPartitioner::Mode::kNormal; + // Do partitioning based on execution providers' capability. GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(), - session_state.GetMutableFuncMgr())); + session_state.GetMutableFuncMgr(), mode)); // apply transformers except default transformers // Default transformers are required for correctness and they are owned and run by inference session @@ -888,9 +896,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, return common::Status::OK(); } - #endif // !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph, + const ExecutionProviders& providers, + KernelRegistryManager& kernel_registry_manager, + SessionState& session_state) const { + std::unordered_map compiled_kernel_hashes; + + GraphPartitioner partitioner(kernel_registry_manager, providers); + ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(), + session_state.GetMutableFuncMgr(), + GraphPartitioner::Mode::kOrtFormatLoad, + &compiled_kernel_hashes)); + + if (!compiled_kernel_hashes.empty()) { + session_state.SetCompiledKernelHashes(std::move(compiled_kernel_hashes)); + } + + return Status::OK(); +} +#endif + #if defined(ENABLE_ORT_FORMAT_LOAD) template static Status LoadOrtModelBytes(const std::basic_string& model_uri, @@ -1131,38 +1159,65 @@ common::Status InferenceSession::Initialize() { // Register 2nd registries into KernelRegistryManager. ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_)); + bool loading_ort_format = !ort_format_model_bytes_.empty(); + bool saving_model = !session_options_.optimized_model_filepath.empty(); + bool saving_ort_format = false; + if (saving_model) { + std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, ""); + bool has_explicit_type = !model_type.empty(); + saving_ort_format = ((has_explicit_type && model_type == "ORT") || + (!has_explicit_type && + experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))); + } + #if !defined(ORT_MINIMAL_BUILD) - // add predefined transformers - AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, - transformers_to_enable_); + if (!loading_ort_format) { + // add predefined transformers + AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, + transformers_to_enable_); - // apply any transformations to the main graph and any subgraphs - ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, - execution_providers_, kernel_registry_manager_, - insert_cast_transformer_, - *session_state_)); + // apply any transformations to the main graph and any subgraphs + ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, + execution_providers_, kernel_registry_manager_, + insert_cast_transformer_, + *session_state_, + saving_ort_format)); - // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. - ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. + ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph - ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_)); + // Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph + ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_)); + } else #endif // !defined(ORT_MINIMAL_BUILD) + { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // nodes are already partitioned, but a custom EP may compile some at runtime. + // run the partitioning to allow that to happen. + // + // We always have the CPU EP, so only need to run this if some other EP is enabled + if (execution_providers_.NumProviders() > 1) { + ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, + *session_state_)); + } +#endif + } - // need to keep the initializers if we're going to save the optimized model - bool keep_initializers = !session_options_.optimized_model_filepath.empty(); + const experimental::fbs::SessionState* serialized_session_state = + loading_ort_format + ? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state() + : nullptr; - auto* serialized_session_state = !ort_format_model_bytes_.empty() - ? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state() - : nullptr; - - ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, - session_options_, - serialized_session_state, - !keep_initializers)); + ORT_RETURN_IF_ERROR_SESSIONID_( + session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, + session_options_, + serialized_session_state, + // need to keep the initializers if saving the optimized model + !saving_model, + saving_ort_format)); #if !defined(ORT_MINIMAL_BUILD) - if (!session_options_.optimized_model_filepath.empty()) { + if (saving_model) { if (session_options_.graph_optimization_level >= TransformerLevel::Level3) { LOGS(*session_logger_, WARNING) << "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. " @@ -1170,12 +1225,7 @@ common::Status InferenceSession::Initialize() { "and should only be used in the same environment the model was optimized for."; } - std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, ""); - bool has_explicit_type = !model_type.empty(); - - if ((has_explicit_type && model_type == "ORT") || - (!has_explicit_type && - experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath))) { + if (saving_ort_format) { ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath)); } else { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index dde8911829..446c5d2753 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -57,9 +57,7 @@ class LoggingManager; */ struct ModelMetadata { ModelMetadata() = default; - ModelMetadata(const ModelMetadata& other) - : producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) { - } + ModelMetadata(const ModelMetadata&) = default; ~ModelMetadata() = default; ModelMetadata& operator=(const ModelMetadata&) = delete; @@ -531,7 +529,8 @@ class InferenceSession { const onnxruntime::GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, const InsertCastTransformer& insert_cast_transformer, - SessionState& session_state) ORT_MUST_USE_RESULT; + SessionState& session_state, + bool saving_model_in_ort_format) ORT_MUST_USE_RESULT; onnxruntime::GraphTransformerManager graph_transformation_mgr_; @@ -543,6 +542,11 @@ class InferenceSession { std::vector transformers_to_enable_; #endif +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, + KernelRegistryManager& kernel_registry_manager, SessionState& session_state) const; +#endif + SessionOptions session_options_; /// Logging manager if provided. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d8e71b45ea..27c5ec1fa3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -231,9 +231,10 @@ namespace py = pybind11; using namespace onnxruntime; using namespace onnxruntime::logging; +static Env& platform_env = Env::Default(); + #if !defined(ORT_MINIMAL_BUILD) // Custom op section starts -static Env& platform_env = Env::Default(); CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) { { @@ -650,10 +651,10 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector #endif } else if (type == kRocmExecutionProvider) { #ifdef USE_ROCM - RegisterExecutionProvider( - sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id, - cuda_mem_limit, - arena_extend_strategy)); + RegisterExecutionProvider( + sess, *onnxruntime::CreateExecutionProviderFactory_ROCM(cuda_device_id, + cuda_mem_limit, + arena_extend_strategy)); #endif } else if (type == kDnnlExecutionProvider) { #ifdef USE_DNNL @@ -685,11 +686,9 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector } else if (option.first == "device_id") { openvino_device_id = option.second; - } - else if (option.first == "num_of_threads") { + } else if (option.first == "num_of_threads") { num_of_threads = std::stoi(option.second); - } - else { + } else { ORT_THROW("Invalid OpenVINO EP option: ", option.first); } } @@ -946,7 +945,7 @@ void addGlobalMethods(py::module& m, const Environment& env) { ORT_UNUSED_PARAMETER(algo); ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); #else - cudnn_conv_algo_search = algo; + cudnn_conv_algo_search = algo; #endif }); m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) { @@ -954,7 +953,7 @@ void addGlobalMethods(py::module& m, const Environment& env) { ORT_UNUSED_PARAMETER(use_single_stream); ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM"); #else - do_copy_in_default_stream = use_single_stream; + do_copy_in_default_stream = use_single_stream; #endif }); m.def("set_cuda_mem_limit", [](const int64_t limit) { cuda_mem_limit = static_cast(limit); }); @@ -1734,7 +1733,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") .def("end_profiling", [](PyInferenceSession* sess) -> std::string { return sess->GetSessionHandle()->EndProfiling(); }) - .def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t{ + .def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t { return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs(); }) .def("get_providers", [](PyInferenceSession* sess) -> const std::vector& { diff --git a/onnxruntime/test/providers/internal_testing/README.md b/onnxruntime/test/providers/internal_testing/README.md new file mode 100644 index 0000000000..11a45972ae --- /dev/null +++ b/onnxruntime/test/providers/internal_testing/README.md @@ -0,0 +1,3 @@ +Internal test EP that is used to test and validate interactions between the ORT framework, optimizers and an EP. + +Primary usage currently is validating support in a minimal build for an EP that compiles nodes into kernels at runtime. \ No newline at end of file diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc new file mode 100644 index 0000000000..3c02d26cf6 --- /dev/null +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "internal_testing_execution_provider.h" + +#include "core/framework/allocatormgr.h" +#include "core/framework/compute_capability.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/graph/model.h" +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { + +constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; + +InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops) + : IExecutionProvider{utils::kInternalTestingExecutionProvider}, + ops_{ops} { + // TODO: Allocation planner calls GetAllocator for the individual EP. It would be better if it goes through + // the session state to get the allocator so it's per-device (or for the allocation planner to try the EP first + // and fall back to using session state next by passing in a functor it can use to call SessionState::GetAllocator). + + AllocatorCreationInfo device_info( + [](int) { + return onnxruntime::make_unique(OrtMemoryInfo(INTERNAL_TESTING_EP, + OrtAllocatorType::OrtDeviceAllocator)); + }); + + InsertAllocator(CreateAllocator(device_info)); +} + +InternalTestingExecutionProvider::~InternalTestingExecutionProvider() {} + +std::vector> +InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const std::vector& /*kernel_registries*/) const { + std::vector> result; + + /* + Very basic search for groups of nodes that can be handled by the EP. + This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP + but B is between them in the topological sort as you'll get two single node capabilities. However if can also + be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected. + Not sure how often each of these scenarios happens. + + A B C + | / | + D E + | | + + Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order, + accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all + connected nodes that can be handled are grouped together. + */ + + std::vector> node_groups; + std::vector cur_group; + for (NodeIndex node_index : graph_viewer.GetNodesInTopologicalOrder()) { + if (ops_.count(graph_viewer.GetNode(node_index)->OpType())) { + cur_group.push_back(node_index); + } else if (!cur_group.empty()) { + node_groups.push_back(std::move(cur_group)); + } + } + + if (!cur_group.empty()) { + node_groups.push_back(std::move(cur_group)); + } + + if (node_groups.empty()) { + return result; + } + + const auto& graph_output_list = graph_viewer.GetOutputs(); + std::unordered_set graph_outputs(graph_output_list.cbegin(), graph_output_list.cend()); + + for (const auto& group : node_groups) { + std::unordered_set node_set; + node_set.reserve(group.size()); + for (const auto& index : group) { + node_set.insert(index); + } + + std::unique_ptr sub_graph = onnxruntime::make_unique(); + + std::unordered_set node_outputs; + std::unordered_set subgraph_inputs; + std::unordered_set subgraph_outputs; + std::vector ordered_subgraph_inputs; + std::vector ordered_subgraph_outputs; + + for (const auto& index : group) { + sub_graph->nodes.push_back(index); + const auto* node = graph_viewer.GetNode(index); + + for (const auto* input : node->InputDefs()) { + // if the node input was not produced by this subgraph, add it to the subgraph inputs. + if (node_outputs.count(input) == 0) { + if (subgraph_inputs.count(input) == 0) { + subgraph_inputs.insert(input); + ordered_subgraph_inputs.push_back(input); + } + } + } + + const auto& output_defs = node->OutputDefs(); + for (const auto* output_def : output_defs) { + node_outputs.insert(output_def); + // if output is overall graph output we need to produce it. + if (graph_outputs.count(output_def) != 0) { + ordered_subgraph_outputs.push_back(output_def); + } + } + + // if output connects to a node not in this subgraph we need to produce it + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + if (node_set.count(it->GetNode().Index()) == 0) { + const auto* output_def = output_defs[it->GetSrcArgIndex()]; + if (subgraph_outputs.count(output_def) == 0) { + subgraph_outputs.insert(output_def); + ordered_subgraph_outputs.push_back(output_def); + } + } + } + } + + // Assign inputs and outputs to subgraph's meta_def + auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); + meta_def->name = "InternalTestingEP_" + std::to_string(metadef_id_++); + meta_def->domain = kMSDomain; + meta_def->since_version = 1; + meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + + for (const auto& input : ordered_subgraph_inputs) { + meta_def->inputs.push_back(input->Name()); + } + + for (const auto& output : ordered_subgraph_outputs) { + meta_def->outputs.push_back(output->Name()); + } + + sub_graph->SetMetaDef(std::move(meta_def)); + + result.push_back(onnxruntime::make_unique(std::move(sub_graph))); + } + + return result; +} + +common::Status InternalTestingExecutionProvider::Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) { + // Create a function to generate dummy empty output for each fused node so the model can be executed. + for (const auto& node_and_viewer : fused_nodes) { + NodeComputeInfo compute_info; + const Node& node = node_and_viewer.fused_node; + + //{ + // const GraphViewer& graph_viewer = node_and_viewer.filtered_graph; + // std::cout << "Fusing nodes: "; + // for (const auto& unfused_node : graph_viewer.Nodes()) { + // std::cout << " '" << unfused_node.Name() << "':" << unfused_node.Index(); + // } + // std::cout << std::endl; + //} + + compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { + return 0; + }; + + compute_info.release_state_func = [](FunctionState /*state*/) { + }; + + compute_info.compute_func = [&node](FunctionState /*state*/, const OrtCustomOpApi* c_api, + OrtKernelContext* context) -> Status { + Ort::CustomOpApi api{*c_api}; // use C++ API for convenience + + const auto outputs = node.OutputDefs(); + const size_t num_outputs = outputs.size(); + + for (size_t i = 0; i < num_outputs; i++) { + const auto* shape_proto = outputs[i]->Shape(); + if (shape_proto == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown output shapes are not supported"); + } + + TensorShape shape = utils::GetTensorShapeFromTensorShapeProto(*shape_proto); + if (shape.Size() < 0) { + // arbitrarily set any unknown dim to 1 + for (size_t idx = 0, end = shape.NumDimensions(); idx < end; ++idx) { + if (shape[idx] == -1) { + shape[idx] = 1; + } + } + } + + // create the output_tensor. + auto* ortvalue = api.KernelContext_GetOutput(context, i, shape.GetDims().data(), shape.GetDims().size()); + + // and fill with zeros + auto* tensor = ortvalue->GetMutable(); + void* data = tensor->MutableDataRaw(); + auto bytes = tensor->SizeInBytes(); + memset(data, 0, bytes); + }; + + return Status::OK(); + }; + + node_compute_funcs.push_back(std::move(compute_info)); + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h new file mode 100644 index 0000000000..a10ee1bd90 --- /dev/null +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +class InternalTestingExecutionProvider : public IExecutionProvider { + public: + InternalTestingExecutionProvider(const std::unordered_set& ops); + virtual ~InternalTestingExecutionProvider(); + + std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_view, + const std::vector& /*kernel_registries*/) const override; + + common::Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; + + FusionStyle GetFusionStyle() const override { + return FusionStyle::FilteredGraphViewer; + } + + private: + const std::unordered_set ops_; + + // unique counter to name each fused kernel across the entire model + mutable int metadef_id_{0}; +}; +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc new file mode 100644 index 0000000000..9784f497dd --- /dev/null +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/logging/logging.h" +#include "core/framework/utils.h" +#include "core/session/inference_session.h" + +#include "test/framework/test_utils.h" +#include "test/test_environment.h" +#include "test/providers/internal_testing/internal_testing_execution_provider.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +namespace onnxruntime { + +namespace test { + +static void CreateSession(const SessionOptions& so, std::unique_ptr& session, + const ORTCHAR_T* model_path = ORT_TSTR("testdata/mnist.onnx"), // arbitrary test model + bool enable_custom_ep = true, + const std::unordered_set* override_supported_ops = nullptr) { + session = onnxruntime::make_unique(so, GetEnvironment()); + + // set supported ops to ops that are ideally found consecutively in the model. + // we can say the EP potentially handles them all, but can also test removing handling of one or more ops + // at runtime to simulate a lower spec device where not all ops can be handled. this allows us to test + // that we can revert ops back to the CPU implementation successfully + const std::unordered_set default_supported_ops{"Conv", "Add", "Relu", "MaxPool"}; + const std::unordered_set* supported_ops = override_supported_ops ? override_supported_ops + : &default_supported_ops; + + if (enable_custom_ep) { + ASSERT_STATUS_OK(session->RegisterExecutionProvider( + onnxruntime::make_unique(*supported_ops))); + } + + ASSERT_STATUS_OK(session->Load(model_path)); + ASSERT_STATUS_OK(session->Initialize()); +} + +static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enabled) { + // validate that we can execute the model. the dummy internal testing EP just creates empty output so the + // values in the output aren't relevant. all we care about is that we can execute the model and produce output. + OrtValue ml_value_x; + TensorShape input_shape{1, 1, 28, 28}; + std::vector input(input_shape.Size(), 1.f); + + CreateMLValue(input_shape.GetDims(), input.data(), OrtMemoryInfo(), &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("Input3", ml_value_x)); + + // prepare outputs + std::vector output_names; + output_names.push_back("Plus214_Output_0"); + std::vector fetches; + + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + + if (custom_ep_enabled) { + // check that the output is all zeros. the dummy EP produces output of the correct shape with all zeros, so any + // downstream operations should still result in zeros for this model + // OR it should equal the bias in the final Add operation, which is in the Parameter194 initializer + const auto& t = fetches[0].Get(); + const auto data = t.DataAsSpan(); + + int idx = 0; + const auto& session_state = session.GetSessionState(); + ASSERT_STATUS_OK(session_state.GetOrtValueNameIdxMap().GetIdx("Parameter194", idx)); + const auto& initializer = session_state.GetConstantInitializedTensors().at(idx); + const auto expected = initializer.Get().DataAsSpan(); + + ASSERT_THAT(data, ::testing::ContainerEq(expected)); + } +} + +#if !defined(ORT_MINIMAL_BUILD) +TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.test_output.ort"); + + // + // First load the onnx format model and save as an ORT model. + // This should preserve the nodes the custom EP can handle. + // + std::unique_ptr session; + SessionOptions so; + so.optimized_model_filepath = ort_model_path; + + CreateSession(so, session); + // this graph should include the original nodes that the custom EP will take at runtime + auto num_nodes = session->GetGraph().NumberOfNodes(); + + // + // Second, load the ORT format model with just the CPU EP to make sure it can be executed. This tests that the + // fallback to the CPU EP kernel hashes works. + // + std::unique_ptr session2; + + so.optimized_model_filepath.clear(); + bool enable_custom_ep = false; + + CreateSession(so, session2, ort_model_path, enable_custom_ep); + const auto& graph1 = session2->GetGraph(); + // model should have all the original nodes and we should be able to execute with the fallback to CPU EP + ASSERT_EQ(graph1.NumberOfNodes(), num_nodes); + ExecuteMnist(*session2, enable_custom_ep); + session2 = nullptr; + + // + // Finally, load the ORT format model with the custom EP enabled. This tests that we support runtime compilation + // for the ORT format model. + // + enable_custom_ep = true; + CreateSession(so, session2, ort_model_path, enable_custom_ep); + const auto& graph2 = session2->GetGraph(); + // model should be able to be loaded, and we should compile using custom ep. that will result in one node for the + // custom EP (with Conv/Add/Relu/MaxPool), one for a reshape, and one for the fused MatMul+Add. + ASSERT_EQ(graph2.NumberOfNodes(), 3); + ExecuteMnist(*session2, enable_custom_ep); +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +// test to validate a minimal build +TEST(InternalTestingEP, TestLoadOrtModel) { + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort"); + + std::unique_ptr session; + bool enable_custom_ep = true; + + CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep); + ExecuteMnist(*session, enable_custom_ep); +} + +// test that is the custom EP cannot take all nodes due to device limitations +// that we fallback to the CPU implementations and can execute the model +TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort"); + const std::unordered_set supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/}; + + std::unique_ptr session; + bool enable_custom_ep = true; + + CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops); + + const auto& graph = session->GetGraph(); + // Conv+Add gets fused by level 1 optimizer into single node. The 'Conv'/'Add'/'Relu' nodes should be compiled and + // handled by the custom EP. fallback to CPU for MaxPool. + ASSERT_EQ(graph.NumberOfNodes(), 6); + const auto& func_mgr = session->GetSessionState().GetFuncMgr(); + NodeComputeInfo* compute_func = nullptr; + + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) + << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; + if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); + EXPECT_NE(compute_func, nullptr); + } + } + + ExecuteMnist(*session, enable_custom_ep); +} + +TEST(InternalTestingEP, TestMinimalRegistrationOfEPwithGetCapability) { + // TODO: In a full build we want to be able to call GetCapability for the NNAPI EP and produce an ORT format model + // with nodes correctly preserved. That requires being able to do a minimal registration of that EP where + // GetCapability is fully implemented, but Compile is a stub that just throws NOT_IMPLEMENTED if someone attempts + // to execute a model in that InferenceSession. +} + +// count nodes assigned to the test EP and make sure they all have valid compute funcs +static int CountAndValidateAssignedNodes(const Graph& current_graph, + const std::unordered_set& supported_ops, + const FuncManager& func_mgr) { + int count = 0; + + for (const auto& node : current_graph.Nodes()) { + EXPECT_EQ(supported_ops.count(node.OpType()), size_t(0)) + << "Nodes with supported op types should have been replaced. Node with type " << node.OpType() << " was not."; + if (node.GetExecutionProviderType() == utils::kInternalTestingExecutionProvider) { + NodeComputeInfo* compute_func = nullptr; + EXPECT_STATUS_OK(func_mgr.GetFuncs(node.Name(), compute_func)); + EXPECT_NE(compute_func, nullptr); + ++count; + } + + if (node.ContainsSubgraph()) { + for (const auto& entry : node.GetSubgraphs()) { + count += CountAndValidateAssignedNodes(*entry, supported_ops, func_mgr); + } + } + } + + return count; +} + +// Test model that contains a subgraph. This model has a Loop and an If so multiple layers of nested subgraphs. +// There are Add nodes in the Loop and If subgraphs so we should see the custom EP taking nodes at both these levels. +TEST(InternalTestingEP, TestModelWithSubgraph) { + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); + const std::unordered_set supported_ops{"Add"}; + + std::unique_ptr session; + bool enable_custom_ep = true; + + CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops); + + const auto& graph = session->GetGraph(); + const auto& func_mgr = session->GetSessionState().GetFuncMgr(); + + int num_replaced_nodes = CountAndValidateAssignedNodes(graph, supported_ops, func_mgr); + + // One Add node in the Loop. One Add node in each branch of the If inside the Loop body + ASSERT_EQ(num_replaced_nodes, 3); + + OrtValue ml_value; + + // this is a bit of a hack. the correct output is the input value + 2, so if we start with -2 the result is 0. + // the output from fused nodes using the testing EP is always 0, so we should match the expected output this way + // as we replace all the Add nodes with something that returns 0. + // RunAndVerifyOutputsWithEP checks that nodes are assigned to the EP so we know it's being used to execute the model + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {-2.f}, + &ml_value); + NameMLValMap feeds; + feeds.insert(std::make_pair("state_var_in", ml_value)); + // compare outputs from CPU EP vs custom EP + RunAndVerifyOutputsWithEP(ort_model_path, + "InternalTestingEP.TestModelWithSubgraph", + onnxruntime::make_unique(supported_ops), + feeds); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 3fa8ef5711..e4799c7964 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -1,13 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "core/common/logging/logging.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h" #include "core/session/inference_session.h" -#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" -#include "test/providers/provider_test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test/test_environment.h" +#include "test/util/include/test_utils.h" + +#if !defined(ORT_MINIMAL_BUILD) +// if this is a full build we need the provider test utils +#include "test/providers/provider_test_utils.h" +#endif + +#include "gtest/gtest.h" +#include "gmock/gmock.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -16,85 +28,48 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -#ifdef __ANDROID__ -void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values) { - ASSERT_EQ(1, fetches.size()); - auto& rtensor = fetches.front().Get(); - TensorShape expected_shape(expected_dims); - ASSERT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.template Data(), rtensor.template Data() + expected_values.size()); - ASSERT_EQ(expected_values, found); -} -#endif - -void RunAndVerifyOutputs(const std::string& model_file_name, - const char* log_id, - const NameMLValMap& feeds, - const std::vector& output_names, - const std::vector& expected_dims, - const std::vector& expected_values) { - SessionOptions so; - so.session_logid = log_id; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0))); - ASSERT_STATUS_OK(session_object.Load(model_file_name)); - ASSERT_STATUS_OK(session_object.Initialize()); - - // Since we already know the model is entirely supported by NNAPI, all nodes (2 Add nodes here) will be fused - // Get the graph after session is initialized, and verify the fused node (the only node in the graph) is using NNAPI EP - const auto& graph = session_object.GetGraph(); - ASSERT_EQ(1, graph.NumberOfNodes()); // Make sure the graph has 1 fused node - ASSERT_EQ(onnxruntime::kNnapiExecutionProvider, graph.Nodes().cbegin()->GetExecutionProviderType()); - -// The execution can only be performed on Android -#ifdef __ANDROID__ - // Now run and verify the result - std::vector fetches; - ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); - VerifyOutputs(fetches, expected_dims, expected_values); -#else - ORT_UNUSED_PARAMETER(feeds); - ORT_UNUSED_PARAMETER(output_names); - ORT_UNUSED_PARAMETER(expected_dims); - ORT_UNUSED_PARAMETER(expected_values); -#endif -} +#if !defined(ORT_MINIMAL_BUILD) // Since NNAPI EP handles Reshape and Flatten differently, -// Please see ReshapeOpBuilder::CanSkipReshape in /onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +// Please see ReshapeOpBuilder::CanSkipReshape in +// /onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc // We have a separated test for these skip reshape scenarios TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) { - std::string model_file_name = "testdata/nnapi_reshape_flatten_test.onnx"; + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/nnapi_reshape_flatten_test.onnx"); +#if defined(__ANDROID__) std::vector dims_mul_x = {2, 1, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f}; std::vector dims_mul_y = {3, 2, 2}; std::vector values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x); + CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, + &ml_value_x); OrtValue ml_value_y; - CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y, &ml_value_y); + CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_y, values_mul_y, + &ml_value_y); NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); feeds.insert(std::make_pair("Y", ml_value_y)); - // prepare outputs - std::vector output_names; - output_names.push_back("Z"); - - // prepare expected inputs and outputs - std::vector expected_dims_mul_z = {1, 6}; - std::vector expected_values_mul_z = {59.0f, 72.0f, 129.0f, 159.0f, 204.0f, 253.0f}; - - RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", feeds, output_names, expected_dims_mul_z, expected_values_mul_z); + RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", + onnxruntime::make_unique(0), + feeds); +#else + // test load only + SessionOptions so; + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique(0))); + ASSERT_STATUS_OK(session_object.Load(model_file_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0) + << "Some nodes should have been taken by the NNAPI EP"; +#endif } TEST(NnapiExecutionProviderTest, FunctionTest) { - std::string model_file_name = "nnapi_execution_provider_test_graph.onnx"; + const ORTCHAR_T* model_file_name = ORT_TSTR("nnapi_execution_provider_test_graph.onnx"); + { // Create the model with 2 add nodes onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); @@ -130,30 +105,38 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name)); } +#if defined(__ANDROID__) std::vector dims_mul_x = {1, 1, 3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x); + CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, + &ml_value_x); OrtValue ml_value_y; - CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_y); + CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, + &ml_value_y); OrtValue ml_value_z; - CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_z); + CreateMLValue(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, + &ml_value_z); NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - // prepare outputs - std::vector output_names; - output_names.push_back("M"); - std::vector fetches; - - // prepare expected inputs and outputs - std::vector expected_dims_mul_m = {1, 1, 3, 2}; - std::vector expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; - - RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.FunctionTest", + onnxruntime::make_unique(0), + feeds); +#else + // test load only + SessionOptions so; + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique(0))); + ASSERT_STATUS_OK(session_object.Load(model_file_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0) + << "Some nodes should have been taken by the NNAPI EP"; +#endif } +#endif // !(ORT_MINIMAL_BUILD TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE; @@ -164,5 +147,38 @@ TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW); } +TEST(NnapiExecutionProviderTest, TestOrtFormatModel) { + // mnist model that has only had basic optimizations applied. nnapi should be able to take at least some of the nodes + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort"); + +// The execution can only be performed on Android +#if defined(__ANDROID__) + RandomValueGenerator random{}; + const std::vector dims = {1, 1, 28, 28}; + std::vector data = random.Gaussian(dims, 0.0f, 1.f); + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, data, + &ml_value); + + NameMLValMap feeds; + feeds.insert(std::make_pair("Input3", ml_value)); + + RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.TestOrtFormatModel", + onnxruntime::make_unique(0), + feeds); +#else + // test load only + SessionOptions so; + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique(0))); + ASSERT_STATUS_OK(session_object.Load(model_file_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0) + << "Some nodes should have been taken by the NNAPI EP"; +#endif +} + } // namespace test } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/test/testdata/mnist.level1_opt.ort b/onnxruntime/test/testdata/mnist.level1_opt.ort new file mode 100644 index 0000000000..a8f7a892e6 Binary files /dev/null and b/onnxruntime/test/testdata/mnist.level1_opt.ort differ diff --git a/onnxruntime/test/testdata/mnist.onnx b/onnxruntime/test/testdata/mnist.onnx new file mode 100644 index 0000000000..fc1a3f733c Binary files /dev/null and b/onnxruntime/test/testdata/mnist.onnx differ diff --git a/onnxruntime/test/testdata/mnist.ort b/onnxruntime/test/testdata/mnist.ort new file mode 100644 index 0000000000..d16e8a5ed0 Binary files /dev/null and b/onnxruntime/test/testdata/mnist.ort differ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 584001d788..e231cc589c 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -10,14 +10,19 @@ namespace onnxruntime { std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena); -std::shared_ptr CreateExecutionProviderFactory_CUDA(OrtDevice::DeviceId device_id, - OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE, - size_t cuda_mem_limit = std::numeric_limits::max(), - ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo, - bool do_copy_in_default_stream = true); + +std::shared_ptr CreateExecutionProviderFactory_CUDA( + OrtDevice::DeviceId device_id, + OrtCudnnConvAlgoSearch cudnn_conv_algo = OrtCudnnConvAlgoSearch::EXHAUSTIVE, + size_t cuda_mem_limit = std::numeric_limits::max(), + ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo, + bool do_copy_in_default_stream = true); + +std::shared_ptr CreateExecutionProviderFactory_OpenVINO( + const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads); + std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena); std::shared_ptr CreateExecutionProviderFactory_NGraph(const char* ng_backend_type); -std::shared_ptr CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads); std::shared_ptr CreateExecutionProviderFactory_Nuphar(bool, const char*); std::shared_ptr CreateExecutionProviderFactory_Nnapi(unsigned long); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); @@ -29,6 +34,10 @@ std::shared_ptr CreateExecutionProviderFactory_ROCM(O size_t hip_mem_limit = std::numeric_limits::max(), ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo); +// EP for internal testing +std::shared_ptr CreateExecutionProviderFactory_InternalTesting( + const std::unordered_set& supported_ops); + namespace test { std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena) { diff --git a/onnxruntime/test/util/include/asserts.h b/onnxruntime/test/util/include/asserts.h index 69bf85ddb6..c56728d003 100644 --- a/onnxruntime/test/util/include/asserts.h +++ b/onnxruntime/test/util/include/asserts.h @@ -4,17 +4,18 @@ #pragma once #include "core/common/status.h" +#include "gtest/gtest.h" // helpers to run a function and check the status, outputting any error if it fails. // note: wrapped in do{} while(false) so the _tmp_status variable has limited scope #define ASSERT_STATUS_OK(function) \ do { \ - Status _tmp_status = function; \ + Status _tmp_status = function; \ ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status; \ } while (false) #define EXPECT_STATUS_OK(function) \ do { \ - Status _tmp_status = function; \ + Status _tmp_status = function; \ EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status; \ } while (false) diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index c0543cd640..29d73bcc48 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -21,5 +21,9 @@ std::unique_ptr DefaultAclExecutionProvider(bool enable_aren std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultRocmExecutionProvider(); +// EP for internal testing +std::unique_ptr DefaultInternalTestingExecutionProvider( + const std::unordered_set& supported_ops); + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/test_utils.h b/onnxruntime/test/util/include/test_utils.h new file mode 100644 index 0000000000..388db559fc --- /dev/null +++ b/onnxruntime/test/util/include/test_utils.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/framework_common.h" +#include "core/framework/execution_provider.h" + +#include +#include +#include + +namespace onnxruntime { +class Graph; + +namespace test { + +// return number of nodes in the Graph and any subgraphs that are assigned to the specified execution provider +int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type); + +// run the model using the CPU EP to get expected output, comparing to the output when the 'execution_provider' +// is enabled. requires that at least one node is assigned to 'execution_provider' +void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, + const char* log_id, + std::unique_ptr execution_provider, + const NameMLValMap& feeds); +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc new file mode 100644 index 0000000000..025a339da8 --- /dev/null +++ b/onnxruntime/test/util/test_utils.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/util/include/test_utils.h" + +#include "core/framework/ml_value.h" +#include "core/session/inference_session.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/test/test_environment.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { +static void VerifyOutputs(const std::vector& output_names, + const std::vector& expected_fetches, + const std::vector& fetches) { + ASSERT_EQ(expected_fetches.size(), fetches.size()); + + for (size_t i = 0, end = expected_fetches.size(); i < end; ++i) { + auto& ltensor = expected_fetches[i].Get(); + auto& rtensor = fetches[i].Get(); + ASSERT_EQ(ltensor.Shape().GetDims(), rtensor.Shape().GetDims()); + auto element_type = ltensor.GetElementType(); + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + EXPECT_THAT(ltensor.DataAsSpan(), ::testing::ContainerEq(rtensor.DataAsSpan())) + << " mismatch for " << output_names[i]; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + EXPECT_THAT(ltensor.DataAsSpan(), ::testing::ContainerEq(rtensor.DataAsSpan())) + << " mismatch for " << output_names[i]; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const float abs_err = float(1e-5); + + EXPECT_THAT(ltensor.DataAsSpan(), + ::testing::Pointwise(::testing::FloatNear(abs_err), rtensor.DataAsSpan())); + break; + } + default: + ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type); + } + } +} + +int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type) { + int count = 0; + + for (const auto& node : current_graph.Nodes()) { + if (node.GetExecutionProviderType() == ep_type) { + ++count; + } + + if (node.ContainsSubgraph()) { + for (const auto& entry : node.GetSubgraphs()) { + count += CountAssignedNodes(*entry, ep_type); + } + } + } + + return count; +} + +void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, const char* log_id, + std::unique_ptr execution_provider, + const NameMLValMap& feeds) { + SessionOptions so; + so.session_logid = log_id; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + // + // get expected output from CPU EP + // + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_path)); + ASSERT_STATUS_OK(session_object.Initialize()); + + const auto& graph = session_object.GetGraph(); + const auto& outputs = graph.GetOutputs(); + + // fetch all outputs + std::vector output_names; + output_names.reserve(outputs.size()); + for (const auto* node_arg : outputs) { + if (node_arg->Exists()) { + output_names.push_back(node_arg->Name()); + } + } + + std::vector expected_fetches; + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &expected_fetches)); + + auto provider_type = execution_provider->Type(); // copy string so the std::move doesn't affect us + + // + // get output with EP enabled + // + InferenceSessionWrapper session_object2{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(std::move(execution_provider))); + ASSERT_STATUS_OK(session_object2.Load(model_path)); + ASSERT_STATUS_OK(session_object2.Initialize()); + + // make sure that some nodes are assigned to the EP, otherwise this test is pointless... + const auto& graph2 = session_object2.GetGraph(); + auto ep_nodes = CountAssignedNodes(graph2, provider_type); + ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type << " for " << model_path; + + // Run with EP and verify the result + std::vector fetches; + ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches)); + VerifyOutputs(output_names, expected_fetches, fetches); +} + +} // namespace test +} // namespace onnxruntime diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0dddf6e763..683663db00 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -419,10 +419,13 @@ def parse_arguments(): help="Build ONNXRuntime micro-benchmarks.") # options to reduce binary size - parser.add_argument("--minimal_build", action='store_true', + parser.add_argument("--minimal_build", action='store', + const='on', default='off', nargs='?', type=str.lower, help="Create a build that only supports ORT format models. " "See /docs/ONNX_Runtime_Format_Model_Usage.md for more information. " - "RTTI is automatically disabled in a minimal build.") + "RTTI is automatically disabled in a minimal build. " + "To enable execution providers that compile kernels at runtime (e.g. NNAPI) pass 'extended' " + "as a parameter. e.g. '--minimal_build extended'.") parser.add_argument("--include_ops_by_model", type=str, help="include ops from model(s) under designated path.") parser.add_argument("--include_ops_by_config", type=str, help="include ops from config file. " @@ -710,7 +713,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_DISABLE_RTTI=" + ("ON" if args.disable_rtti else "OFF"), "-Donnxruntime_DISABLE_EXCEPTIONS=" + ("ON" if args.disable_exceptions else "OFF"), "-Donnxruntime_DISABLE_ORT_FORMAT_LOAD=" + ("ON" if args.disable_ort_format_load else "OFF"), - "-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build else "OFF"), + "-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build != 'off' else "OFF"), + "-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + ("ON" if args.minimal_build == 'extended' else "OFF"), "-Donnxruntime_REDUCED_OPS_BUILD=" + ( "ON" if args.include_ops_by_config or args.include_ops_by_model else "OFF"), "-Donnxruntime_MSVC_STATIC_RUNTIME=" + (