diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 7a3d0b9ba7..a3e8ffea2c 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -30,6 +30,12 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/function*" ) + # remove graph proto serializer + list(APPEND onnxruntime_graph_src_exclude_patterns + "${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.cc" + "${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.h" + ) + # no optimizer support in base minimal build # some optimizer support in extended minimal build if (NOT onnxruntime_EXTENDED_MINIMAL_BUILD) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index eec8ec9948..3d1d7c0ab6 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -348,7 +348,7 @@ if (onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src}) endif() -if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD) +if ((NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_NUPHAR) OR onnxruntime_EXTENDED_MINIMAL_BUILD) file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/internal_testing/*" ) @@ -464,11 +464,10 @@ if(onnxruntime_USE_COREML) endif() if(onnxruntime_USE_NUPHAR) - file(GLOB_RECURSE onnxruntime_test_nuphar_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/nuphar_tvm/*.h" - "${TEST_SRC_DIR}/nuphar_tvm/*.cc" - ) - + # the test case under nuphar_tvm is only to verify some basic tvm show case, which is already out of date + # it doesn't have relationship to nuphar directly. consider we have an official tvm execution provider now, + # keep those test cases doesn't bring any value now. + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/framework/nuphar/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nuphar) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nuphar) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 6283a2a09b..581c497cfd 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -197,48 +197,15 @@ class IExecutionProvider { // TODO: temparary sulotion, need to unify the interface in EP and AllocatorManager void TryInsertAllocator(AllocatorPtr allocator); - // creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support - // compilation via GraphViewer instances. -#if !defined(ORT_MINIMAL_BUILD) - /** - Given a list of fused_node, return create_state/compute/release_state func for each node. - */ - virtual common::Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs); - - /** - Given a list of fused_node, return a dll that expose functions for each node. - For each node, there should be three symbols: - Create_State_${node_name} - Compute_${node_name} - Release_State_${node_name} - */ - virtual common::Status Compile(const std::vector& fused_nodes, - std::string& dll_path); - -#endif - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) struct FusedNodeAndGraph { const std::reference_wrapper fused_node; // GraphViewer that filters the full graph to the nodes that are covered by 'node' const std::reference_wrapper filtered_graph; }; - /** - Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused, - return create_state/compute/release_state func for each node. - @remarks This is an optional interface that is only needed if the execution provider compiles nodes - in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model. - - Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions - as it is only valid for the duration of the call to Compile. - */ - virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs); -#endif - // Fusion approach that is suppported + // !!! The "Function" FusionStyle will be deprecated soon. + // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style. enum class FusionStyle { // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body(). @@ -254,12 +221,35 @@ class IExecutionProvider { }; virtual FusionStyle GetFusionStyle() const { - // existing EPs use this mode so default to it. - // newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return - // FilteredGraphViewer - return FusionStyle::Function; + // All the ORT build in EP has migrate to FilteredGraphViewer style except Nuphar. + // For newer EPs, please avoid use Function style as it will be deprecated soon. + return FusionStyle::FilteredGraphViewer; } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + + /** + * !!!! This API will be deprecated soon. If your execution provider overrides this API + * !!!! Please migrate it to the "Compile" API with FusedNodeAndGraph type. + Given a list of fused_node, return create_state/compute/release_state func for each node. + */ + virtual common::Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs); + + /** + Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused, + return create_state/compute/release_state func for each node. + @remarks This is now the default interface when execution provider wants to compile nodes + for both minimal build and complete ort build. + + Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions + as it is only valid for the duration of the call to Compile. + */ + virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs); + +#endif + void SetLogger(const logging::Logger* logger) { logger_ = logger; } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 3c26379826..37bd9e75cf 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1258,6 +1258,10 @@ class Graph { return Resolve(default_options); } + const std::unordered_set& GetOuterScopeNodeArgNames() const noexcept{ + return outer_scope_node_arg_names_; + } + common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& fbs_graph) const; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index fd2af7b7d2..7ec9178d58 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -149,6 +149,10 @@ class GraphViewer { /** Get the internal graph*/ const Graph& GetGraph() const { return *graph_; } +#if !defined(ORT_MINIMAL_BUILD) + const std::unordered_set& GetOuterScopeNodeArgNames() const noexcept; +#endif + /** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name' @@ -156,6 +160,9 @@ class GraphViewer { */ bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const; + /** Check if a given name is an initializer tensor's name in this graph. */ + bool IsInitializedTensor(const std::string& name) const; + /** returns the initializer's TensorProto if 'name' is an initializer, is constant and cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 00d154f6d3..249be3b54d 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -151,26 +151,21 @@ void IExecutionProvider::RegisterAllocator(std::shared_ptr) { return; } -#if !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +// !!!!This API will be deprecated soon. If your execution provider overrides this API +// !!!!Please migrate it to the "Compile" API with FusedNodeAndGraph type. common::Status IExecutionProvider::Compile(const std::vector& /*fused_node*/, std::vector& /*node_compute_funcs*/) { return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "IExecutionProvider::Compile with fused Node is not implemented by " + type_); } -common::Status IExecutionProvider::Compile(const std::vector& /*fused_node*/, - std::string& /*dll_path*/) { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, - "IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_); -} -#endif - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status IExecutionProvider::Compile(const std::vector& /*fused_nodes_and_graphs*/, std::vector& /*node_compute_funcs*/) { return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_); } + #endif int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 83f5b34e2c..b038915bbf 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -42,6 +42,15 @@ NonCudaOps non_cuda; using namespace ::onnxruntime::common; namespace onnxruntime { +#if !defined(ORT_MINIMAL_BUILD) +static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) { + auto schema = node.Op(); + builder.SetName(schema->Name()) + .SetDomain(schema->domain()) + .SinceVersion(schema->SinceVersion()) + .Provider(node.GetExecutionProviderType()); +} +#endif // minimal KernelDef based on MetaDef instead of a Function based node static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef, @@ -103,7 +112,7 @@ static void AssignNodes(Graph& graph, const IndexedSubGraph& capability, #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, - GraphPartitioner::Mode mode, std::vector>& capabilities, + GraphPartitioner::Mode mode, std::vector>& capabilities, TransformLayoutFunction transform_layout) { { GraphViewer graph_viewer(graph); @@ -142,15 +151,6 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg } #if !defined(ORT_MINIMAL_BUILD) - -static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) { - auto schema = node.Op(); - builder.SetName(schema->Name()) - .SetDomain(schema->domain()) - .SinceVersion(schema->SinceVersion()) - .Provider(node.GetExecutionProviderType()); -} - /** * Check if a node can be placed on a specific provider. * Do nothing if the node is already assigned @@ -162,8 +162,8 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::No * \return Fused node. Return nullptr if there is no fuse */ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, - const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type, IExecutionProvider::FusionStyle fusion_style, + const std::string& provider_type, GraphPartitioner::Mode mode, int& fused_node_unique_id) { Node* result = nullptr; @@ -216,7 +216,18 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, std::string node_name = oss.str(); Node* fused_node = nullptr; - if (fusion_style == IExecutionProvider::FusionStyle::Function) { + // TODO1: The DML currently use some legacy approach. + // It registers a generic predefined kernel for all purpose fusion, + // so it rely on the function body in the fused node during kernel creation, + // which is after the graph partition phase. + // Ideally, it should be moved to "Compile" call. + // Here we temporary keep the function body for DML fusion + // Need to remove it after migrate DML to the Compile-based approach. + // TODO2: Nuphar is out of maintain, keep it with old API temporarily. + // We want to deprecate Nuphar soon. + if (fusion_style == IExecutionProvider::FusionStyle::Function || + provider_type == kDmlExecutionProvider || + provider_type == kNupharExecutionProvider) { fused_node = &graph.FuseSubGraph(capability, node_name); } else { // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed @@ -226,10 +237,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, fused_node->SetExecutionProviderType(provider_type); - // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) { - result = fused_node; - } + result = fused_node; } else { // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion @@ -250,7 +258,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, // for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up). // assign any nodes to the EP that are currently unassigned, and that the EP can handle. -static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr, +static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, KernelRegistryManager& kernel_registry_mgr, KernelRegistry& fused_kernel_registry, IExecutionProvider& current_ep, @@ -268,7 +276,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { Graph* subgraph = entry.second; // we pass through the export_dll value and FuncManager from the top level graph - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr, + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, transform_layout_function)); } @@ -295,6 +303,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa const std::string& type = current_ep.Type(); auto fusion_style = current_ep.GetFusionStyle(); std::vector nodes_to_compile; + + // The fused node may map to an existing kernel, so it is fused but doesn't need to be compiled + // But we still need to finalize the graph fusion for those nodes. + std::vector nodes_to_complete_fuse; + std::vector> capabilities_to_complete_fuse; + // filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with // nodes_to_compile. std::vector> capabilities_to_compile; @@ -311,42 +325,46 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa continue; } - Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id); + Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { - nodes_to_compile.push_back(n); - capabilities_to_compile.push_back(std::move(capability)); + // searching in kernel registries, if no kernel registered for the fused_node, use compile approach + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) { + nodes_to_compile.push_back(n); + capabilities_to_compile.push_back(std::move(capability)); + } else { + // there is a predefined kernel for the fused node. doesn't need compile, but need to complete the fusing. + nodes_to_complete_fuse.push_back(n); + capabilities_to_complete_fuse.push_back(std::move(capability)); + } } } // NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode + // even with single node, EP might still want to compile it. + // for example, it want to JIT an optimized kernel for LSTM with a given shape. if (!nodes_to_compile.empty()) { std::vector node_compute_funcs; - - if (export_dll) { - ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function, - "Must use Function based fusion when exporting compiled nodes to dll."); - } - + // !!! The Function style fusion will be deprecated soon. if (fusion_style == IExecutionProvider::FusionStyle::Function) { + // TODO: Nuphar is out of maintain. Use the old api temporarily. + // We want to deprecate it soon. // Create a Function based node where the fused nodes have a new Graph instance. + static std::once_flag legacy_compile_method_warning_flag; + std::call_once( + legacy_compile_method_warning_flag, [](std::string_view ep_type) { + LOGS_DEFAULT(WARNING) << "Execution Provider: " << ep_type << " is still using Funciton style Compile API, " + << " which will be deprecated soon, please migrate to the new Compile API based on " + << " FilteredGraphViewer. "; + }, + type); + ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs)); - if (export_dll) { - std::string dll_path; - ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path)); + if (node_compute_funcs.size() != nodes_to_compile.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); + } - for (auto* node : nodes_to_compile) { - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path)); - } - } else { - ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs)); - - if (node_compute_funcs.size() != nodes_to_compile.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); - } - - for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j]))); - } + for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j]))); } for (auto* node : nodes_to_compile) { @@ -358,12 +376,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa return FunctionKernel::Create(func_mgr, info, out); })); } - } else { // temporary storage for the GraphViewer for each IndexedSubGraph std::vector> viewers; viewers.reserve(nodes_to_compile.size()); std::vector nodes_and_viewers; + nodes_and_viewers.reserve(nodes_to_compile.size()); for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) { auto* node = nodes_to_compile[j]; @@ -401,6 +419,21 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa graph.FinalizeFuseSubGraph(indexed_sub_graph, *node); } } + } + + // TODO: The DML currently use some legacy approach. + // The fuse is done in FuseSubGraph function. + // Need to remove it later when DML migrate to Compile approach + if (!nodes_to_complete_fuse.empty() && type != kDmlExecutionProvider) { + for (size_t j = 0, end = nodes_to_complete_fuse.size(); j < end; j++) { + auto* node = nodes_to_complete_fuse[j]; + + const auto& cur_capability = capabilities_to_complete_fuse[j]; + const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + + // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one + graph.FinalizeFuseSubGraph(indexed_sub_graph, *node); + } ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph)); } @@ -458,7 +491,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } -Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, +Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const { @@ -467,7 +500,7 @@ Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, do { // process full graph with each EP for (const auto& ep : providers_) { - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_, + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_mgr_, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function)); } @@ -610,7 +643,7 @@ Status GraphPartitioner::PartitionOrtFormatModel( return Status::OK(); } -Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, +Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, TransformLayoutFunction transform_layout_function, Mode mode, std::unordered_map* compiled_kernel_hashes) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -634,19 +667,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode, + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, func_mgr, *fused_kernel_registry, mode, fused_node_unique_id, transform_layout_function)); #else - ORT_UNUSED_PARAMETER(export_dll); ORT_THROW("Not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided"); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes, fused_node_unique_id, transform_layout_function)); } - if (!fused_kernel_registry->IsEmpty()) { kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry); } diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index db9c3fb31d..4aac959f3e 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -32,7 +32,7 @@ class GraphPartitioner { } // Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad. - Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr, + Status Partition(Graph& graph, FuncManager& func_mgr, TransformLayoutFunction transform_layout_function, Mode mode = Mode::kNormal, std::unordered_map* compiled_kernel_hashes = nullptr) const; @@ -41,7 +41,7 @@ class GraphPartitioner { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); #if !defined(ORT_MINIMAL_BUILD) - Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr, + Status PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr, KernelRegistry& fused_kernel_registry, Mode mode, int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const; #endif diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 9315b06de4..a7c732d7f8 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -276,9 +276,6 @@ class SessionState { concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; } concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; } - bool ExportDll() const noexcept { return export_fused_dll_; } - void SetExportDllFlag(bool flag) noexcept { export_fused_dll_ = flag; } - const FuncManager& GetFuncMgr() const noexcept { return fused_funcs_mgr_; } FuncManager& GetMutableFuncMgr() noexcept { return fused_funcs_mgr_; } @@ -488,7 +485,6 @@ class SessionState { concurrency::ThreadPool* const thread_pool_{}; concurrency::ThreadPool* const inter_op_thread_pool_{}; - bool export_fused_dll_ = false; const DataTransferManager& data_transfer_mgr_; bool use_deterministic_compute_; diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc new file mode 100644 index 0000000000..89eb20f734 --- /dev/null +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_proto_serializer.h" + +namespace onnxruntime { + +void GraphViewerToProto(const GraphViewer& graph_view, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializer, + bool include_outer_scope_args) { + graph_proto.set_name(graph_view.Name()); + graph_proto.set_doc_string(graph_view.Description()); + + for (const auto* input_arg : graph_view.GetInputsIncludingInitializers()) { + *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); + } + + for (const auto* output_arg : graph_view.GetOutputs()) { + *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); + } + + for (const auto* value_info : graph_view.GetValueInfo()) { + *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); + } + + if (include_outer_scope_args){ + // add the NodeArg info for outer scope NodeArgs so we capture the type information + for (const auto& name : graph_view.GetOuterScopeNodeArgNames()) { + auto* node_arg = graph_view.GetNodeArg(name); + ORT_ENFORCE(node_arg, "Outer scope node arg name '" + name + "'was added but does not exist. "); + *(graph_proto.mutable_value_info()->Add()) = node_arg->ToProto(); + } + } + + // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + const gsl::not_null node_proto{graph_proto.add_node()}; + const gsl::not_null p_node{graph_view.GetNode(node_idx)}; + // we need to update any GraphProto attributes for subgraphs so that any changes made by things + // such as the optimizers are captured. otherwise we can end up saving an invalid graph. + p_node->ToProto(*node_proto, /* update_subgraphs */ true); + } + + if (include_initializer) { + auto& initializers = graph_view.GetAllInitializedTensors(); + for (auto& it : initializers) { + auto* p_initializer = graph_proto.add_initializer(); + *p_initializer = *(it.second); + } + } +} + +} \ No newline at end of file diff --git a/onnxruntime/core/graph/graph_proto_serializer.h b/onnxruntime/core/graph/graph_proto_serializer.h new file mode 100644 index 0000000000..fe88dd547f --- /dev/null +++ b/onnxruntime/core/graph/graph_proto_serializer.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { + +void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args); +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index bf50574b4f..4fb6e7de82 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -271,9 +271,19 @@ bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_oute return GetConstantInitializer(name, check_outer_scope) != nullptr; } +bool GraphViewer::IsInitializedTensor(const std::string& name) const { + return graph_->IsInitializedTensor(name); +} + const ONNX_NAMESPACE::TensorProto* GraphViewer::GetConstantInitializer(const std::string& initializer_name, bool check_outer_scope) const { return graph_->GetConstantInitializer(initializer_name, check_outer_scope); } +#if !defined(ORT_MINIMAL_BUILD) +const std::unordered_set& GraphViewer::GetOuterScopeNodeArgNames() const noexcept { + return graph_->GetOuterScopeNodeArgNames(); +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 9dd5ba2819..6977dfdc1f 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -20,9 +20,6 @@ class CoreMLExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph_viewer, const std::vector& /*kernel_registries*/) const override; - // we implement the Compile that takes FusedNodeAndGraph instances - FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 713b8f414b..d2e735a9b9 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -124,33 +124,6 @@ std::vector> DNNLExecutionProvider::GetSupportedNodes(con return supported_node_vecs; } -void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) { - for (const auto* input_arg : graph.GetInputs()) { - *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); - } - - // Add all graph's initializers to the subgraph - const auto& init_tensors = graph.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - *(graph_proto.mutable_initializer()->Add()) = *(tensor.second); - } - - for (const auto* output_arg : graph.GetOutputs()) { - *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); - } - - for (const auto* value_info : graph.GetValueInfo()) { - *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); - } - - // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. - for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { - const gsl::not_null node_proto{graph_proto.add_node()}; - const gsl::not_null p_node{graph.GetNode(node_idx)}; - p_node->ToProto(*node_proto); - } -} - std::vector> DNNLExecutionProvider::GetCapability( const GraphViewer& graph_viewer, const std::vector& kernel_registries) const { @@ -269,7 +242,7 @@ std::vector> DNNLExecutionProvider::GetCapabi if (dump_subgraphs_) { auto model = graph_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph()); + graph_viewer.ToProto(*model_proto->mutable_graph(), false, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); HashValue model_hash; int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); @@ -280,39 +253,34 @@ std::vector> DNNLExecutionProvider::GetCapabi return result; } -Status DNNLExecutionProvider::Compile(const std::vector& fused_nodes, +Status DNNLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { //follow from coreml ep's Compile - for (const auto* fused_node : fused_nodes) { - const auto* func_body = fused_node->GetFunctionBody(); - if (!func_body) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - } - const Graph& graph_body = func_body->Body(); - auto graph_body_viewer = graph_body.CreateGraphViewer(); - + for (auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; if (dump_subgraphs_) { - auto model = graph_body_viewer->CreateModel(*GetLogger()); + auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - *model_proto->mutable_graph() = *graph_body.ToGraphProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), false, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } //subgraph - auto dnnl_subgraph = std::make_unique(ort_dnnl::DnnlSubgraph(*graph_body_viewer.get())); - subgraphs_.emplace(fused_node->Name(), std::move(dnnl_subgraph)); + auto dnnl_subgraph = std::make_unique(ort_dnnl::DnnlSubgraph(graph_body_viewer)); + subgraphs_.emplace(fused_node.Name(), std::move(dnnl_subgraph)); //apply transformation to subgraph if (enable_fusion_) { - ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node->Name()].get()); + ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node.Name()].get(), graph_body_viewer); } //subgraph primitive - auto dnnl_subgraph_primitive = std::make_unique(*subgraphs_[fused_node->Name()].get()); + auto dnnl_subgraph_primitive = std::make_unique(*subgraphs_[fused_node.Name()].get()); { - const auto& input_defs = fused_node->InputDefs(); + const auto& input_defs = fused_node.InputDefs(); std::vector onnx_input_names(input_defs.size()); for (size_t i = 0, end = input_defs.size(); i < end; ++i) { onnx_input_names[i] = input_defs[i]->Name(); @@ -320,7 +288,7 @@ Status DNNLExecutionProvider::Compile(const std::vector& fused_nodes, dnnl_subgraph_primitive->SetOrderedInputs(std::move(onnx_input_names)); } { - const auto& output_defs = fused_node->OutputDefs(); + const auto& output_defs = fused_node.OutputDefs(); std::vector onnx_output_names(output_defs.size()); for (size_t i = 0, end = output_defs.size(); i < end; ++i) { onnx_output_names[i] = output_defs[i]->Name(); @@ -328,7 +296,7 @@ Status DNNLExecutionProvider::Compile(const std::vector& fused_nodes, dnnl_subgraph_primitive->SetOrderedOutputs(std::move(onnx_output_names)); } - subgraph_primitives_.emplace(fused_node->Name(), std::move(dnnl_subgraph_primitive)); + subgraph_primitives_.emplace(fused_node.Name(), std::move(dnnl_subgraph_primitive)); NodeComputeInfo compute_info; diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index dc39e91401..d4e49c2e17 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -34,7 +34,7 @@ class DNNLExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const override; - common::Status Compile(const std::vector& fused_nodes, + common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; private: diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc index a6723f6368..cf7e0bdfd4 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul_integer.cc @@ -67,14 +67,14 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod if (has_a_zero_point) { auto zp_A_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1}); - auto tensor = node.Input(IN_A_ZERO_POINT); + auto& tensor = node.Input(IN_A_ZERO_POINT); auto zp_A_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_A_mem_desc_s32, eng); mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_A_mem_s32; } if (has_b_zero_point) { auto zp_B_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1}); - auto tensor = node.Input(IN_B_ZERO_POINT); + auto& tensor = node.Input(IN_B_ZERO_POINT); auto zp_B_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_B_mem_desc_s32, eng); mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_B_mem_s32; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc index b08c1d1bb9..62a815307f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.cc @@ -164,14 +164,14 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) if (has_input_zero_point) { auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1}); - auto tensor = node.Input(INPUT_ZP); + auto& tensor = node.Input(INPUT_ZP); auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng); mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_mem; } if (has_weights_zero_point) { auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1}); - auto tensor = node.Input(WEIGHTS_ZP); + auto& tensor = node.Input(WEIGHTS_ZP); auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng); mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_mem; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.cc index b32402f1d0..1919870b55 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.cc @@ -10,31 +10,51 @@ namespace ort_dnnl { DnnlTensor DnnlNode::empty_tensor_ = DnnlTensor(""); DnnlTensor::DnnlTensor(const NodeArg* arg) { - arg_ = arg; if (!arg || !arg->Exists()) { tensor_name_ = ""; } else { tensor_name_ = arg->Name(); } + // because the passed in ort graph will be released after compile + // need to save the type/shape in dnnl IR + arg_type_ = arg->Type(); + arg_type_proto_ = ONNX_NAMESPACE::TypeProto::Create(); + arg_type_proto_->copy_from(arg->TypeAsProto()); } DnnlTensor::DnnlTensor(std::string name) { tensor_name_ = name; - arg_ = nullptr; + arg_type_ = nullptr; + arg_type_proto_ = nullptr; } std::string DnnlTensor::Name() const { return tensor_name_; } +const ONNX_NAMESPACE::TensorShapeProto* DnnlTensor::GetShape() const{ + if (arg_type_proto_ == nullptr || arg_type_ == nullptr) { + return nullptr; + } + + if (arg_type_proto_->value_case() != ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType) { + return nullptr; + } + auto& tensor_type = arg_type_proto_->tensor_type(); + if (tensor_type.has_shape()) { + return &tensor_type.shape(); + } + return nullptr; +} + dnnl::memory::dims DnnlTensor::Dim() const { - if (arg_ == nullptr) { + if (arg_type_proto_ == nullptr || arg_type_ == nullptr) { return dnnl::memory::dims(); } - auto shape_proto = arg_->Shape(); + auto* shape_proto = GetShape(); // a shape without any information if (shape_proto == nullptr) { - LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_->Type() << ": " << arg_->Name(); + LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_type_ << ": " << tensor_name_; return dnnl::memory::dims(); } std::vector shape; @@ -42,7 +62,7 @@ dnnl::memory::dims DnnlTensor::Dim() const { for (const auto& dim : dims) { bool has_dim_value = dim.value_case() == dim.kDimValue; if (!has_dim_value) { - LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_->Type() << ": " << arg_->Name(); + LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_type_ << ": " << tensor_name_; shape.push_back(DNNL_RUNTIME_DIM_VAL); } else { shape.push_back(dim.dim_value()); @@ -57,7 +77,10 @@ dnnl::memory::dims DnnlTensor::Dim() const { } dnnl::memory::data_type DnnlTensor::Type() const { - auto data_type = arg_->TypeAsProto()->tensor_type().elem_type(); + if (arg_type_proto_ == nullptr) { + ORT_THROW("Invoke DnnlTensor's arg_type_proto_ not initialized yet."); + } + auto data_type = arg_type_proto_->tensor_type().elem_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return dnnl::memory::data_type::undef; @@ -123,7 +146,7 @@ void DnnlTensor::RemoveConsumer(const DnnlNodeArg& arg) { } DnnlNode::DnnlNode(const Node* node) { - onnx_node_ = node; + since_version_ = node->SinceVersion(); name_ = node->Name(); op_type_ = node->OpType(); attr_->insert(node->GetAttributes()); @@ -176,11 +199,11 @@ NodeAttributes& DnnlNode::Attributes() { } int DnnlNode::SinceVersion() { - return onnx_node_->SinceVersion(); + return since_version_; } -DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { - Build(); +DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) { + Build(graph_viewer); is_dynamic_ = false; for (auto input : GetDnnlInputs()) { if (input->IsDynamic()) { @@ -309,19 +332,11 @@ void DnnlSubgraph::AddNode(std::unique_ptr new_node) { dnnl_nodes_.back()->Index() = index; } -bool DnnlSubgraph::GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value) { - return graph_viewer_.GetInitializedTensor(arg_name, value); -} - -bool DnnlSubgraph::IsConstantInitializer(const std::string& arg_name, bool check_outer_scope) { - return graph_viewer_.IsConstantInitializer(arg_name, check_outer_scope); -} - -void DnnlSubgraph::Build() { +void DnnlSubgraph::Build(const GraphViewer& graph_viewer) { //establish nodes, tensors and nodeargs - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); + const auto* node(graph_viewer.GetNode(node_indices[i])); AddNode(std::make_unique(node)); auto dnnl_node = dnnl_nodes_.back().get(); std::vector inputs; @@ -361,15 +376,15 @@ void DnnlSubgraph::Build() { //graph inputs including initializers and outputs can be deleted by graph transformation (eg, gelu fusion) //delete unneeded inputs don't affect onnxruntime passing them as input data handle //delete unneeded outputs will cause ep to output to fewer data handles then expected - for (const auto* node_arg : graph_viewer_.GetInputsIncludingInitializers()) { + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { inputs_.push_back(dnnl_tensors_[node_arg->Name()].get()); } - for (const auto* node_arg : graph_viewer_.GetOutputs()) { + for (const auto* node_arg : graph_viewer.GetOutputs()) { outputs_.push_back(dnnl_tensors_[node_arg->Name()].get()); } - for (auto& initializer : graph_viewer_.GetAllInitializedTensors()) { + for (auto& initializer : graph_viewer.GetAllInitializedTensors()) { auto& name = initializer.first; initializers_.push_back(dnnl_tensors_[name].get()); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.h index 963ed815c3..1204eec1c1 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph.h @@ -55,8 +55,12 @@ class DnnlTensor { void RemoveConsumer(const DnnlNodeArg& arg); private: + + const ONNX_NAMESPACE::TensorShapeProto* GetShape() const; + std::string tensor_name_; - const NodeArg* arg_; + ONNX_NAMESPACE::DataType arg_type_; + std::unique_ptr arg_type_proto_; //a tensor can have no producer (input.initializer) or no consumer (output for subgraph) DnnlNodeArg producer_; std::vector consumers_; @@ -79,7 +83,7 @@ class DnnlNode { int SinceVersion(); private: - const Node* onnx_node_ = nullptr; + int since_version_; std::vector inputs_; std::vector outputs_; static DnnlTensor empty_tensor_; @@ -101,7 +105,7 @@ class DnnlSubgraph { std::vector GetDnnlOutputs(); std::vector GetDnnlInitializers(); // build the subgraph IR - void Build(); + void Build(const GraphViewer& graph_viewer); //check whether the subgraph is dynamic void TopoSort(); bool IsDynamic(); @@ -110,9 +114,6 @@ class DnnlSubgraph { void AddTensor(std::unique_ptr new_tensor); void RemoveTensor(const std::string& tensor_name); - bool GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value); - bool IsConstantInitializer(const std::string& arg_name, bool check_outer_scope); - private: //graph owns all nodes std::vector> dnnl_nodes_; @@ -122,7 +123,6 @@ class DnnlSubgraph { std::vector inputs_; std::vector outputs_; //output should never get deleted from graph transformation std::vector initializers_; - const GraphViewer& graph_viewer_; bool is_dynamic_; }; } // namespace ort_dnnl diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 92f788f4aa..f3a18efd44 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -446,7 +446,7 @@ bool DnnlSubgraphPrimitive::HasMemory(std::string memory_name, dnnl::memory::des return false; } -void DnnlSubgraphPrimitive::SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) { +void DnnlSubgraphPrimitive::SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) { if (always_copy_output) { outputs_are_always_copied_.insert(tensor.Name()); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h index 39430f6bc9..11a6491bc4 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h @@ -59,7 +59,7 @@ class DnnlSubgraphPrimitive { //set memory to a tensor (output) //if always_copy_output is true a copy of the memory will be made when the output is leaving the subgraph. //is_scalar is true to indicate a scalar output in order to allocate the correct onnxruntime output buffer - void SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false); + void SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false); void SetMemory(std::string memory_name, dnnl::memory mem); void SetInitializer(std::string memory_name, dnnl::memory mem); dnnl::memory::desc GetOutputInfo(std::string name); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc index 799f76a9a6..e878988f91 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc @@ -12,12 +12,12 @@ namespace onnxruntime { namespace ort_dnnl { //apply all transformation rules in order -void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph) { +void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) { ConvRelu(subgraph); MatMulAdd(subgraph); - Gelu(subgraph); - FastGelu(subgraph); - RemoveMatMulIntegerZP(subgraph); + Gelu(subgraph, onnx_subgraph_viewer); + FastGelu(subgraph, onnx_subgraph_viewer); + RemoveMatMulIntegerZP(subgraph, onnx_subgraph_viewer); } //resolve a fusion by replacing old_indices nodes with a new_node @@ -162,13 +162,13 @@ bool IsScalar(const DnnlTensor& input_arg) { return dim_size == 0 || (dim_size == 1 && dim[0] == 1); } -bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value) { +bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value) { if (!IsScalar(input_arg)) { return false; } const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - if (!subgraph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { + if (!onnx_subgraph_viewer.GetInitializedTensor(input_arg.Name(), tensor_proto)) { return false; } @@ -234,7 +234,7 @@ DnnlNode* FirstParentByType(DnnlNode* node, const std::string& parent_type) { After Fusion: [root]--> Gelu ==> */ -void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) { +void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) { static int gelu_index = 0; //traverse with max index as there will be empty nodes due to fusion size_t max_index = subgraph.GetMaxNodeIndex(); @@ -249,8 +249,8 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) { // Check second input is sqrt(2) // Some Bert models uses this approximation of SQRT2 in the Gelu function float approximated_sqrt_two = 1.4142099618911743f; - if (!IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), approximated_sqrt_two) && - !IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), static_cast(M_SQRT2))) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), approximated_sqrt_two) && + !IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), static_cast(M_SQRT2))) { continue; } @@ -275,7 +275,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) { } bool is_add_input0 = add_node->Input(0).Name() == erf_node->Output(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) { continue; } @@ -304,7 +304,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) { if (!(is_mul2_input0 ^ is_mul2_input1)) { is_pattern_1 = false; } - if (is_pattern_1 && !IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) { + if (is_pattern_1 && !IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) { is_pattern_1 = false; } if (is_pattern_1 && !IsNodeFusable(subgraph, mul2_node)) { @@ -329,7 +329,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) { ORT_THROW("Invalid Mul node"); } bool is_mul2_first_input = mul2_node->Input(0).Name() == mul1_node->Output(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) { continue; } } @@ -369,14 +369,14 @@ The formula corresponding to Gelu activation subgraph : where x is the input. */ -void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph) { +void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) { static int fastgelu_index = 0; //traverse with max index as there will be empty nodes due to fusion size_t max_index = subgraph.GetMaxNodeIndex(); for (size_t index = 0; index < max_index; index++) { auto dnnl_node = subgraph.GetDnnlNode(index); - if (!FastGeluFirstFormula(subgraph, dnnl_node, fastgelu_index)) { - FastGeluSecondFormula(subgraph, dnnl_node, fastgelu_index); + if (!FastGeluFirstFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index)) { + FastGeluSecondFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index); } } } @@ -389,7 +389,7 @@ The formula corresponding to Gelu activation subgraph : where x is the input. */ -bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* mul1_node, int& fastgelu_index) { +bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* mul1_node, int& fastgelu_index) { std::vector gelu_indices; //----------mul(0.44715)------------------ if (mul1_node == nullptr || mul1_node->OpType() != "Mul") { @@ -398,7 +398,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode int32_t mul1_input_index = -1; const float mul_val = 0.044715f; for (auto i = 0; i < 2; i++) { - if (IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(i), mul_val)) { + if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(i), mul_val)) { mul1_input_index = i; break; } @@ -424,7 +424,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode return false; } bool is_add_input0 = mul2_node->Output(0).Name() == add1_node->Input(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) { return false; } @@ -450,7 +450,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode int32_t mul4_input_index = -1; const float mul4_val = 0.7978845834732056f; for (auto i = 0; i < 2; i++) { - if (IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(i), mul4_val)) { + if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(i), mul4_val)) { mul4_input_index = i; break; } @@ -464,7 +464,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode auto tanh_node = mul3_node->Output(0).GetConsumers()[0].GetNode(); int32_t x_input_index = (mul1_input_index == 0) ? 1 : 0; - if (FastGeluFormulaCommon(subgraph, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) { + if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) { if (debug_log_) { LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (first formula)"; } @@ -481,15 +481,15 @@ The formula corresponding to Gelu activation subgraph : where x is the input. */ -void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* pow_node, int& fastgelu_index) { +void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* pow_node, int& fastgelu_index) { std::vector gelu_indices; //---------Pow------------------- if (pow_node == nullptr || pow_node->OpType() != "Pow") { return; } - auto pow_exponent = pow_node->Input(1); - if (!IsInitilizedWithExpectedValue(subgraph, pow_exponent, 3.0f)) { + auto& pow_exponent = pow_node->Input(1); + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, pow_exponent, 3.0f)) { return; } @@ -505,7 +505,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod float fastgelu_muliplyer = 0.044714998453855515f; bool is_mul1_input0 = pow_node->Output(0).Name() == mul1_node->Input(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) { return; } if (!IsNodeFusable(subgraph, mul1_node)) { @@ -531,7 +531,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod // constant is sqrt(2/pi) float fastgelu_sqrt_2_div_pi = 0.7978845834732056f; bool is_mul2_input0 = add1_node->Output(0).Name() == mul2_node->Input(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) { return; } @@ -543,7 +543,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod //----------Tanh------------------ auto tanh_node = mul2_node->Output(0).GetConsumers()[0].GetNode(); // since the first node is pow the x_input_index is always 0 - if (FastGeluFormulaCommon(subgraph, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) { + if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) { if (debug_log_) { LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (second formula)"; } @@ -559,7 +559,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))), where x is the input. */ -bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector& gelu_indices, int& fastgelu_index) { +bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector& gelu_indices, int& fastgelu_index) { //----------Tanh------------------ if (tanh_node == nullptr || tanh_node->OpType() != "Tanh") { return false; @@ -575,7 +575,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod return false; } bool is_add2_input0 = tanh_node->Output(0).Name() == add2_node->Input(0).Name(); - if (!IsInitilizedWithExpectedValue(subgraph, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) { return false; } @@ -607,7 +607,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod if (!(is_mul_input0 ^ is_mul_input1)) { return false; } - if (!IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) { + if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) { return false; } if (!IsNodeFusable(subgraph, prev_mul4_node)) { @@ -748,7 +748,7 @@ void DnnlGraphTransformer::MatMulAdd(DnnlSubgraph& subgraph) { } } -void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) { +void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) { size_t max_index = subgraph.GetMaxNodeIndex(); for (size_t index = 0; index < max_index; index++) { auto dnnl_node = subgraph.GetDnnlNode(index); @@ -763,9 +763,9 @@ void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) { continue; } - auto b_zero_point = dnnl_node->Input(3); + auto& b_zero_point = dnnl_node->Input(3); const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - if (!subgraph.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) { + if (!onnx_subgraph_viewer.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) { continue; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.h index 58fa73d771..673b7d6d0e 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.h @@ -9,7 +9,10 @@ namespace ort_dnnl { class DnnlGraphTransformer { public: - void Apply(DnnlSubgraph& subgraph); + // The passed in onnx subgraph viewer is only valid during "Compile" phase, + // so keep a reference to that onnx subgraph in DnnlSubgraph is risky. + // passed in the onnx subgraph viewer explicitly to make sure we manage the lifetime correctly. + void Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer_); DnnlGraphTransformer() { const std::string debug_log_env = onnxruntime::GetEnvironmentVar("ORT_DNNL_DEBUG_LOG"); if (!debug_log_env.empty()) { @@ -18,15 +21,15 @@ class DnnlGraphTransformer { } private: - void Gelu(DnnlSubgraph& subgraph); - void FastGelu(DnnlSubgraph& subgraph); - bool FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index); - void FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index); - bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector& gelu_indices, int& fastgelu_index); - bool IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value); + void Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer); + void FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer); + bool FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index); + void FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index); + bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector& gelu_indices, int& fastgelu_index); + bool IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value); void ConvRelu(DnnlSubgraph& subgraph); void MatMulAdd(DnnlSubgraph& subgraph); - void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph); + void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer); // This function checks a few things // - the node in question has a single output // - The output of the node is only consumed by a one other node diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ac3d167d1f..10e5eec5d5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -660,34 +660,6 @@ static bool IsNodeSupported(const std::set& op_set, return true; } -// Convert GraphViewer graph to GraphProto -void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) { - for (const auto* input_arg : graph.GetInputs()) { - *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); - } - - // Add all graph's initializers to the subgraph - const auto& init_tensors = graph.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - *(graph_proto.mutable_initializer()->Add()) = *(tensor.second); - } - - for (const auto* output_arg : graph.GetOutputs()) { - *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); - } - - for (const auto* value_info : graph.GetValueInfo()) { - *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); - } - - // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. - for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { - const gsl::not_null node_proto{graph_proto.add_node()}; - const gsl::not_null p_node{graph.GetNode(node_idx)}; - p_node->ToProto(*node_proto); - } -} - std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const { std::unordered_set node_set; node_set.reserve(graph_nodes_index.size()); @@ -913,7 +885,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph()); + graph_viewer.ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); @@ -1012,30 +984,24 @@ bool get_input_output_names(const GraphViewer& graph, return no_input_shape; } -Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, +Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; bool no_input_shape = false; - for (const auto& fused_node : fused_nodes) { + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; // map parameter input name to index std::unordered_map input_name_index; - const auto& input_defs = fused_node->InputDefs(); + const auto& input_defs = fused_node.InputDefs(); input_name_index.reserve(input_defs.size()); for (std::size_t i = 0; i < input_defs.size(); ++i) { input_name_index[input_defs[i]->Name()] = i; } - // Reconstruct graph proto from fused node's function body - const auto* func_body = fused_node->GetFunctionBody(); - if (!func_body) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - } - - const Graph& graph_body = func_body->Body(); - auto graph_body_viewer = graph_body.CreateGraphViewer(); - auto model = graph_body_viewer->CreateModel(*GetLogger()); + auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - *model_proto->mutable_graph() = *graph_body.ToGraphProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); @@ -1069,11 +1035,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } // compile the program - map_progs_[fused_node->Name()] = prog; + map_progs_[fused_node.Name()] = prog; - map_onnx_string_[fused_node->Name()] = onnx_string_buffer; - map_input_index_[fused_node->Name()] = input_name_index; - map_no_input_shape_[fused_node->Name()] = no_input_shape; + map_onnx_string_[fused_node.Name()] = onnx_string_buffer; + map_input_index_[fused_node.Name()] = input_name_index; + map_no_input_shape_[fused_node.Name()] = no_input_shape; NodeComputeInfo compute_info; compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 9c5ac11fe3..65b252ffdd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -45,8 +45,8 @@ class MIGraphXExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph_viewer, const std::vector& kernel_registries) const override; - Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs) override; + common::Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; virtual std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index cfb67c91c4..d59660821d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -23,9 +23,6 @@ class NnapiExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph_view, const std::vector& /*kernel_registries*/) const override; - // we implement the Compile that takes FusedNodeAndGraph instances - FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h index ede5af0b03..86bbd8552f 100644 --- a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h +++ b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.h @@ -118,6 +118,13 @@ class NupharExecutionProvider : public IExecutionProvider { return iter->second.get(); } + onnxruntime::IExecutionProvider::FusionStyle GetFusionStyle() const override { + // existing EPs use this mode so default to it. + // newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return + // FilteredGraphViewer + return onnxruntime::IExecutionProvider::FusionStyle::Function; + } + private: void CreateTVMTarget(); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 5386cdefd3..acb323aec8 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -26,7 +26,9 @@ void BackendManager::ReleaseGlobalContext() { g_global_context.reset(); } -BackendManager::BackendManager(const Node* fused_node, const logging::Logger& logger) { +BackendManager::BackendManager(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger) { auto prec_str = GetGlobalContext().precision_str; if (prec_str == "FP32") { subgraph_context_.precision = InferenceEngine::Precision::FP32; @@ -38,14 +40,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo // Save the indexes of graph inputs among fused_node's inputDefs // (which also contains initializers). - auto node_input_defs = fused_node->InputDefs(); + auto node_input_defs = fused_node.InputDefs(); int i = 0; for (auto idef : node_input_defs) { subgraph_context_.input_names.insert({idef->Name(), i}); i++; } - auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs(); + auto graph_inputs = subgraph.GetInputs(); for (auto input : graph_inputs) { if (GetGlobalContext().device_type.find("MYRIAD") != std::string::npos) { auto shape = input->Shape(); @@ -63,14 +65,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo subgraph_context_.input_indexes.push_back(index); } - auto graph_outputs_defs = fused_node->OutputDefs(); + auto graph_outputs_defs = fused_node.OutputDefs(); i = 0; for (auto output_def : graph_outputs_defs) { subgraph_context_.output_names.insert({output_def->Name(), i}); i++; } - subgraph_context_.subgraph_name = fused_node->Name(); - model_proto_ = GetModelProtoFromFusedNode(fused_node, logger); + subgraph_context_.subgraph_name = fused_node.Name(); + model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger); if (ModelHasBatchedInputs(*model_proto_) && GetGlobalContext().is_wholly_supported_graph && @@ -81,7 +83,7 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo concrete_backend_ = BackendFactory::MakeBackend(*model_copy, GetGlobalContext(), subgraph_context_); subgraph_context_.has_dynamic_input_shape = false; - } else if (ModelHasSymbolicInputDims(fused_node)) { + } else if (ModelHasSymbolicInputDims(subgraph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims. Defering backend initialization"; subgraph_context_.has_dynamic_input_shape = true; } else { @@ -123,9 +125,9 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod return has_batched_inputs; } -bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const { +bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { bool has_sym_dims = false; - auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs(); + auto graph_inputs = subgraph.GetInputs(); for (auto input : graph_inputs) { if (input->Shape() == nullptr) { has_sym_dims = true; @@ -145,25 +147,23 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_no } std::unique_ptr -BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, +BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const { - const auto* node_function = fused_node->GetFunctionBody(); - const std::string& name = fused_node->Name(); - ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", name); - - const onnxruntime::Graph& node_subgraph = node_function->Body(); - auto model = node_subgraph.CreateGraphViewer()->CreateModel(logger); + auto model = subgraph.CreateModel(logger); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - *model_proto->mutable_graph() = *node_subgraph.ToGraphProto(); + subgraph.ToProto(*model_proto->mutable_graph(), true, true); #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { + const std::string& name = fused_node.Name(); std::fstream dump(name + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } +#else + ORT_UNUSED_PARAMETER(fused_node); #endif return model_proto; diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 3ed2835d65..4bec0a8f26 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -13,7 +13,7 @@ namespace openvino_ep { // Singleton class that manages all the backends class BackendManager { public: - BackendManager(const onnxruntime::Node* fused_node, const logging::Logger& logger); + BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); void Compute(Ort::CustomOpApi api, OrtKernelContext* context); void ShutdownBackendManager(); static GlobalContext& GetGlobalContext(); @@ -21,8 +21,8 @@ class BackendManager { private: std::unique_ptr GetModelProtoFromFusedNode( - const onnxruntime::Node* fused_node, const logging::Logger& logger) const; - bool ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const; + const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const; + bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; std::shared_ptr diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 72ca41bcc0..8fb2c49921 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -131,18 +131,21 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const } common::Status OpenVINOExecutionProvider::Compile( - const std::vector& fused_nodes, + const std::vector& fused_nodes, std::vector& node_compute_funcs) { - for (const auto& fused_node : fused_nodes) { - NodeComputeInfo compute_info; - - #if defined (OV_API_20) - openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; - # else - openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false; - #endif + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; - std::shared_ptr backend_manager = std::make_shared(fused_node, *GetLogger()); + NodeComputeInfo compute_info; + +#if defined(OV_API_20) + openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; +#else + openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false; +#endif + + std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 9fab5e519f..8a3e65cbeb 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -151,8 +151,8 @@ class OpenVINOExecutionProvider : public IExecutionProvider { GetCapability(const GraphViewer& graph_viewer, const std::vector& kernel_registries) const override; - Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs) override; + Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; const void* GetExecutionHandle() const noexcept override { return nullptr; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index 2942d1267d..6afedaaf96 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -256,30 +256,24 @@ RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return result; } -common::Status RknpuExecutionProvider::Compile( - const std::vector& fused_nodes, - std::vector& node_compute_funcs) { - for (const auto* fused_node : fused_nodes) { - // Reconstruct graph proto from fused node's function body - const auto* func_body = fused_node->GetFunctionBody(); - if (!func_body) { - return common::Status(common::ONNXRUNTIME, - common::INVALID_ARGUMENT, "Function body is empty"); - } - const Graph& graph_body = func_body->Body(); - onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(), +common::Status RknpuExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; + onnxruntime::Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph_body.DomainToVersionMap(), + graph_body_viewer.DomainToVersionMap(), std::vector(), *GetLogger()); ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - *(model_proto.mutable_graph()) = graph_body.ToGraphProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); // Build map from input name to its index in input definitions std::unordered_map input_map; - const auto& input_defs = fused_node->InputDefs(); + const auto& input_defs = fused_node.InputDefs(); input_map.reserve(input_defs.size()); for (int i = 0, end = input_defs.size(); i < end; ++i) { input_map[input_defs[i]->Name()] = i; @@ -287,15 +281,15 @@ common::Status RknpuExecutionProvider::Compile( // Build map from output name to its index in output definitions std::unordered_map output_map; - const auto& output_defs = fused_node->OutputDefs(); + const auto& output_defs = fused_node.OutputDefs(); output_map.reserve(output_defs.size()); for (int i = 0, end = output_defs.size(); i < end; ++i) { output_map[output_defs[i]->Name()] = i; } - model_proto_[fused_node->Name()] = model_proto; - input_info_[fused_node->Name()] = input_map; - output_info_[fused_node->Name()] = output_map; + model_proto_[fused_node.Name()] = model_proto; + input_info_[fused_node.Name()] = input_map; + output_info_[fused_node.Name()] = output_map; NodeComputeInfo compute_info; compute_info.create_state_func = [&](ComputeContext* context, diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index 1316af5951..22ad809781 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -20,7 +20,7 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const override; - common::Status Compile(const std::vector& fused_nodes, + common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index ccbcd47cd1..9f0dea5b05 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -291,17 +291,11 @@ std::vector> IExecutionProvider::GetCapabilit const std::vector& kernel_registries) const { return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_registries); } - +// !!! This API will be deprecated soon. common::Status IExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { return g_host->IExecutionProvider__Compile(this, fused_nodes, node_compute_funcs); } - -common::Status IExecutionProvider::Compile(const std::vector& fused_nodes, - std::string& dll_path) { - return g_host->IExecutionProvider__Compile(this, fused_nodes, dll_path); -} - common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a9f85fe2b5..60662637e2 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -218,8 +218,9 @@ struct ProviderHost { virtual void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) = 0; virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const std::vector& kernel_registries) = 0; + //!!! This API will be deprecated soon virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes, std::vector& node_compute_funcs) = 0; - virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes, std::string& dll_path) = 0; + virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; virtual int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) = 0; @@ -286,6 +287,8 @@ struct ProviderHost { #endif // TypeProto + virtual std::unique_ptr TypeProto__construct() = 0; + virtual void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) = 0; virtual const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) = 0; virtual ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) = 0; @@ -368,6 +371,7 @@ struct ProviderHost { virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; virtual bool TensorProto_DataType_IsValid(int value) = 0; @@ -670,6 +674,8 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0; virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; + virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + // Path virtual PathString Path__ToPathString(const Path* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 59398e18bf..cf1266b565 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -159,6 +159,8 @@ struct TensorProto final { static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); } + void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); } + TensorProto() = delete; TensorProto(const TensorProto&) = delete; }; @@ -240,6 +242,8 @@ struct TypeProto_Sequence final { }; struct TypeProto final { + static std::unique_ptr Create() { return g_host->TypeProto__construct(); } + const TypeProto_Tensor& tensor_type() const { return g_host->TypeProto__tensor_type(this); } TypeProto_Tensor* mutable_tensor_type() { return g_host->TypeProto__mutable_tensor_type(this); } @@ -268,7 +272,10 @@ struct TypeProto final { ValueCase value_case() const { return ValueCase(g_host->TypeProto__value_case(this)); } - PROVIDER_DISALLOW_ALL(TypeProto) + void copy_from(const TypeProto* other) { return g_host->TypeProto__CopyFrom(this, other); } + + TypeProto() = delete; + TypeProto(const TypeProto&) = delete; }; struct ValueInfoProto final { @@ -705,6 +712,8 @@ struct GraphViewer final { const std::vector& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } + void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; void operator=(const GraphViewer&) = delete; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 287fdd522f..f9145b925c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -516,34 +516,6 @@ Status TensorrtExecutionProvider::SetComputeStream(void* stream) { return Status::OK(); } -// Convert GraphViewer graph to GraphProto -void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) { - for (const auto* input_arg : graph.GetInputs()) { - *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); - } - - // Add all graph's initializers to the subgraph - const auto& init_tensors = graph.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - *(graph_proto.mutable_initializer()->Add()) = *(tensor.second); - } - - for (const auto* output_arg : graph.GetOutputs()) { - *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); - } - - for (const auto* value_info : graph.GetValueInfo()) { - *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); - } - - // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. - for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { - const gsl::not_null node_proto{graph_proto.add_node()}; - const gsl::not_null p_node{graph.GetNode(node_idx)}; - p_node->ToProto(*node_proto); - } -} - std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const { const std::vector& node_index = graph.GetNodesInTopologicalOrder(); std::unordered_set node_set; @@ -799,7 +771,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto graph_viewer = graph_build.CreateGraphViewer(); auto model = graph_viewer->CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - ToGraphProtoInternal(*graph_viewer, *model_proto->mutable_graph()); + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -982,19 +954,19 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // Consolidate supported node list if (supported_nodes_vector.size() > 1) { - nodes_vector.clear(); - for (const auto& group : supported_nodes_vector) { - if (!group.first.empty()) { - nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); - } - } - SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; - if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) { - LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; - } else { - LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; - supported_nodes_vector = consolidated_supported_nodes_vector; + nodes_vector.clear(); + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); } + } + SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; + if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; + } else { + LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; + supported_nodes_vector = consolidated_supported_nodes_vector; + } } // Construct subgraph capability from node list @@ -1020,12 +992,14 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } -common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes, +common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { - for (const auto* fused_node : fused_nodes) { + for (auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; // Build map from input name to its index in input definitions std::unordered_map input_map; - const auto& input_defs = fused_node->InputDefs(); + const auto& input_defs = fused_node.InputDefs(); input_map.reserve(input_defs.size()); for (size_t i = 0, end = input_defs.size(); i < end; ++i) { input_map[input_defs[i]->Name()] = i; @@ -1033,29 +1007,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse // Build map from output name to its index in output definitions std::unordered_map output_map; - const auto& output_defs = fused_node->OutputDefs(); + const auto& output_defs = fused_node.OutputDefs(); output_map.reserve(output_defs.size()); for (size_t i = 0, end = output_defs.size(); i < end; ++i) { output_map[output_defs[i]->Name()] = i; } // Reconstruct graph proto from fused node's function body - const auto* func_body = fused_node->GetFunctionBody(); - if (!func_body) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - } - const Graph& graph_body = func_body->Body(); - auto graph_body_viewer = graph_body.CreateGraphViewer(); - auto model = graph_body_viewer->CreateModel(*GetLogger()); + auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - *model_proto->mutable_graph() = *graph_body.ToGraphProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); if (dump_subgraphs_) { // Dump TensorRT subgraphs - std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } @@ -1123,7 +1091,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } // Set precision flags - std::string trt_node_name_with_precision = fused_node->Name(); + std::string trt_node_name_with_precision = fused_node.Name(); if (fp16_enable_ && int8_enable_) { trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); trt_node_name_with_precision += "_fp16_int8"; @@ -1140,7 +1108,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse // Set DLA if (fp16_enable_ || int8_enable_) { - if (dla_enable_ && dla_core_ >= 0) { //DLA can only run with FP16 and INT8 + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 int number_of_dla_core = trt_builder->getNbDLACores(); if (number_of_dla_core == 0) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; @@ -1205,7 +1173,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse trt_context = tensorrt_ptr::unique_pointer(trt_engine->createExecutionContext()); if (trt_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build execution context for fused node: " + fused_node->Name()); + "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); } } @@ -1220,7 +1188,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } } - // If (1) engine cache enable is not set or (2) first time enable engine cache and no engine cache is present, // build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will // be built at runtime @@ -1231,7 +1198,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse trt_config->setInt8Calibrator(nullptr); if (!SetDynamicRange(*trt_network, dynamic_range_map)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node->Name()); + "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); } } @@ -1242,7 +1209,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build engine for fused node: " + fused_node->Name()); + "TensorRT EP could not build engine for fused node: " + fused_node.Name()); } if (engine_cache_enable_) @@ -1252,7 +1219,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse trt_context = tensorrt_ptr::unique_pointer(trt_engine->createExecutionContext()); if (trt_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build execution context for fused node: " + fused_node->Name()); + "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); } } } @@ -1280,15 +1247,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } // Save engine, context and input/output info to map - parsers_.emplace(fused_node->Name(), std::move(trt_parser)); - engines_.emplace(fused_node->Name(), std::move(trt_engine)); - contexts_.emplace(fused_node->Name(), std::move(trt_context)); - builders_.emplace(fused_node->Name(), std::move(trt_builder)); - networks_.emplace(fused_node->Name(), std::move(trt_network)); - input_info_[fused_node->Name()].push_back(input_indexes); - output_info_[fused_node->Name()].push_back(output_indexes); - output_info_[fused_node->Name()].push_back(output_types); - input_shape_ranges_[fused_node->Name()] = input_shape_ranges; + parsers_.emplace(fused_node.Name(), std::move(trt_parser)); + engines_.emplace(fused_node.Name(), std::move(trt_engine)); + contexts_.emplace(fused_node.Name(), std::move(trt_context)); + builders_.emplace(fused_node.Name(), std::move(trt_builder)); + networks_.emplace(fused_node.Name(), std::move(trt_network)); + input_info_[fused_node.Name()].push_back(input_indexes); + output_info_[fused_node.Name()].push_back(output_indexes); + output_info_[fused_node.Name()].push_back(output_types); + input_shape_ranges_[fused_node.Name()] = input_shape_ranges; // Create function state // TODO: remove default capture @@ -1310,7 +1277,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse compute_info.release_state_func = [](FunctionState state) { if (state) { // Serialize and save engine to cache - // + // // Note: only save engine to file if engine cache enable is set and engine is being updated due to input shape changed // or engine file is not previously existed TensorrtFuncState* trt_state = reinterpret_cast(state); @@ -1325,7 +1292,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { delete static_cast(state); ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt")); + "TensorRT EP could not call engine encryption function encrypt")); } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); @@ -1341,7 +1308,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; } } - + delete static_cast(state); } }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f920626934..e82d92bcff 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -127,7 +127,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { int GetDeviceId() const { return device_id_; } - common::Status Compile(const std::vector& fused_nodes, + common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc index 12eae2262c..7b5d46277a 100644 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc +++ b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc @@ -8,6 +8,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" +#include "core/graph/graph_proto_serializer.h" #include "core/platform/env.h" #include "core/graph/model.h" @@ -101,35 +102,33 @@ TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer, return result; } -common::Status TvmExecutionProvider::Compile(const std::vector& nodes, +common::Status TvmExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { printOptions(); - for (auto* fused_node : nodes) { - auto func_body = fused_node->GetFunctionBody(); - if (!func_body) - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - const std::string func_name = fused_node->Name(); - const Graph& node_graph = func_body->Body(); - Model model(node_graph.Name(), true, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), node_graph.DomainToVersionMap(), + for (auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; + const std::string func_name = fused_node.Name(); + Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), graph_body_viewer.DomainToVersionMap(), std::vector(), *GetLogger()); ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - - *(model_proto.mutable_graph()) = node_graph.ToGraphProto(); + //TVM EP is using static lib approach, so invoke serializer directly. + GraphViewerToProto(graph_body_viewer, *model_proto.mutable_graph(), true, true); auto opset = model_proto.add_opset_import(); opset->set_domain(kOnnxDomain); - opset->set_version(node_graph.DomainToVersionMap().at(kOnnxDomain)); + opset->set_version(graph_body_viewer.DomainToVersionMap().at(kOnnxDomain)); std::string onnx_model_str; model_proto.SerializeToString(&onnx_model_str); compilers_[func_name] = std::make_shared(std::move(onnx_model_str), - fused_node->ModelPath().ToPathString(), + fused_node.ModelPath().ToPathString(), int(opset->version())); InputsInfoMap all_input_shapes; - auto mod = compileModel(func_name, node_graph, all_input_shapes); + auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes); std::vector output_tensors; - prepareOutputTensors(mod, output_tensors, node_graph.GetOutputs().size()); + prepareOutputTensors(mod, output_tensors, graph_body_viewer.GetOutputs().size()); runners_[func_name] = std::make_shared(options_, mod, all_input_shapes, output_tensors); @@ -168,15 +167,15 @@ void TvmExecutionProvider::printOptions() { } std::shared_ptr TvmExecutionProvider::compileModel(const std::string& func_name, - const Graph& graph, + const GraphViewer& graph_viewer, InputsInfoMap& all_input_shapes) { all_input_shapes.clear(); TVMTensorShapes input_shapes; if (options_.freeze_weights) { - setInputShapesForFreezedNN(graph, input_shapes, all_input_shapes); + setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes); } else { - setInputShapesForUnfreezedNN(graph, input_shapes, all_input_shapes); + setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes); } std::shared_ptr mod = compilers_[func_name]->operator()(options_, input_shapes); @@ -184,14 +183,14 @@ std::shared_ptr TvmExecutionProvider::compileModel(const std::string& return mod; } -void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph, +void TvmExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph.GetInputsIncludingInitializers(); + const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); size_t indx = 0; for (const auto* node : all_nodes) { - if(!graph.IsInitializedTensor(node->Name())) { + if (!graph_viewer.IsInitializedTensor(node->Name())) { TensorShapeVector shape = getInputShape(node); all_input_shapes[indx++] = shape; input_shapes.emplace_back(shape); @@ -199,16 +198,16 @@ void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph, } } -void TvmExecutionProvider::setInputShapesForUnfreezedNN(const Graph& graph, +void TvmExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph.GetInputsIncludingInitializers(); + const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); size_t indx = 0; for (const auto* node : all_nodes) { TensorShapeVector shape = getInputShape(node); all_input_shapes[indx++] = shape; - if(!graph.IsInitializedTensor(node->Name())) { + if (!graph_viewer.IsInitializedTensor(node->Name())) { input_shapes.emplace_back(shape); } } diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_execution_provider.h index 9d891ee292..a46f818cb4 100644 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.h +++ b/onnxruntime/core/providers/tvm/tvm_execution_provider.h @@ -40,7 +40,7 @@ class TvmExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& /*kernel_registries*/) const override; - common::Status Compile(const std::vector& fused_nodes, + common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::unique_ptr GetDataTransfer() const override; AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; @@ -48,10 +48,10 @@ class TvmExecutionProvider : public IExecutionProvider { private: void printOptions(); std::shared_ptr compileModel(const std::string& func_name, - const Graph& graph, + const GraphViewer& graph_viewer, InputsInfoMap& inputs_info); - void setInputShapesForFreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes); - void setInputShapesForUnfreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes); + void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes); + void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes); TensorShapeVector getInputShape(const NodeArg* node); TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto); void prepareOutputTensors(const std::shared_ptr& mod, std::vector& output_tensors, size_t num); diff --git a/onnxruntime/core/providers/vitisai/vitisai_custom_op.cc b/onnxruntime/core/providers/vitisai/vitisai_custom_op.cc index f4f2364bfd..f0a75189bc 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_custom_op.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_custom_op.cc @@ -32,28 +32,22 @@ namespace onnxruntime { namespace vitisai_ep { -static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, +static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) { - const auto* node_function = fused_node->GetFunctionBody(); - - ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", - fused_node->Name()); - - const Graph& node_subgraph = node_function->Body(); - onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{}, - IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(), + onnxruntime::Model model{graph_viewer.Name(), true, ModelMetaData{}, PathString{}, + IOnnxRuntimeOpSchemaRegistryList{}, graph_viewer.DomainToVersionMap(), std::vector(), logger}; ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - *(model_proto.mutable_graph()) = node_subgraph.ToGraphProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); return model_proto; } VitisAICustomOp::VitisAICustomOp(const ComputeContext* context, - const onnxruntime::Node* fused_node, + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& graph_viewer, const std::string& backend_type, const std::string& export_runtime_module, const std::string& load_runtime_module, @@ -68,7 +62,7 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context, allocator_ = context->allocator_handle; name_ = context->node_name; - model_proto_ = GetModelProtoFromFusedNode(fused_node, *GetLogger()); + model_proto_ = GetModelProtoFromFusedNode(graph_viewer, *GetLogger()); std::istringstream model_stream{model_proto_.SerializeAsString()}; xg_ = pyxir::onnx::import_onnx_model(model_stream); @@ -78,12 +72,12 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context, if (load_runtime_module_.empty()) { pyxir::partition(xg_, std::vector{backend_type_}, ""); - auto input_defs = fused_node->InputDefs(); + auto input_defs = fused_node.InputDefs(); for (auto idef : input_defs) { in_tensor_names_.push_back(idef->Name()); } - auto output_defs = fused_node->OutputDefs(); + auto output_defs = fused_node.OutputDefs(); for (auto odef : output_defs) { out_tensor_names_.push_back(odef->Name()); } diff --git a/onnxruntime/core/providers/vitisai/vitisai_custom_op.h b/onnxruntime/core/providers/vitisai/vitisai_custom_op.h index 2175d3bb39..417bbe0500 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_custom_op.h +++ b/onnxruntime/core/providers/vitisai/vitisai_custom_op.h @@ -29,7 +29,8 @@ namespace vitisai_ep { class VitisAICustomOp { public: VitisAICustomOp(const ComputeContext* context, - const onnxruntime::Node* fused_node, + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& graph_viewer, const std::string& backend_type, const std::string& export_runtime_module, const std::string& load_runtime_module, diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 008e79f26a..db0c5e209c 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -273,12 +273,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes, +common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { - for (const auto& fused_node : fused_nodes) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; NodeComputeInfo compute_info; compute_info.create_state_func = [this, fused_node, logger = GetLogger()](ComputeContext* context, FunctionState* state) { - auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, backend_type_, export_runtime_module_, + auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, graph_body_viewer, backend_type_, export_runtime_module_, load_runtime_module_, logger); *state = p; return 0; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 8b869c9890..1282e0283f 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -29,7 +29,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { int GetDeviceId() const { return device_id_; } - common::Status Compile(const std::vector& fused_nodes, + common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; private: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 10d09fa40c..a7bd0a7006 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -930,7 +930,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // Do partitioning based on execution providers' capability. GraphPartitioner partitioner(kernel_registry_manager, providers); - ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(), + ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), layout_transformer::TransformLayoutForCompilingEP, mode)); @@ -1153,7 +1153,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, std::unordered_map compiled_kernel_hashes; GraphPartitioner partitioner(kernel_registry_manager, providers); - ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), + ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), layout_transformer::TransformLayoutForCompilingEP, GraphPartitioner::Mode::kOrtFormatLoad, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 4da4c066db..704a5a7f55 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -28,6 +28,7 @@ #include "core/util/math.h" #include "core/framework/sparse_utils.h" #include "core/common/string_helper.h" +#include "core/graph/graph_proto_serializer.h" #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -265,14 +266,11 @@ struct ProviderHostImpl : ProviderHost { void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) override { return p->IExecutionProvider::TryInsertAllocator(allocator); } std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const std::vector& kernel_registries) override { return p->IExecutionProvider::GetCapability(graph_viewer, kernel_registries); } + // !!! this api will be deprecated soon common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes, std::vector& node_compute_funcs) override { return p->IExecutionProvider::Compile(fused_nodes, node_compute_funcs); } - common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes, std::string& dll_path) override { - return p->IExecutionProvider::Compile(fused_nodes, dll_path); - } - common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs); } @@ -354,6 +352,8 @@ struct ProviderHostImpl : ProviderHost { #endif // TypeProto (wrapped) + std::unique_ptr TypeProto__construct() override { return std::make_unique(); } + void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) override { p->CopyFrom(*other); } const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) override { return p->tensor_type(); } ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) override { return p->mutable_tensor_type(); } int TypeProto__value_case(const ONNX_NAMESPACE::TypeProto* p) override { return p->value_case(); } @@ -441,6 +441,7 @@ struct ProviderHostImpl : ProviderHost { int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); } + void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); } // TensorProtos (wrapped) ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } @@ -762,6 +763,9 @@ struct ProviderHostImpl : ProviderHost { const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); } const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); } + void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); + } // Path (wrapped) PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index f36d19bc60..d375db14df 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1242,12 +1242,9 @@ TEST(ExecutionProviderTest, FunctionTest) { InferenceSession session_object_2{so, GetEnvironment()}; ASSERT_STATUS_OK( session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); - status = session_object_2.Load(model_file_name); - ASSERT_TRUE(status.IsOK()); - status = session_object_2.Initialize(); - ASSERT_TRUE(status.IsOK()); - status = session_object_2.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(session_object_2.Load(model_file_name)); + ASSERT_STATUS_OK(session_object_2.Initialize()); + ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index f632fd1ecc..95cf065c28 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -143,8 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { DefaultLoggingManager().DefaultLogger(), profiler); GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), - layout_transformer::TransformLayoutForCompilingEP); + status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(), + layout_transformer::TransformLayoutForCompilingEP); ASSERT_TRUE(status.IsOK()) << status; ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -209,7 +209,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // Partition the graph GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(), layout_transformer::TransformLayoutForCompilingEP); ASSERT_TRUE(status.IsOK()) << status; @@ -259,7 +259,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // Partition the graph GraphPartitioner partitioner(krm, execution_providers); - status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(), + status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(), layout_transformer::TransformLayoutForCompilingEP); ASSERT_TRUE(status.IsOK()) << status; diff --git a/onnxruntime/test/nuphar_tvm/tvm_basic_test.cc b/onnxruntime/test/nuphar_tvm/tvm_basic_test.cc deleted file mode 100644 index 6fa1cf3b25..0000000000 --- a/onnxruntime/test/nuphar_tvm/tvm_basic_test.cc +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" -#include "core/common/logging/logging.h" -#include "core/framework/compute_capability.h" -#include "core/framework/execution_provider.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/op_kernel.h" -#include "core/graph/graph_viewer.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/inference_session.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "test/framework/test_utils.h" -#include "test/test_environment.h" -#include "test/nuphar_tvm/tvm_demo/demo_compiler.h" - -#include - -namespace onnxruntime { - -using namespace tvm_demo; - -class TVMDemoKernel : public OpKernel { - public: - explicit TVMDemoKernel(const OpKernelInfo& info) : OpKernel(info) {} - - protected: - const TensorShape& GetOutputShape(OpKernelContext* context, int /*i*/) const { - return context->Input(0)->Shape(); - } -}; - -class UnionSet { - public: - UnionSet(int n) { - for (int i = 0; i < n; ++i) { - farthers_.push_back(i); - } - } - - int get(int x) { - if (farthers_[x] == x) { - return x; - } - return farthers_[x] = get(farthers_[x]); - } - - void merge(int x, int y) { - x = get(x); - y = get(y); - if (x != y) { - farthers_[y] = x; - } - } - - std::vector farthers_; -}; - -static DLDataType GetDataType(ONNXTensorElementDataType type) { - if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return {kDLFloat, 64, 1}; - } else - ORT_THROW("not implement."); -} - -namespace test { - -struct TVMFuncState { - AllocateFunc test_allocate_func = nullptr; - DestroyFunc test_release_func = nullptr; - AllocatorHandle allocator = nullptr; - tvm::runtime::Module* module = nullptr; -}; - -class FuseExecutionProviderX : public CPUExecutionProvider { - public: - explicit FuseExecutionProviderX(const CPUExecutionProviderInfo& info) : CPUExecutionProvider(info) { - } - - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& /*kernel_registries*/) const override { - std::vector> result; - std::vector fused_nodes; - for (auto& node : graph_viewer.Nodes()) { - if (node.OpType() == "Mul") { - fused_nodes.push_back(node.Index()); - } - } - - UnionSet set(static_cast(fused_nodes.size())); - for (int i = 0; i < fused_nodes.size(); ++i) { - auto node = graph_viewer.GetNode(fused_nodes[i]); - for (auto it = node->InputNodesBegin(); it != node->InputNodesEnd(); ++it) { - auto index_it = std::find(fused_nodes.begin(), fused_nodes.end(), (*it).Index()); - if (index_it != fused_nodes.end()) { - set.merge(i, static_cast(index_it - fused_nodes.begin())); - } - } - } - - std::vector> groups; - groups.resize(fused_nodes.size()); - for (int i = 0; i < set.farthers_.size(); ++i) { - groups[set.get(i)].push_back(fused_nodes[i]); - } - - for (auto& group : groups) { - if (group.size() > 1) { - std::unique_ptr sub_graph = std::make_unique(); - std::set fused_inputs, fused_outputs; - for (auto index : group) { - sub_graph->nodes.push_back(index); - auto node = graph_viewer.GetNode(index); - for (auto input : node->InputDefs()) { - auto it = fused_outputs.find(input); - if (it != fused_outputs.end()) { - fused_outputs.erase(it); - } else { - fused_inputs.insert(input); - } - } - for (auto output : node->OutputDefs()) { - auto it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - } else { - fused_outputs.insert(output); - } - } - } - - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "TVMFuse"; - meta_def->domain = "FuseTest"; - for (auto input : fused_inputs) { - meta_def->inputs.push_back(input->Name()); - } - - for (auto output : fused_outputs) { - meta_def->outputs.push_back(output->Name()); - } - - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - sub_graph->SetMetaDef(std::move(meta_def)); - //TODO:set fuse kernel func; - result.push_back( - std::make_unique(std::move(sub_graph))); - } - } - return result; - } - - common::Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs) override { - for (auto* fused_node : fused_nodes) { - auto func_body = fused_node->GetFunctionBody(); - if (!func_body) - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); - //1. Build tvm IR based on the Ort graph - auto demo_tvm_tensor_ctx = BuildTVMIR(func_body->Body()); - //2. Create schedule for the built tvm IRs - auto s = CreateSchedule(demo_tvm_tensor_ctx); - //3. Build tvm module - std::vector tvm_args; - for (auto& t : demo_tvm_tensor_ctx.inputs) { - tvm_args.push_back(t); - } - for (auto& t : demo_tvm_tensor_ctx.outputs) { - tvm_args.push_back(t); - } - - std::vector func_names; - auto module_ptr = std::make_shared(); - *module_ptr = BuildStackVMModule(s, tvm::build_config(), tvm_args, func_names); - modules_[fused_node->Name()] = module_ptr; - - NodeComputeInfo compute_info; - - compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { - auto* p = new TVMFuncState(); - *p = {context->allocate_func, context->release_func, context->allocator_handle, modules_[context->node_name].get()}; - *state = p; - return 0; - }; - - compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); - }; - - //we use lambda to capture the tvm model, so we can use it to get the funciton. - compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) { - Ort::CustomOpApi ort{*api}; - - TVMFuncState* tvm_state = reinterpret_cast(state); - - std::vector> input_shapes; - std::vector> output_shapes; - - auto eval_func_name = "func"; - DLContext cpu_context = {kDLCPU, 0}; - size_t num_inputs = ort.KernelContext_GetInputCount(context); - size_t num_outputs = ort.KernelContext_GetOutputCount(context); - size_t n_args = num_inputs + num_outputs; - std::vector dl_tensors(n_args); - std::vector tvm_values(n_args); - std::vector tvm_type_codes(n_args); - for (auto i = 0; i < num_inputs; i++) { - const OrtValue* input_tensor = ort.KernelContext_GetInput(context, i); - auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); - auto tensor_type = ort.GetTensorElementType(tensor_info); - input_shapes.emplace_back(ort.GetTensorShape(tensor_info)); - ort.ReleaseTensorTypeAndShapeInfo(tensor_info); - - tvm_type_codes[i] = kNDArrayContainer; - dl_tensors[i].ctx = cpu_context; - dl_tensors[i].dtype = GetDataType(tensor_type); - dl_tensors[i].strides = nullptr; - dl_tensors[i].byte_offset = 0; - dl_tensors[i].data = const_cast(ort.GetTensorData(input_tensor)); - dl_tensors[i].ndim = input_shapes.back().size(); - dl_tensors[i].shape = input_shapes.back().data(); - tvm_values[i].v_handle = &dl_tensors[i]; - } - - for (auto i = 0; i < num_outputs; i++) { - //setup output tensor property - //todo: type should be set by framework. - output_shapes.push_back(input_shapes[i]); - OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, output_shapes[i].data(), output_shapes[i].size()); - auto tensor_info = ort.GetTensorTypeAndShape(output_tensor); - auto tensor_type = ort.GetTensorElementType(tensor_info); - ort.ReleaseTensorTypeAndShapeInfo(tensor_info); - - tvm_type_codes[num_inputs + i] = kNDArrayContainer; - dl_tensors[num_inputs + i].ctx = cpu_context; - dl_tensors[num_inputs + i].dtype = GetDataType(tensor_type); - dl_tensors[num_inputs + i].strides = nullptr; - dl_tensors[num_inputs + i].byte_offset = 0; - dl_tensors[num_inputs + i].data = ort.GetTensorMutableData(output_tensor); - dl_tensors[num_inputs + i].ndim = output_shapes.back().size(); - dl_tensors[num_inputs + i].shape = output_shapes.back().data(); - tvm_values[num_inputs + i].v_handle = &dl_tensors[num_inputs + i]; - } - - auto evaluate_func_ = tvm_state->module->GetFunction(eval_func_name); - tvm::TVMArgs tvm_args(&tvm_values[0], &tvm_type_codes[0], static_cast(n_args)); - tvm::TVMRetValue rvalue; - try { - evaluate_func_.CallPacked(tvm_args, &rvalue); - } catch (std::exception&) { - return Status(common::ONNXRUNTIME, common::FAIL); // TODO: Translate exception to error code - } - if (rvalue.type_code() != kNull) { - return Status(common::ONNXRUNTIME, common::FAIL); // TODO: get error code. - } else { - return Status::OK(); - } - }; - node_compute_funcs.push_back(compute_info); - } - - return Status::OK(); - } - - private: - std::unordered_map> modules_; -}; - -static void RunSession(InferenceSession& session_object, - RunOptions& run_options, - std::vector& dims_x, - std::vector& values_x, - std::vector& dims_y, - std::vector& values_y) { - // prepare inputs - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("X1", ml_value)); - - // prepare outputs - std::vector output_names; - output_names.push_back("Y4"); - std::vector fetches; - - // Now run - common::Status st = session_object.Run(run_options, feeds, output_names, &fetches); - if (!st.IsOK()) { - std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; - } - EXPECT_TRUE(st.IsOK()); - ASSERT_EQ(1, fetches.size()); - auto& rtensor = fetches.front().Get(); - TensorShape expected_shape(dims_y); - EXPECT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.template Data(), rtensor.template Data() + expected_shape.Size()); - ASSERT_EQ(found.size(), values_y.size()); - for (size_t i = 0; i < found.size(); i++) - ASSERT_EQ(found[i], values_y[i]); -} - -static const std::string MODEL_URI = "testdata/fuse_mul_1.onnx"; - -TEST(TVMTest, CodeGen_Demo_for_Fuse_Mul) { - SessionOptions so; - - so.session_logid = "InferenceSessionTests.NoTimeout"; - - InferenceSession session_object{so, GetEnvironment()}; - CPUExecutionProviderInfo info; - auto tvm_xp = std::make_unique(info); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(tvm_xp)).IsOK()); - EXPECT_TRUE(session_object.Load(MODEL_URI).IsOK()); - EXPECT_TRUE(session_object.Initialize().IsOK()); - - RunOptions run_options; - run_options.run_tag = "one session/one tag"; - - // prepare inputs - std::vector dims_x = { - 6, - }; - std::vector values_x = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - - // prepare expected inputs and outputs - std::vector expected_dims_y = { - 6, - }; - // now the expected value should be Mul's result. - std::vector expected_values_y = {1.0, 32.0, 243.0, 1024.0, 3125.0, 7776.0}; - - // Now run - RunSession(session_object, run_options, dims_x, values_x, expected_dims_y, expected_values_y); -} -} // namespace test - -} // namespace onnxruntime - -TEST(TVMTest, Native_TVM) { - using namespace tvm; - auto n = var("n"); - Array shape; - shape.push_back(n); - auto A = placeholder(shape, Float(64), "A"); - auto B = placeholder(shape, Float(64), "B"); - auto D = placeholder(shape, Float(64), "D"); - auto C = compute( - A->shape, [&A, &B](Expr i) { - return A[i] + B[i]; - }, - "C"); - auto E = compute( - A->shape, [&C, &D](Expr i) { - return C[i] + D[i]; - }, - "E"); - - auto s = create_schedule({E->op}); - auto args = Array({A, B, D, E}); - std::unordered_map binds; - auto config = build_config(); -#ifdef USE_NUPHAR_TVM_WITH_LLVM - auto target = target::llvm(); -#else - auto target = target::stackvm(); -#endif - auto lowered = lower(s, args, "func", binds, config); - auto module = build(lowered, target, Target(), config); - auto func = module.GetFunction("func"); - - DLDataType dtype; - dtype.code = kDLFloat; - dtype.bits = 64; - dtype.lanes = 1; - DLContext ctx; - ctx.device_type = DLDeviceType::kDLCPU; - ctx.device_id = 0; - - std::vector v = {1.0, 2.0, 3.0}; - int64_t len = 3; - DLTensor tensor_A = {&v[0], ctx, 1, dtype, &len, nullptr, 0}; - DLTensor tensor_B = {&v[0], ctx, 1, dtype, &len, nullptr, 0}; - DLTensor tensor_D = {&v[0], ctx, 1, dtype, &len, nullptr, 0}; - - std::vector r; - r.resize(len); - DLTensor tensor_E = {&r[0], ctx, 1, dtype, &len, nullptr, 0}; - - TVMValue lvalues[4]; - int type_codes[4] = {kNDArrayContainer, kNDArrayContainer, kNDArrayContainer, kNDArrayContainer}; - lvalues[0].v_handle = &tensor_A; - lvalues[1].v_handle = &tensor_B; - lvalues[2].v_handle = &tensor_D; - lvalues[3].v_handle = &tensor_E; - - TVMArgs tvm_args(lvalues, type_codes, 4); - TVMRetValue rvalue; - func.CallPacked(tvm_args, &rvalue); - CHECK_EQ(rvalue.type_code(), kNull); - double expected[3] = {3.0, 6.0, 9.0}; - auto data_E = static_cast(tensor_E.data); - for (int i = 0; i < 3; i++) { - EXPECT_NEAR(*(data_E + i), expected[i], 0.001f); - } -} diff --git a/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.cc b/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.cc deleted file mode 100644 index ff2f9076a7..0000000000 --- a/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test/nuphar_tvm/tvm_demo/demo_compiler.h" - -#include "core/codegen/passes/scheduler/schedule_utils.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h" -#include "core/codegen/passes/scheduler/tvm_scheduler.h" -#include "core/codegen/passes/scheduler/tvm_schedule_builder.h" - -#include -#include - -namespace onnxruntime { -namespace tvm_demo { - -// Create a dummy demo handle -static codegen::CodeGenHandle demo_handle; -// Create a dummy demo codegen context -static tvm_codegen::CodeGenContext demo_codegen_ctx(&demo_handle); - -// Translate an Ort graph into tvm IR -// Note this function is specific for this demo. -// This function uses specific way for graph traversal or constructing tvm placeholders. -// It may or may not work for a universal Ort graph. -// For a more general example, please check nuphar provider. -DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph) { - // Create OpIRRegistry that holds all OpIRCreators - std::unique_ptr op_ir_registry = - std::make_unique(); - - // Register all generic OpIRCreators - tvm_codegen::RegisterAllGenericOpIRCreators(op_ir_registry.get()); - - // Create OpIRBuilder - std::shared_ptr op_ir_builder = - std::make_shared("Demo_Op_IR_Builder"); - - // Attach all generic OpIRCreators from op_ir_registry to op_ir_builder - tvm_codegen::RegisterGenericOrtOpTypeDispatcher(op_ir_builder, op_ir_registry.get()); - - // Create DemoTVMTensorCtx holdings tvm IR - DemoTVMTensorCtx result; - - // Local lookup from name to tvm::Tensor - std::unordered_map tvm_tensors; - - // Note this is a simplified traversal that works specifically for this demo - // but may or may not work for an univerisal model. - // For more general traversal, please check nuphar provider. - for (auto& node : graph.Nodes()) { - tvm::Array inputs; - tvm::Array outputs; - - // Get inputs - for (auto& def : node.InputDefs()) { - const std::string& name = def->Name(); - auto iter = tvm_tensors.find(name); - // Always create placeholder when not finding a tensor - // Note it is for this demo. - // It may or may not work for a universal graph. - if (iter == tvm_tensors.end()) { - tvm_tensors[name] = - tvm::placeholder(ShapeToTvmArray(def, demo_codegen_ctx), - tvm_codegen::ToTvmType(TensorProtoDataType(def)), - name + "_placeholder"); - } - inputs.push_back(tvm_tensors[name]); - } - - // call OpIBuilder's Evaluate to build tvm IR - ORT_THROW_IF_ERROR(op_ir_builder->Evaluate(inputs, node, demo_codegen_ctx, outputs)); - - // Store outputs - for (size_t def_id = 0; def_id < node.OutputDefs().size(); ++def_id) { - const NodeArg* def = node.OutputDefs()[def_id]; - tvm_tensors[def->Name()] = outputs[def_id]; - } - } - - // put inputs to DemoTVMTensorCtx - for (auto& input : graph.GetInputs()) { - result.inputs.push_back(tvm_tensors[input->Name()]); - } - - // check initializer - for (auto& initializer : graph.GetAllInitializedTensors()) { - result.inputs.push_back(tvm_tensors[initializer.first]); - } - - // Only one output in this demo - auto& output = graph.GetOutputs()[0]; - result.outputs.push_back(tvm_tensors[output->Name()]); - return result; -} - -// Declare a Demo scheduler that always inserts compute_inline -DECLARE_TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM) - -// Define a Demo scheduler's Evaluate that always inserts compute_inline -bool TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - tvm_codegen::CodeGenContext&, - tvm_codegen::ScheduleContext& ctx_sched) { - return TryInlineSchedule(tensor, ctx_sched); -} - -// Register the always inline Scheduler to sched_registry -static void RegisterAlwaysInlineScheduler(tvm_codegen::TVMScheduleRegistry* sched_registry) { - sched_registry->Register( - std::make_unique()); -} - -// Declare a schedule dispatcher that always dispatches the always inline Scheduler -DECLARE_SCHEDULE_DISPATCHER_CLASS(DemoTVM) - -// Use a predefined key as DemoKey to dispatch the scheduler -constexpr auto predefined_key = "DemoKey"; - -// Define the schedule dispatcher's Find function -// that always dispatches the always inline Scheduler -// Note this dispatcher always returning a predefined_key is only for demo purpose. -// In practice, a dispatcher returns a key by checking tvm::Tensor, Node, -// or even meta data stored in CodeGenContext. -// Derived CodeGenContext allows compiler developers to store their specific meta data. -// For more detailed example, please check nuphar provider. -tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(DemoTVM)::Find( - const tvm::Tensor&, const Node*, tvm_codegen::CodeGenContext&) { - return DispatcherBase::Get(predefined_key); -} - -// Attach the always inline Scheduler to the above dispatcher -// and then attach the dispatcher to the scheduler builder -static void AttachAlwaysInlineScheduler(const std::shared_ptr& builder, - const tvm_codegen::TVMScheduleRegistry* registry) { - auto dispatcher = std::make_unique("DemoSchedulers"); - - // Using a predefined_key - dispatcher->Register(predefined_key, - registry->Get(TVM_SCHEDULER_STRING(AlwaysInline, DemoTVM))); - - builder->InsertDispatcher(std::move(dispatcher)); -} - -// Traverse tvm::Tensor and then schedule them -// Note this traversal is simplified and specific for this demo. -// For a more general traversal, please check nuphar provider. -static void TraverseAndSchedule( - std::shared_ptr& schedule_builder, - const tvm::Tensor& tensor, - tvm_codegen::ScheduleContext& ctx_schedule) { - ORT_THROW_IF_ERROR(schedule_builder->Evaluate(tensor, nullptr, demo_codegen_ctx, ctx_schedule)); - - // Traverse tensor's children (inputs) - for (auto& t : tensor->op->InputTensors()) { - // check whether it is a non-trivial tensor by checking its input size - if (t->op->InputTensors().size() > 0) { - TraverseAndSchedule(schedule_builder, t, ctx_schedule); - } - } -} - -// Create a TVM schedule by always inserting tvm's compute_inline. -// Note this schedule is specific for this demo. -// In practice, always inline might lead to bad performance -// or even illegal loop transformation for some backends. -// For a more general example, please check nuphar provider. -tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx) { - // Create TVMScheduleRegistry that holds all Scheduler - std::unique_ptr schedule_registry = - std::make_unique(); - - // Register the always inline Scheduler to schedule_registry - RegisterAlwaysInlineScheduler(schedule_registry.get()); - - // Create a DemoScheduleBuilder - std::shared_ptr schedule_builder = - std::make_shared("Demo_Schedule_Builder"); - - // Attach the demo inline scheduler to the schedule_builder - AttachAlwaysInlineScheduler(schedule_builder, schedule_registry.get()); - - // Create scheudule object - tvm::Array out_ops; - for (auto& t : ctx.outputs) { - out_ops.push_back(t->op); - } - - // Create scheudule context - tvm_codegen::ScheduleContext ctx_schedule(out_ops); - - // Traverse tvm::Tensor in a DFS way, and then schedule - for (auto& t : ctx.outputs) { - TraverseAndSchedule(schedule_builder, t, ctx_schedule); - } - - // Make sure all outputs compute_root (tvm's requirement) - for (auto& t : ctx.outputs) { - tvm_codegen::InsertRootSchedule(t, ctx_schedule); - } - - return ctx_schedule.schedule; -} - -// Build TVM Module with a schedule using tvm's stackvm. -// Note in real practice, please change stackvm to other backends. -// For a more detailed example, please check nuphar provider. -tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule, - tvm::BuildConfig config, - tvm::Array tvm_args, - std::vector& target_func_names) { - auto target = tvm::target::stackvm(); - std::string func_name = "func"; - auto args = tvm::Array(tvm_args); - std::unordered_map binds; - auto lowered = lower(schedule, args, "func", binds, config); - // Uncomment the following line to dump lowered func - // std::cout << "Dumping lowered func: " << lowered[0]->body; - target_func_names.push_back(func_name); - return build(lowered, target, tvm::Target(), config); -} - -} // namespace tvm_demo -} // namespace onnxruntime diff --git a/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.h b/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.h deleted file mode 100644 index 3905995293..0000000000 --- a/onnxruntime/test/nuphar_tvm/tvm_demo/demo_compiler.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/graph/graph_viewer.h" -#include -#include - -namespace onnxruntime { -namespace tvm_demo { -// A Demo data structure to hold tvm IR and context -struct DemoTVMTensorCtx { - tvm::Array inputs; - tvm::Array outputs; -}; - -// Translate an Ort graph into tvm IR -DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph); - -// Create a demo schedule for the tvm IR -tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx); - -// Build a demo tvm module with the tvm IR and schedule -tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule, - tvm::BuildConfig config, - tvm::Array tvm_args, - std::vector& target_func_names); - -} // namespace tvm_demo -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 9de0490d61..14b180c2d7 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -21,10 +21,6 @@ class InternalTestingExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; - FusionStyle GetFusionStyle() const override { - return FusionStyle::FilteredGraphViewer; - } - DataLayout GetPreferredLayout() const override; private: