Unify the Compile API for mobile build and normal build (#10632)

* use the lightweight compile api as default; use dnnl ep for testing

* apply to tensorrt ep

* fix the missing files

* fix build

* fix the copy issue on linux

* migrate migraphx and openvino ep

* fix openvino build break

* fix linux build

* fix unused parameter

* fix coreml build

* use graph view's filtered initializers

* fix openvino break

* fix tvm compile api

* fix tvm / rknpu / vitisai ep build

* add IsInitializedTensor in graph_viewer; fix nuphar build

* use serializer directly as tvm ep is still static lib

* fix the type mismatch

* fix the type mismatch

* fix merge conflict

* add a comment

* fix minimal build

* fix the DML EP's legacy approach

* save type/shape in dnnl IR

* fix linux break

* fix tvm failure

* dnnl ep: move initializer referenced out of dnnl subgraph

* Revert "add IsInitializedTensor in graph_viewer; fix nuphar build"

This reverts commit 1cc3c7f08c16fee4fe3309a67209eb769d479587.

* add IsInitializedTensor to graph viewer

* add the legacy code for nuphar build to temporarily make nuphar build work

* ignore internal test for nuphar

* remove the out of date tests

* keep the legacy API in EP for a while

* turn serializer into a static function

* update comments

* fix tvm build

* Update include/onnxruntime/core/framework/execution_provider.h

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* Update include/onnxruntime/core/framework/execution_provider.h

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* Update onnxruntime/core/framework/execution_provider.cc

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* updatee comments; add warning message for legacy compil call

* add a flag to control out of scope arg in serialization

* fix trt  build; improve the test

* resolve merege errors

* fix a typo

Co-authored-by: Cheng Tang <chenta@microsoft.com>
Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
Tang, Cheng 2022-05-05 08:30:07 -07:00 committed by GitHub
parent eca4cbc419
commit 3f3c5fcd68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 516 additions and 1161 deletions

View file

@ -30,6 +30,12 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/graph/function*"
)
# remove graph proto serializer
list(APPEND onnxruntime_graph_src_exclude_patterns
"${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.cc"
"${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.h"
)
# no optimizer support in base minimal build
# some optimizer support in extended minimal build
if (NOT onnxruntime_EXTENDED_MINIMAL_BUILD)

View file

@ -348,7 +348,7 @@ if (onnxruntime_USE_RKNPU)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src})
endif()
if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD)
if ((NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_NUPHAR) OR onnxruntime_EXTENDED_MINIMAL_BUILD)
file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/internal_testing/*"
)
@ -464,11 +464,10 @@ if(onnxruntime_USE_COREML)
endif()
if(onnxruntime_USE_NUPHAR)
file(GLOB_RECURSE onnxruntime_test_nuphar_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/nuphar_tvm/*.h"
"${TEST_SRC_DIR}/nuphar_tvm/*.cc"
)
# the test case under nuphar_tvm is only to verify some basic tvm show case, which is already out of date
# it doesn't have relationship to nuphar directly. consider we have an official tvm execution provider now,
# keep those test cases doesn't bring any value now.
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/framework/nuphar/*)
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nuphar)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nuphar)

View file

@ -197,48 +197,15 @@ class IExecutionProvider {
// TODO: temparary sulotion, need to unify the interface in EP and AllocatorManager
void TryInsertAllocator(AllocatorPtr allocator);
// creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support
// compilation via GraphViewer instances.
#if !defined(ORT_MINIMAL_BUILD)
/**
Given a list of fused_node, return create_state/compute/release_state func for each node.
*/
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs);
/**
Given a list of fused_node, return a dll that expose functions for each node.
For each node, there should be three symbols:
Create_State_${node_name}
Compute_${node_name}
Release_State_${node_name}
*/
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::string& dll_path);
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
struct FusedNodeAndGraph {
const std::reference_wrapper<onnxruntime::Node> fused_node;
// GraphViewer that filters the full graph to the nodes that are covered by 'node'
const std::reference_wrapper<GraphViewer> filtered_graph;
};
/**
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
return create_state/compute/release_state func for each node.
@remarks This is an optional interface that is only needed if the execution provider compiles nodes
in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model.
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
as it is only valid for the duration of the call to Compile.
*/
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs);
#endif
// Fusion approach that is suppported
// !!! The "Function" FusionStyle will be deprecated soon.
// !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
enum class FusionStyle {
// The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
// in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
@ -254,12 +221,35 @@ class IExecutionProvider {
};
virtual FusionStyle GetFusionStyle() const {
// existing EPs use this mode so default to it.
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
// FilteredGraphViewer
return FusionStyle::Function;
// All the ORT build in EP has migrate to FilteredGraphViewer style except Nuphar.
// For newer EPs, please avoid use Function style as it will be deprecated soon.
return FusionStyle::FilteredGraphViewer;
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/**
* !!!! This API will be deprecated soon. If your execution provider overrides this API
* !!!! Please migrate it to the "Compile" API with FusedNodeAndGraph type.
Given a list of fused_node, return create_state/compute/release_state func for each node.
*/
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs);
/**
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
return create_state/compute/release_state func for each node.
@remarks This is now the default interface when execution provider wants to compile nodes
for both minimal build and complete ort build.
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
as it is only valid for the duration of the call to Compile.
*/
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs);
#endif
void SetLogger(const logging::Logger* logger) {
logger_ = logger;
}

View file

@ -1258,6 +1258,10 @@ class Graph {
return Resolve(default_options);
}
const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept{
return outer_scope_node_arg_names_;
}
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph) const;

View file

@ -149,6 +149,10 @@ class GraphViewer {
/** Get the internal graph*/
const Graph& GetGraph() const { return *graph_; }
#if !defined(ORT_MINIMAL_BUILD)
const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept;
#endif
/**
returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
@ -156,6 +160,9 @@ class GraphViewer {
*/
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
/** Check if a given name is an initializer tensor's name in this graph. */
bool IsInitializedTensor(const std::string& name) const;
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
@param check_outer_scope If true and the graph is a subgraph,

View file

@ -151,26 +151,21 @@ void IExecutionProvider::RegisterAllocator(std::shared_ptr<AllocatorManager>) {
return;
}
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// !!!!This API will be deprecated soon. If your execution provider overrides this API
// !!!!Please migrate it to the "Compile" API with FusedNodeAndGraph type.
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with fused Node is not implemented by " + type_);
}
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
std::string& /*dll_path*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_);
}
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
"IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_);
}
#endif
int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer,

View file

@ -42,6 +42,15 @@ NonCudaOps non_cuda;
using namespace ::onnxruntime::common;
namespace onnxruntime {
#if !defined(ORT_MINIMAL_BUILD)
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
auto schema = node.Op();
builder.SetName(schema->Name())
.SetDomain(schema->domain())
.SinceVersion(schema->SinceVersion())
.Provider(node.GetExecutionProviderType());
}
#endif
// minimal KernelDef based on MetaDef instead of a Function based node
static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef,
@ -103,7 +112,7 @@ static void AssignNodes(Graph& graph, const IndexedSubGraph& capability,
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep,
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
TransformLayoutFunction transform_layout) {
{
GraphViewer graph_viewer(graph);
@ -142,15 +151,6 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg
}
#if !defined(ORT_MINIMAL_BUILD)
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
auto schema = node.Op();
builder.SetName(schema->Name())
.SetDomain(schema->domain())
.SinceVersion(schema->SinceVersion())
.Provider(node.GetExecutionProviderType());
}
/**
* Check if a node can be placed on a specific provider.
* Do nothing if the node is already assigned
@ -162,8 +162,8 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::No
* \return Fused node. Return nullptr if there is no fuse
*/
static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type,
IExecutionProvider::FusionStyle fusion_style,
const std::string& provider_type,
GraphPartitioner::Mode mode,
int& fused_node_unique_id) {
Node* result = nullptr;
@ -216,7 +216,18 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
std::string node_name = oss.str();
Node* fused_node = nullptr;
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
// TODO1: The DML currently use some legacy approach.
// It registers a generic predefined kernel for all purpose fusion,
// so it rely on the function body in the fused node during kernel creation,
// which is after the graph partition phase.
// Ideally, it should be moved to "Compile" call.
// Here we temporary keep the function body for DML fusion
// Need to remove it after migrate DML to the Compile-based approach.
// TODO2: Nuphar is out of maintain, keep it with old API temporarily.
// We want to deprecate Nuphar soon.
if (fusion_style == IExecutionProvider::FusionStyle::Function ||
provider_type == kDmlExecutionProvider ||
provider_type == kNupharExecutionProvider) {
fused_node = &graph.FuseSubGraph(capability, node_name);
} else {
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
@ -226,10 +237,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
fused_node->SetExecutionProviderType(provider_type);
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) {
result = fused_node;
}
result = fused_node;
} else {
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
@ -250,7 +258,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
// for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up).
// assign any nodes to the EP that are currently unassigned, and that the EP can handle.
static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
KernelRegistryManager& kernel_registry_mgr,
KernelRegistry& fused_kernel_registry,
IExecutionProvider& current_ep,
@ -268,7 +276,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
// we pass through the export_dll value and FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
transform_layout_function));
}
@ -295,6 +303,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
const std::string& type = current_ep.Type();
auto fusion_style = current_ep.GetFusionStyle();
std::vector<Node*> nodes_to_compile;
// The fused node may map to an existing kernel, so it is fused but doesn't need to be compiled
// But we still need to finalize the graph fusion for those nodes.
std::vector<Node*> nodes_to_complete_fuse;
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_complete_fuse;
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
// nodes_to_compile.
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
@ -311,42 +325,46 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
continue;
}
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id);
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
if (n != nullptr) {
nodes_to_compile.push_back(n);
capabilities_to_compile.push_back(std::move(capability));
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
nodes_to_compile.push_back(n);
capabilities_to_compile.push_back(std::move(capability));
} else {
// there is a predefined kernel for the fused node. doesn't need compile, but need to complete the fusing.
nodes_to_complete_fuse.push_back(n);
capabilities_to_complete_fuse.push_back(std::move(capability));
}
}
}
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
// even with single node, EP might still want to compile it.
// for example, it want to JIT an optimized kernel for LSTM with a given shape.
if (!nodes_to_compile.empty()) {
std::vector<NodeComputeInfo> node_compute_funcs;
if (export_dll) {
ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function,
"Must use Function based fusion when exporting compiled nodes to dll.");
}
// !!! The Function style fusion will be deprecated soon.
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
// TODO: Nuphar is out of maintain. Use the old api temporarily.
// We want to deprecate it soon.
// Create a Function based node where the fused nodes have a new Graph instance.
static std::once_flag legacy_compile_method_warning_flag;
std::call_once(
legacy_compile_method_warning_flag, [](std::string_view ep_type) {
LOGS_DEFAULT(WARNING) << "Execution Provider: " << ep_type << " is still using Funciton style Compile API, "
<< " which will be deprecated soon, please migrate to the new Compile API based on "
<< " FilteredGraphViewer. ";
},
type);
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
if (export_dll) {
std::string dll_path;
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path));
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}
for (auto* node : nodes_to_compile) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
}
} else {
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
}
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
}
for (auto* node : nodes_to_compile) {
@ -358,12 +376,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
return FunctionKernel::Create(func_mgr, info, out);
}));
}
} else {
// temporary storage for the GraphViewer for each IndexedSubGraph
std::vector<std::unique_ptr<GraphViewer>> viewers;
viewers.reserve(nodes_to_compile.size());
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
nodes_and_viewers.reserve(nodes_to_compile.size());
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
auto* node = nodes_to_compile[j];
@ -401,6 +419,21 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
}
}
}
// TODO: The DML currently use some legacy approach.
// The fuse is done in FuseSubGraph function.
// Need to remove it later when DML migrate to Compile approach
if (!nodes_to_complete_fuse.empty() && type != kDmlExecutionProvider) {
for (size_t j = 0, end = nodes_to_complete_fuse.size(); j < end; j++) {
auto* node = nodes_to_complete_fuse[j];
const auto& cur_capability = capabilities_to_complete_fuse[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
}
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
}
@ -458,7 +491,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
return Status::OK();
}
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode,
int& fused_node_unique_id,
TransformLayoutFunction transform_layout_function) const {
@ -467,7 +500,7 @@ Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll,
do {
// process full graph with each EP
for (const auto& ep : providers_) {
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_mgr_,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function));
}
@ -610,7 +643,7 @@ Status GraphPartitioner::PartitionOrtFormatModel(
return Status::OK();
}
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
TransformLayoutFunction transform_layout_function, Mode mode,
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
@ -634,19 +667,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, func_mgr, *fused_kernel_registry, mode,
fused_node_unique_id, transform_layout_function));
#else
ORT_UNUSED_PARAMETER(export_dll);
ORT_THROW("Not supported in this build.");
#endif //! defined(ORT_MINIMAL_BUILD)
} else {
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
fused_node_unique_id, transform_layout_function));
}
if (!fused_kernel_registry->IsEmpty()) {
kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry);
}

View file

@ -32,7 +32,7 @@ class GraphPartitioner {
}
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
Status Partition(Graph& graph, FuncManager& func_mgr,
TransformLayoutFunction transform_layout_function,
Mode mode = Mode::kNormal,
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes = nullptr) const;
@ -41,7 +41,7 @@ class GraphPartitioner {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);
#if !defined(ORT_MINIMAL_BUILD)
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
Status PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr,
KernelRegistry& fused_kernel_registry, Mode mode,
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
#endif

View file

@ -276,9 +276,6 @@ class SessionState {
concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; }
concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; }
bool ExportDll() const noexcept { return export_fused_dll_; }
void SetExportDllFlag(bool flag) noexcept { export_fused_dll_ = flag; }
const FuncManager& GetFuncMgr() const noexcept { return fused_funcs_mgr_; }
FuncManager& GetMutableFuncMgr() noexcept { return fused_funcs_mgr_; }
@ -488,7 +485,6 @@ class SessionState {
concurrency::ThreadPool* const thread_pool_{};
concurrency::ThreadPool* const inter_op_thread_pool_{};
bool export_fused_dll_ = false;
const DataTransferManager& data_transfer_mgr_;
bool use_deterministic_compute_;

View file

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_proto_serializer.h"
namespace onnxruntime {
void GraphViewerToProto(const GraphViewer& graph_view,
ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializer,
bool include_outer_scope_args) {
graph_proto.set_name(graph_view.Name());
graph_proto.set_doc_string(graph_view.Description());
for (const auto* input_arg : graph_view.GetInputsIncludingInitializers()) {
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
}
for (const auto* output_arg : graph_view.GetOutputs()) {
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
}
for (const auto* value_info : graph_view.GetValueInfo()) {
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
}
if (include_outer_scope_args){
// add the NodeArg info for outer scope NodeArgs so we capture the type information
for (const auto& name : graph_view.GetOuterScopeNodeArgNames()) {
auto* node_arg = graph_view.GetNodeArg(name);
ORT_ENFORCE(node_arg, "Outer scope node arg name '" + name + "'was added but does not exist. ");
*(graph_proto.mutable_value_info()->Add()) = node_arg->ToProto();
}
}
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{graph_view.GetNode(node_idx)};
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
p_node->ToProto(*node_proto, /* update_subgraphs */ true);
}
if (include_initializer) {
auto& initializers = graph_view.GetAllInitializedTensors();
for (auto& it : initializers) {
auto* p_initializer = graph_proto.add_initializer();
*p_initializer = *(it.second);
}
}
}
}

View file

@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph_viewer.h"
namespace onnxruntime {
void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args);
} // namespace onnxruntime

View file

@ -271,9 +271,19 @@ bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_oute
return GetConstantInitializer(name, check_outer_scope) != nullptr;
}
bool GraphViewer::IsInitializedTensor(const std::string& name) const {
return graph_->IsInitializedTensor(name);
}
const ONNX_NAMESPACE::TensorProto* GraphViewer::GetConstantInitializer(const std::string& initializer_name,
bool check_outer_scope) const {
return graph_->GetConstantInitializer(initializer_name, check_outer_scope);
}
#if !defined(ORT_MINIMAL_BUILD)
const std::unordered_set<std::string>& GraphViewer::GetOuterScopeNodeArgNames() const noexcept {
return graph_->GetOuterScopeNodeArgNames();
}
#endif
} // namespace onnxruntime

View file

@ -20,9 +20,6 @@ class CoreMLExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
// we implement the Compile that takes FusedNodeAndGraph instances
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

View file

@ -124,33 +124,6 @@ std::vector<std::vector<NodeIndex>> DNNLExecutionProvider::GetSupportedNodes(con
return supported_node_vecs;
}
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
for (const auto* input_arg : graph.GetInputs()) {
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
}
// Add all graph's initializers to the subgraph
const auto& init_tensors = graph.GetAllInitializedTensors();
for (const auto& tensor : init_tensors) {
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
}
for (const auto* output_arg : graph.GetOutputs()) {
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
}
for (const auto* value_info : graph.GetValueInfo()) {
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
}
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
p_node->ToProto(*node_proto);
}
}
std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapability(
const GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const {
@ -269,7 +242,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
if (dump_subgraphs_) {
auto model = graph_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
graph_viewer.ToProto(*model_proto->mutable_graph(), false, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
HashValue model_hash;
int metadef_id = GenerateMetaDefId(graph_viewer, model_hash);
@ -280,39 +253,34 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
return result;
}
Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
Status DNNLExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
//follow from coreml ep's Compile
for (const auto* fused_node : fused_nodes) {
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
auto graph_body_viewer = graph_body.CreateGraphViewer();
for (auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
if (dump_subgraphs_) {
auto model = graph_body_viewer->CreateModel(*GetLogger());
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), false, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
//subgraph
auto dnnl_subgraph = std::make_unique<ort_dnnl::DnnlSubgraph>(ort_dnnl::DnnlSubgraph(*graph_body_viewer.get()));
subgraphs_.emplace(fused_node->Name(), std::move(dnnl_subgraph));
auto dnnl_subgraph = std::make_unique<ort_dnnl::DnnlSubgraph>(ort_dnnl::DnnlSubgraph(graph_body_viewer));
subgraphs_.emplace(fused_node.Name(), std::move(dnnl_subgraph));
//apply transformation to subgraph
if (enable_fusion_) {
ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node->Name()].get());
ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node.Name()].get(), graph_body_viewer);
}
//subgraph primitive
auto dnnl_subgraph_primitive = std::make_unique<ort_dnnl::DnnlSubgraphPrimitive>(*subgraphs_[fused_node->Name()].get());
auto dnnl_subgraph_primitive = std::make_unique<ort_dnnl::DnnlSubgraphPrimitive>(*subgraphs_[fused_node.Name()].get());
{
const auto& input_defs = fused_node->InputDefs();
const auto& input_defs = fused_node.InputDefs();
std::vector<std::string> onnx_input_names(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
onnx_input_names[i] = input_defs[i]->Name();
@ -320,7 +288,7 @@ Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
dnnl_subgraph_primitive->SetOrderedInputs(std::move(onnx_input_names));
}
{
const auto& output_defs = fused_node->OutputDefs();
const auto& output_defs = fused_node.OutputDefs();
std::vector<std::string> onnx_output_names(output_defs.size());
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
onnx_output_names[i] = output_defs[i]->Name();
@ -328,7 +296,7 @@ Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
dnnl_subgraph_primitive->SetOrderedOutputs(std::move(onnx_output_names));
}
subgraph_primitives_.emplace(fused_node->Name(), std::move(dnnl_subgraph_primitive));
subgraph_primitives_.emplace(fused_node.Name(), std::move(dnnl_subgraph_primitive));
NodeComputeInfo compute_info;

View file

@ -34,7 +34,7 @@ class DNNLExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
private:

View file

@ -67,14 +67,14 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod
if (has_a_zero_point) {
auto zp_A_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
auto tensor = node.Input(IN_A_ZERO_POINT);
auto& tensor = node.Input(IN_A_ZERO_POINT);
auto zp_A_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_A_mem_desc_s32, eng);
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_A_mem_s32;
}
if (has_b_zero_point) {
auto zp_B_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
auto tensor = node.Input(IN_B_ZERO_POINT);
auto& tensor = node.Input(IN_B_ZERO_POINT);
auto zp_B_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_B_mem_desc_s32, eng);
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_B_mem_s32;
}

View file

@ -164,14 +164,14 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node)
if (has_input_zero_point) {
auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
auto tensor = node.Input(INPUT_ZP);
auto& tensor = node.Input(INPUT_ZP);
auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng);
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_mem;
}
if (has_weights_zero_point) {
auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
auto tensor = node.Input(WEIGHTS_ZP);
auto& tensor = node.Input(WEIGHTS_ZP);
auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng);
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_mem;
}

View file

@ -10,31 +10,51 @@ namespace ort_dnnl {
DnnlTensor DnnlNode::empty_tensor_ = DnnlTensor("");
DnnlTensor::DnnlTensor(const NodeArg* arg) {
arg_ = arg;
if (!arg || !arg->Exists()) {
tensor_name_ = "";
} else {
tensor_name_ = arg->Name();
}
// because the passed in ort graph will be released after compile
// need to save the type/shape in dnnl IR
arg_type_ = arg->Type();
arg_type_proto_ = ONNX_NAMESPACE::TypeProto::Create();
arg_type_proto_->copy_from(arg->TypeAsProto());
}
DnnlTensor::DnnlTensor(std::string name) {
tensor_name_ = name;
arg_ = nullptr;
arg_type_ = nullptr;
arg_type_proto_ = nullptr;
}
std::string DnnlTensor::Name() const {
return tensor_name_;
}
const ONNX_NAMESPACE::TensorShapeProto* DnnlTensor::GetShape() const{
if (arg_type_proto_ == nullptr || arg_type_ == nullptr) {
return nullptr;
}
if (arg_type_proto_->value_case() != ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType) {
return nullptr;
}
auto& tensor_type = arg_type_proto_->tensor_type();
if (tensor_type.has_shape()) {
return &tensor_type.shape();
}
return nullptr;
}
dnnl::memory::dims DnnlTensor::Dim() const {
if (arg_ == nullptr) {
if (arg_type_proto_ == nullptr || arg_type_ == nullptr) {
return dnnl::memory::dims();
}
auto shape_proto = arg_->Shape();
auto* shape_proto = GetShape();
// a shape without any information
if (shape_proto == nullptr) {
LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_->Type() << ": " << arg_->Name();
LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_type_ << ": " << tensor_name_;
return dnnl::memory::dims();
}
std::vector<int64_t> shape;
@ -42,7 +62,7 @@ dnnl::memory::dims DnnlTensor::Dim() const {
for (const auto& dim : dims) {
bool has_dim_value = dim.value_case() == dim.kDimValue;
if (!has_dim_value) {
LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_->Type() << ": " << arg_->Name();
LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_type_ << ": " << tensor_name_;
shape.push_back(DNNL_RUNTIME_DIM_VAL);
} else {
shape.push_back(dim.dim_value());
@ -57,7 +77,10 @@ dnnl::memory::dims DnnlTensor::Dim() const {
}
dnnl::memory::data_type DnnlTensor::Type() const {
auto data_type = arg_->TypeAsProto()->tensor_type().elem_type();
if (arg_type_proto_ == nullptr) {
ORT_THROW("Invoke DnnlTensor's arg_type_proto_ not initialized yet.");
}
auto data_type = arg_type_proto_->tensor_type().elem_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
return dnnl::memory::data_type::undef;
@ -123,7 +146,7 @@ void DnnlTensor::RemoveConsumer(const DnnlNodeArg& arg) {
}
DnnlNode::DnnlNode(const Node* node) {
onnx_node_ = node;
since_version_ = node->SinceVersion();
name_ = node->Name();
op_type_ = node->OpType();
attr_->insert(node->GetAttributes());
@ -176,11 +199,11 @@ NodeAttributes& DnnlNode::Attributes() {
}
int DnnlNode::SinceVersion() {
return onnx_node_->SinceVersion();
return since_version_;
}
DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) {
Build();
DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) {
Build(graph_viewer);
is_dynamic_ = false;
for (auto input : GetDnnlInputs()) {
if (input->IsDynamic()) {
@ -309,19 +332,11 @@ void DnnlSubgraph::AddNode(std::unique_ptr<DnnlNode> new_node) {
dnnl_nodes_.back()->Index() = index;
}
bool DnnlSubgraph::GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value) {
return graph_viewer_.GetInitializedTensor(arg_name, value);
}
bool DnnlSubgraph::IsConstantInitializer(const std::string& arg_name, bool check_outer_scope) {
return graph_viewer_.IsConstantInitializer(arg_name, check_outer_scope);
}
void DnnlSubgraph::Build() {
void DnnlSubgraph::Build(const GraphViewer& graph_viewer) {
//establish nodes, tensors and nodeargs
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
const auto* node(graph_viewer.GetNode(node_indices[i]));
AddNode(std::make_unique<DnnlNode>(node));
auto dnnl_node = dnnl_nodes_.back().get();
std::vector<DnnlTensor*> inputs;
@ -361,15 +376,15 @@ void DnnlSubgraph::Build() {
//graph inputs including initializers and outputs can be deleted by graph transformation (eg, gelu fusion)
//delete unneeded inputs don't affect onnxruntime passing them as input data handle
//delete unneeded outputs will cause ep to output to fewer data handles then expected
for (const auto* node_arg : graph_viewer_.GetInputsIncludingInitializers()) {
for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) {
inputs_.push_back(dnnl_tensors_[node_arg->Name()].get());
}
for (const auto* node_arg : graph_viewer_.GetOutputs()) {
for (const auto* node_arg : graph_viewer.GetOutputs()) {
outputs_.push_back(dnnl_tensors_[node_arg->Name()].get());
}
for (auto& initializer : graph_viewer_.GetAllInitializedTensors()) {
for (auto& initializer : graph_viewer.GetAllInitializedTensors()) {
auto& name = initializer.first;
initializers_.push_back(dnnl_tensors_[name].get());
}

View file

@ -55,8 +55,12 @@ class DnnlTensor {
void RemoveConsumer(const DnnlNodeArg& arg);
private:
const ONNX_NAMESPACE::TensorShapeProto* GetShape() const;
std::string tensor_name_;
const NodeArg* arg_;
ONNX_NAMESPACE::DataType arg_type_;
std::unique_ptr<ONNX_NAMESPACE::TypeProto> arg_type_proto_;
//a tensor can have no producer (input.initializer) or no consumer (output for subgraph)
DnnlNodeArg producer_;
std::vector<DnnlNodeArg> consumers_;
@ -79,7 +83,7 @@ class DnnlNode {
int SinceVersion();
private:
const Node* onnx_node_ = nullptr;
int since_version_;
std::vector<DnnlTensor*> inputs_;
std::vector<DnnlTensor*> outputs_;
static DnnlTensor empty_tensor_;
@ -101,7 +105,7 @@ class DnnlSubgraph {
std::vector<DnnlTensor*> GetDnnlOutputs();
std::vector<DnnlTensor*> GetDnnlInitializers();
// build the subgraph IR
void Build();
void Build(const GraphViewer& graph_viewer);
//check whether the subgraph is dynamic
void TopoSort();
bool IsDynamic();
@ -110,9 +114,6 @@ class DnnlSubgraph {
void AddTensor(std::unique_ptr<DnnlTensor> new_tensor);
void RemoveTensor(const std::string& tensor_name);
bool GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value);
bool IsConstantInitializer(const std::string& arg_name, bool check_outer_scope);
private:
//graph owns all nodes
std::vector<std::unique_ptr<DnnlNode>> dnnl_nodes_;
@ -122,7 +123,6 @@ class DnnlSubgraph {
std::vector<DnnlTensor*> inputs_;
std::vector<DnnlTensor*> outputs_; //output should never get deleted from graph transformation
std::vector<DnnlTensor*> initializers_;
const GraphViewer& graph_viewer_;
bool is_dynamic_;
};
} // namespace ort_dnnl

View file

@ -446,7 +446,7 @@ bool DnnlSubgraphPrimitive::HasMemory(std::string memory_name, dnnl::memory::des
return false;
}
void DnnlSubgraphPrimitive::SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) {
void DnnlSubgraphPrimitive::SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) {
if (always_copy_output) {
outputs_are_always_copied_.insert(tensor.Name());
}

View file

@ -59,7 +59,7 @@ class DnnlSubgraphPrimitive {
//set memory to a tensor (output)
//if always_copy_output is true a copy of the memory will be made when the output is leaving the subgraph.
//is_scalar is true to indicate a scalar output in order to allocate the correct onnxruntime output buffer
void SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false);
void SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false);
void SetMemory(std::string memory_name, dnnl::memory mem);
void SetInitializer(std::string memory_name, dnnl::memory mem);
dnnl::memory::desc GetOutputInfo(std::string name);

View file

@ -12,12 +12,12 @@ namespace onnxruntime {
namespace ort_dnnl {
//apply all transformation rules in order
void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph) {
void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
ConvRelu(subgraph);
MatMulAdd(subgraph);
Gelu(subgraph);
FastGelu(subgraph);
RemoveMatMulIntegerZP(subgraph);
Gelu(subgraph, onnx_subgraph_viewer);
FastGelu(subgraph, onnx_subgraph_viewer);
RemoveMatMulIntegerZP(subgraph, onnx_subgraph_viewer);
}
//resolve a fusion by replacing old_indices nodes with a new_node
@ -162,13 +162,13 @@ bool IsScalar(const DnnlTensor& input_arg) {
return dim_size == 0 || (dim_size == 1 && dim[0] == 1);
}
bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value) {
bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value) {
if (!IsScalar(input_arg)) {
return false;
}
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (!subgraph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
if (!onnx_subgraph_viewer.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
return false;
}
@ -234,7 +234,7 @@ DnnlNode* FirstParentByType(DnnlNode* node, const std::string& parent_type) {
After Fusion:
[root]--> Gelu ==>
*/
void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
static int gelu_index = 0;
//traverse with max index as there will be empty nodes due to fusion
size_t max_index = subgraph.GetMaxNodeIndex();
@ -249,8 +249,8 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
// Check second input is sqrt(2)
// Some Bert models uses this approximation of SQRT2 in the Gelu function
float approximated_sqrt_two = 1.4142099618911743f;
if (!IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), approximated_sqrt_two) &&
!IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), static_cast<float>(M_SQRT2))) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), approximated_sqrt_two) &&
!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), static_cast<float>(M_SQRT2))) {
continue;
}
@ -275,7 +275,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
}
bool is_add_input0 = add_node->Input(0).Name() == erf_node->Output(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
continue;
}
@ -304,7 +304,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
if (!(is_mul2_input0 ^ is_mul2_input1)) {
is_pattern_1 = false;
}
if (is_pattern_1 && !IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) {
if (is_pattern_1 && !IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) {
is_pattern_1 = false;
}
if (is_pattern_1 && !IsNodeFusable(subgraph, mul2_node)) {
@ -329,7 +329,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
ORT_THROW("Invalid Mul node");
}
bool is_mul2_first_input = mul2_node->Input(0).Name() == mul1_node->Output(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) {
continue;
}
}
@ -369,14 +369,14 @@ The formula corresponding to Gelu activation subgraph :
where x is the input.
*/
void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph) {
void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
static int fastgelu_index = 0;
//traverse with max index as there will be empty nodes due to fusion
size_t max_index = subgraph.GetMaxNodeIndex();
for (size_t index = 0; index < max_index; index++) {
auto dnnl_node = subgraph.GetDnnlNode(index);
if (!FastGeluFirstFormula(subgraph, dnnl_node, fastgelu_index)) {
FastGeluSecondFormula(subgraph, dnnl_node, fastgelu_index);
if (!FastGeluFirstFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index)) {
FastGeluSecondFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index);
}
}
}
@ -389,7 +389,7 @@ The formula corresponding to Gelu activation subgraph :
where x is the input.
*/
bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* mul1_node, int& fastgelu_index) {
bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* mul1_node, int& fastgelu_index) {
std::vector<size_t> gelu_indices;
//----------mul(0.44715)------------------
if (mul1_node == nullptr || mul1_node->OpType() != "Mul") {
@ -398,7 +398,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
int32_t mul1_input_index = -1;
const float mul_val = 0.044715f;
for (auto i = 0; i < 2; i++) {
if (IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(i), mul_val)) {
if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(i), mul_val)) {
mul1_input_index = i;
break;
}
@ -424,7 +424,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
return false;
}
bool is_add_input0 = mul2_node->Output(0).Name() == add1_node->Input(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
return false;
}
@ -450,7 +450,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
int32_t mul4_input_index = -1;
const float mul4_val = 0.7978845834732056f;
for (auto i = 0; i < 2; i++) {
if (IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(i), mul4_val)) {
if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(i), mul4_val)) {
mul4_input_index = i;
break;
}
@ -464,7 +464,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
auto tanh_node = mul3_node->Output(0).GetConsumers()[0].GetNode();
int32_t x_input_index = (mul1_input_index == 0) ? 1 : 0;
if (FastGeluFormulaCommon(subgraph, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) {
if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) {
if (debug_log_) {
LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (first formula)";
}
@ -481,15 +481,15 @@ The formula corresponding to Gelu activation subgraph :
where x is the input.
*/
void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* pow_node, int& fastgelu_index) {
void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* pow_node, int& fastgelu_index) {
std::vector<size_t> gelu_indices;
//---------Pow-------------------
if (pow_node == nullptr || pow_node->OpType() != "Pow") {
return;
}
auto pow_exponent = pow_node->Input(1);
if (!IsInitilizedWithExpectedValue(subgraph, pow_exponent, 3.0f)) {
auto& pow_exponent = pow_node->Input(1);
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, pow_exponent, 3.0f)) {
return;
}
@ -505,7 +505,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
float fastgelu_muliplyer = 0.044714998453855515f;
bool is_mul1_input0 = pow_node->Output(0).Name() == mul1_node->Input(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) {
return;
}
if (!IsNodeFusable(subgraph, mul1_node)) {
@ -531,7 +531,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
// constant is sqrt(2/pi)
float fastgelu_sqrt_2_div_pi = 0.7978845834732056f;
bool is_mul2_input0 = add1_node->Output(0).Name() == mul2_node->Input(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) {
return;
}
@ -543,7 +543,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
//----------Tanh------------------
auto tanh_node = mul2_node->Output(0).GetConsumers()[0].GetNode();
// since the first node is pow the x_input_index is always 0
if (FastGeluFormulaCommon(subgraph, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) {
if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) {
if (debug_log_) {
LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (second formula)";
}
@ -559,7 +559,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))),
where x is the input.
*/
bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index) {
bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index) {
//----------Tanh------------------
if (tanh_node == nullptr || tanh_node->OpType() != "Tanh") {
return false;
@ -575,7 +575,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod
return false;
}
bool is_add2_input0 = tanh_node->Output(0).Name() == add2_node->Input(0).Name();
if (!IsInitilizedWithExpectedValue(subgraph, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) {
return false;
}
@ -607,7 +607,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod
if (!(is_mul_input0 ^ is_mul_input1)) {
return false;
}
if (!IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) {
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) {
return false;
}
if (!IsNodeFusable(subgraph, prev_mul4_node)) {
@ -748,7 +748,7 @@ void DnnlGraphTransformer::MatMulAdd(DnnlSubgraph& subgraph) {
}
}
void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) {
void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
size_t max_index = subgraph.GetMaxNodeIndex();
for (size_t index = 0; index < max_index; index++) {
auto dnnl_node = subgraph.GetDnnlNode(index);
@ -763,9 +763,9 @@ void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) {
continue;
}
auto b_zero_point = dnnl_node->Input(3);
auto& b_zero_point = dnnl_node->Input(3);
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (!subgraph.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) {
if (!onnx_subgraph_viewer.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) {
continue;
}

View file

@ -9,7 +9,10 @@ namespace ort_dnnl {
class DnnlGraphTransformer {
public:
void Apply(DnnlSubgraph& subgraph);
// The passed in onnx subgraph viewer is only valid during "Compile" phase,
// so keep a reference to that onnx subgraph in DnnlSubgraph is risky.
// passed in the onnx subgraph viewer explicitly to make sure we manage the lifetime correctly.
void Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer_);
DnnlGraphTransformer() {
const std::string debug_log_env = onnxruntime::GetEnvironmentVar("ORT_DNNL_DEBUG_LOG");
if (!debug_log_env.empty()) {
@ -18,15 +21,15 @@ class DnnlGraphTransformer {
}
private:
void Gelu(DnnlSubgraph& subgraph);
void FastGelu(DnnlSubgraph& subgraph);
bool FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index);
void FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index);
bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index);
bool IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value);
void Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
void FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
bool FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index);
void FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index);
bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index);
bool IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value);
void ConvRelu(DnnlSubgraph& subgraph);
void MatMulAdd(DnnlSubgraph& subgraph);
void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph);
void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
// This function checks a few things
// - the node in question has a single output
// - The output of the node is only consumed by a one other node

View file

@ -660,34 +660,6 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
return true;
}
// Convert GraphViewer graph to GraphProto
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
for (const auto* input_arg : graph.GetInputs()) {
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
}
// Add all graph's initializers to the subgraph
const auto& init_tensors = graph.GetAllInitializedTensors();
for (const auto& tensor : init_tensors) {
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
}
for (const auto* output_arg : graph.GetOutputs()) {
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
}
for (const auto* value_info : graph.GetValueInfo()) {
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
}
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
p_node->ToProto(*node_proto);
}
}
std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const {
std::unordered_set<size_t> node_set;
node_set.reserve(graph_nodes_index.size());
@ -913,7 +885,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
std::vector<std::unique_ptr<ComputeCapability>> result;
auto model = graph_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
graph_viewer.ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string onnx_string_buffer;
model_proto->SerializeToString(onnx_string_buffer);
@ -1012,30 +984,24 @@ bool get_input_output_names(const GraphViewer& graph,
return no_input_shape;
}
Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
migraphx::onnx_options options;
bool no_input_shape = false;
for (const auto& fused_node : fused_nodes) {
for (const auto& fused_node_graph : fused_nodes) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
// map parameter input name to index
std::unordered_map<std::string, std::size_t> input_name_index;
const auto& input_defs = fused_node->InputDefs();
const auto& input_defs = fused_node.InputDefs();
input_name_index.reserve(input_defs.size());
for (std::size_t i = 0; i < input_defs.size(); ++i) {
input_name_index[input_defs[i]->Name()] = i;
}
// Reconstruct graph proto from fused node's function body
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
auto graph_body_viewer = graph_body.CreateGraphViewer();
auto model = graph_body_viewer->CreateModel(*GetLogger());
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string onnx_string_buffer;
model_proto->SerializeToString(onnx_string_buffer);
@ -1069,11 +1035,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
}
// compile the program
map_progs_[fused_node->Name()] = prog;
map_progs_[fused_node.Name()] = prog;
map_onnx_string_[fused_node->Name()] = onnx_string_buffer;
map_input_index_[fused_node->Name()] = input_name_index;
map_no_input_shape_[fused_node->Name()] = no_input_shape;
map_onnx_string_[fused_node.Name()] = onnx_string_buffer;
map_input_index_[fused_node.Name()] = input_name_index;
map_no_input_shape_[fused_node.Name()] = no_input_shape;
NodeComputeInfo compute_info;
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();

View file

@ -45,8 +45,8 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const override;
Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

View file

@ -23,9 +23,6 @@ class NnapiExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph_view,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
// we implement the Compile that takes FusedNodeAndGraph instances
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

View file

@ -118,6 +118,13 @@ class NupharExecutionProvider : public IExecutionProvider {
return iter->second.get();
}
onnxruntime::IExecutionProvider::FusionStyle GetFusionStyle() const override {
// existing EPs use this mode so default to it.
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
// FilteredGraphViewer
return onnxruntime::IExecutionProvider::FusionStyle::Function;
}
private:
void CreateTVMTarget();

View file

@ -26,7 +26,9 @@ void BackendManager::ReleaseGlobalContext() {
g_global_context.reset();
}
BackendManager::BackendManager(const Node* fused_node, const logging::Logger& logger) {
BackendManager::BackendManager(const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger) {
auto prec_str = GetGlobalContext().precision_str;
if (prec_str == "FP32") {
subgraph_context_.precision = InferenceEngine::Precision::FP32;
@ -38,14 +40,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
// Save the indexes of graph inputs among fused_node's inputDefs
// (which also contains initializers).
auto node_input_defs = fused_node->InputDefs();
auto node_input_defs = fused_node.InputDefs();
int i = 0;
for (auto idef : node_input_defs) {
subgraph_context_.input_names.insert({idef->Name(), i});
i++;
}
auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs();
auto graph_inputs = subgraph.GetInputs();
for (auto input : graph_inputs) {
if (GetGlobalContext().device_type.find("MYRIAD") != std::string::npos) {
auto shape = input->Shape();
@ -63,14 +65,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
subgraph_context_.input_indexes.push_back(index);
}
auto graph_outputs_defs = fused_node->OutputDefs();
auto graph_outputs_defs = fused_node.OutputDefs();
i = 0;
for (auto output_def : graph_outputs_defs) {
subgraph_context_.output_names.insert({output_def->Name(), i});
i++;
}
subgraph_context_.subgraph_name = fused_node->Name();
model_proto_ = GetModelProtoFromFusedNode(fused_node, logger);
subgraph_context_.subgraph_name = fused_node.Name();
model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
if (ModelHasBatchedInputs(*model_proto_) &&
GetGlobalContext().is_wholly_supported_graph &&
@ -81,7 +83,7 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
concrete_backend_ = BackendFactory::MakeBackend(*model_copy, GetGlobalContext(), subgraph_context_);
subgraph_context_.has_dynamic_input_shape = false;
} else if (ModelHasSymbolicInputDims(fused_node)) {
} else if (ModelHasSymbolicInputDims(subgraph)) {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims. Defering backend initialization";
subgraph_context_.has_dynamic_input_shape = true;
} else {
@ -123,9 +125,9 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod
return has_batched_inputs;
}
bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const {
bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const {
bool has_sym_dims = false;
auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs();
auto graph_inputs = subgraph.GetInputs();
for (auto input : graph_inputs) {
if (input->Shape() == nullptr) {
has_sym_dims = true;
@ -145,25 +147,23 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_no
}
std::unique_ptr<ONNX_NAMESPACE::ModelProto>
BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node,
BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger) const {
const auto* node_function = fused_node->GetFunctionBody();
const std::string& name = fused_node->Name();
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", name);
const onnxruntime::Graph& node_subgraph = node_function->Body();
auto model = node_subgraph.CreateGraphViewer()->CreateModel(logger);
auto model = subgraph.CreateModel(logger);
auto model_proto = model->ToProto();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
*model_proto->mutable_graph() = *node_subgraph.ToGraphProto();
subgraph.ToProto(*model_proto->mutable_graph(), true, true);
#ifndef NDEBUG
if (openvino_ep::backend_utils::IsDebugEnabled()) {
const std::string& name = fused_node.Name();
std::fstream dump(name + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
#else
ORT_UNUSED_PARAMETER(fused_node);
#endif
return model_proto;

View file

@ -13,7 +13,7 @@ namespace openvino_ep {
// Singleton class that manages all the backends
class BackendManager {
public:
BackendManager(const onnxruntime::Node* fused_node, const logging::Logger& logger);
BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger);
void Compute(Ort::CustomOpApi api, OrtKernelContext* context);
void ShutdownBackendManager();
static GlobalContext& GetGlobalContext();
@ -21,8 +21,8 @@ class BackendManager {
private:
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(
const onnxruntime::Node* fused_node, const logging::Logger& logger) const;
bool ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const;
const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const;
bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const;
bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const;
std::shared_ptr<ONNX_NAMESPACE::ModelProto>

View file

@ -131,18 +131,21 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const
}
common::Status OpenVINOExecutionProvider::Compile(
const std::vector<onnxruntime::Node*>& fused_nodes,
const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto& fused_node : fused_nodes) {
NodeComputeInfo compute_info;
#if defined (OV_API_20)
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true;
# else
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false;
#endif
for (const auto& fused_node_graph : fused_nodes) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
std::shared_ptr<openvino_ep::BackendManager> backend_manager = std::make_shared<openvino_ep::BackendManager>(fused_node, *GetLogger());
NodeComputeInfo compute_info;
#if defined(OV_API_20)
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true;
#else
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false;
#endif
std::shared_ptr<openvino_ep::BackendManager> backend_manager = std::make_shared<openvino_ep::BackendManager>(fused_node, graph_body_viewer, *GetLogger());
compute_info.create_state_func =
[backend_manager](ComputeContext* context, FunctionState* state) {

View file

@ -151,8 +151,8 @@ class OpenVINOExecutionProvider : public IExecutionProvider {
GetCapability(const GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const override;
Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
const void* GetExecutionHandle() const noexcept override {
return nullptr;

View file

@ -256,30 +256,24 @@ RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
return result;
}
common::Status RknpuExecutionProvider::Compile(
const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
// Reconstruct graph proto from fused node's function body
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME,
common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(),
common::Status RknpuExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
onnxruntime::Model model(graph_body_viewer.Name(), true, ModelMetaData(),
PathString(),
IOnnxRuntimeOpSchemaRegistryList(),
graph_body.DomainToVersionMap(),
graph_body_viewer.DomainToVersionMap(),
std::vector<ONNX_NAMESPACE::FunctionProto>(),
*GetLogger());
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
*(model_proto.mutable_graph()) = graph_body.ToGraphProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
// Build map from input name to its index in input definitions
std::unordered_map<std::string, int> input_map;
const auto& input_defs = fused_node->InputDefs();
const auto& input_defs = fused_node.InputDefs();
input_map.reserve(input_defs.size());
for (int i = 0, end = input_defs.size(); i < end; ++i) {
input_map[input_defs[i]->Name()] = i;
@ -287,15 +281,15 @@ common::Status RknpuExecutionProvider::Compile(
// Build map from output name to its index in output definitions
std::unordered_map<std::string, int> output_map;
const auto& output_defs = fused_node->OutputDefs();
const auto& output_defs = fused_node.OutputDefs();
output_map.reserve(output_defs.size());
for (int i = 0, end = output_defs.size(); i < end; ++i) {
output_map[output_defs[i]->Name()] = i;
}
model_proto_[fused_node->Name()] = model_proto;
input_info_[fused_node->Name()] = input_map;
output_info_[fused_node->Name()] = output_map;
model_proto_[fused_node.Name()] = model_proto;
input_info_[fused_node.Name()] = input_map;
output_info_[fused_node.Name()] = output_map;
NodeComputeInfo compute_info;
compute_info.create_state_func = [&](ComputeContext* context,

View file

@ -20,7 +20,7 @@ class RknpuExecutionProvider : public IExecutionProvider {
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;

View file

@ -291,17 +291,11 @@ std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider::GetCapabilit
const std::vector<const KernelRegistry*>& kernel_registries) const {
return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_registries);
}
// !!! This API will be deprecated soon.
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
return g_host->IExecutionProvider__Compile(this, fused_nodes, node_compute_funcs);
}
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::string& dll_path) {
return g_host->IExecutionProvider__Compile(this, fused_nodes, dll_path);
}
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs);

View file

@ -218,8 +218,9 @@ struct ProviderHost {
virtual void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) = 0;
virtual std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) = 0;
//!!! This API will be deprecated soon
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::vector<NodeComputeInfo>& node_compute_funcs) = 0;
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::string& dll_path) = 0;
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs, std::vector<NodeComputeInfo>& node_compute_funcs) = 0;
virtual int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) = 0;
@ -286,6 +287,8 @@ struct ProviderHost {
#endif
// TypeProto
virtual std::unique_ptr<ONNX_NAMESPACE::TypeProto> TypeProto__construct() = 0;
virtual void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) = 0;
virtual const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) = 0;
virtual ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) = 0;
@ -368,6 +371,7 @@ struct ProviderHost {
virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0;
virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0;
virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0;
virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0;
virtual bool TensorProto_DataType_IsValid(int value) = 0;
@ -670,6 +674,8 @@ struct ProviderHost {
virtual const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0;
virtual const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0;
virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0;
// Path
virtual PathString Path__ToPathString(const Path* p) noexcept = 0;

View file

@ -159,6 +159,8 @@ struct TensorProto final {
static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); }
void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); }
TensorProto() = delete;
TensorProto(const TensorProto&) = delete;
};
@ -240,6 +242,8 @@ struct TypeProto_Sequence final {
};
struct TypeProto final {
static std::unique_ptr<TypeProto> Create() { return g_host->TypeProto__construct(); }
const TypeProto_Tensor& tensor_type() const { return g_host->TypeProto__tensor_type(this); }
TypeProto_Tensor* mutable_tensor_type() { return g_host->TypeProto__mutable_tensor_type(this); }
@ -268,7 +272,10 @@ struct TypeProto final {
ValueCase value_case() const { return ValueCase(g_host->TypeProto__value_case(this)); }
PROVIDER_DISALLOW_ALL(TypeProto)
void copy_from(const TypeProto* other) { return g_host->TypeProto__CopyFrom(this, other); }
TypeProto() = delete;
TypeProto(const TypeProto&) = delete;
};
struct ValueInfoProto final {
@ -705,6 +712,8 @@ struct GraphViewer final {
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); }
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); }
void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); }
GraphViewer() = delete;
GraphViewer(const GraphViewer&) = delete;
void operator=(const GraphViewer&) = delete;

View file

@ -516,34 +516,6 @@ Status TensorrtExecutionProvider::SetComputeStream(void* stream) {
return Status::OK();
}
// Convert GraphViewer graph to GraphProto
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
for (const auto* input_arg : graph.GetInputs()) {
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
}
// Add all graph's initializers to the subgraph
const auto& init_tensors = graph.GetAllInitializedTensors();
for (const auto& tensor : init_tensors) {
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
}
for (const auto* output_arg : graph.GetOutputs()) {
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
}
for (const auto* value_info : graph.GetValueInfo()) {
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
}
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
p_node->ToProto(*node_proto);
}
}
std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const {
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
std::unordered_set<size_t> node_set;
@ -799,7 +771,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
auto graph_viewer = graph_build.CreateGraphViewer();
auto model = graph_viewer->CreateModel(*GetLogger());
auto model_proto = model->ToProto();
ToGraphProtoInternal(*graph_viewer, *model_proto->mutable_graph());
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
@ -982,19 +954,19 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
// Consolidate supported node list
if (supported_nodes_vector.size() > 1) {
nodes_vector.clear();
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
}
}
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
} else {
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
supported_nodes_vector = consolidated_supported_nodes_vector;
nodes_vector.clear();
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
}
}
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
} else {
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
supported_nodes_vector = consolidated_supported_nodes_vector;
}
}
// Construct subgraph capability from node list
@ -1020,12 +992,14 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
return result;
}
common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
for (auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
// Build map from input name to its index in input definitions
std::unordered_map<std::string, size_t> input_map;
const auto& input_defs = fused_node->InputDefs();
const auto& input_defs = fused_node.InputDefs();
input_map.reserve(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
input_map[input_defs[i]->Name()] = i;
@ -1033,29 +1007,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Build map from output name to its index in output definitions
std::unordered_map<std::string, size_t> output_map;
const auto& output_defs = fused_node->OutputDefs();
const auto& output_defs = fused_node.OutputDefs();
output_map.reserve(output_defs.size());
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
output_map[output_defs[i]->Name()] = i;
}
// Reconstruct graph proto from fused node's function body
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
auto graph_body_viewer = graph_body.CreateGraphViewer();
auto model = graph_body_viewer->CreateModel(*GetLogger());
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
model_proto->SerializeToString(string_buf);
if (dump_subgraphs_) {
// Dump TensorRT subgraphs
std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
@ -1123,7 +1091,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
// Set precision flags
std::string trt_node_name_with_precision = fused_node->Name();
std::string trt_node_name_with_precision = fused_node.Name();
if (fp16_enable_ && int8_enable_) {
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
trt_node_name_with_precision += "_fp16_int8";
@ -1140,7 +1108,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Set DLA
if (fp16_enable_ || int8_enable_) {
if (dla_enable_ && dla_core_ >= 0) { //DLA can only run with FP16 and INT8
if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8
int number_of_dla_core = trt_builder->getNbDLACores();
if (number_of_dla_core == 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
@ -1205,7 +1173,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
if (trt_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build execution context for fused node: " + fused_node->Name());
"TensorRT EP could not build execution context for fused node: " + fused_node.Name());
}
}
@ -1220,7 +1188,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
}
// If (1) engine cache enable is not set or (2) first time enable engine cache and no engine cache is present,
// build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
// be built at runtime
@ -1231,7 +1198,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node->Name());
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
}
}
@ -1242,7 +1209,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build engine for fused node: " + fused_node->Name());
"TensorRT EP could not build engine for fused node: " + fused_node.Name());
}
if (engine_cache_enable_)
@ -1252,7 +1219,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
if (trt_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build execution context for fused node: " + fused_node->Name());
"TensorRT EP could not build execution context for fused node: " + fused_node.Name());
}
}
}
@ -1280,15 +1247,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
// Save engine, context and input/output info to map
parsers_.emplace(fused_node->Name(), std::move(trt_parser));
engines_.emplace(fused_node->Name(), std::move(trt_engine));
contexts_.emplace(fused_node->Name(), std::move(trt_context));
builders_.emplace(fused_node->Name(), std::move(trt_builder));
networks_.emplace(fused_node->Name(), std::move(trt_network));
input_info_[fused_node->Name()].push_back(input_indexes);
output_info_[fused_node->Name()].push_back(output_indexes);
output_info_[fused_node->Name()].push_back(output_types);
input_shape_ranges_[fused_node->Name()] = input_shape_ranges;
parsers_.emplace(fused_node.Name(), std::move(trt_parser));
engines_.emplace(fused_node.Name(), std::move(trt_engine));
contexts_.emplace(fused_node.Name(), std::move(trt_context));
builders_.emplace(fused_node.Name(), std::move(trt_builder));
networks_.emplace(fused_node.Name(), std::move(trt_network));
input_info_[fused_node.Name()].push_back(input_indexes);
output_info_[fused_node.Name()].push_back(output_indexes);
output_info_[fused_node.Name()].push_back(output_types);
input_shape_ranges_[fused_node.Name()] = input_shape_ranges;
// Create function state
// TODO: remove default capture
@ -1310,7 +1277,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
compute_info.release_state_func = [](FunctionState state) {
if (state) {
// Serialize and save engine to cache
//
//
// Note: only save engine to file if engine cache enable is set and engine is being updated due to input shape changed
// or engine file is not previously existed
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
@ -1325,7 +1292,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
delete static_cast<TensorrtFuncState*>(state);
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt"));
"TensorRT EP could not call engine encryption function encrypt"));
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
@ -1341,7 +1308,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
}
}
delete static_cast<TensorrtFuncState*>(state);
}
};

View file

@ -127,7 +127,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
int GetDeviceId() const { return device_id_; }
common::Status Compile(const std::vector<Node*>& fused_nodes,
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;

View file

@ -8,6 +8,7 @@
#include "core/framework/tensorprotoutils.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/compute_capability.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/platform/env.h"
#include "core/graph/model.h"
@ -101,35 +102,33 @@ TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer,
return result;
}
common::Status TvmExecutionProvider::Compile(const std::vector<Node*>& nodes,
common::Status TvmExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
printOptions();
for (auto* fused_node : nodes) {
auto func_body = fused_node->GetFunctionBody();
if (!func_body)
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
const std::string func_name = fused_node->Name();
const Graph& node_graph = func_body->Body();
Model model(node_graph.Name(), true, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), node_graph.DomainToVersionMap(),
for (auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
const std::string func_name = fused_node.Name();
Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), graph_body_viewer.DomainToVersionMap(),
std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
*(model_proto.mutable_graph()) = node_graph.ToGraphProto();
//TVM EP is using static lib approach, so invoke serializer directly.
GraphViewerToProto(graph_body_viewer, *model_proto.mutable_graph(), true, true);
auto opset = model_proto.add_opset_import();
opset->set_domain(kOnnxDomain);
opset->set_version(node_graph.DomainToVersionMap().at(kOnnxDomain));
opset->set_version(graph_body_viewer.DomainToVersionMap().at(kOnnxDomain));
std::string onnx_model_str;
model_proto.SerializeToString(&onnx_model_str);
compilers_[func_name] = std::make_shared<Compiler>(std::move(onnx_model_str),
fused_node->ModelPath().ToPathString(),
fused_node.ModelPath().ToPathString(),
int(opset->version()));
InputsInfoMap all_input_shapes;
auto mod = compileModel(func_name, node_graph, all_input_shapes);
auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes);
std::vector<DLTensor> output_tensors;
prepareOutputTensors(mod, output_tensors, node_graph.GetOutputs().size());
prepareOutputTensors(mod, output_tensors, graph_body_viewer.GetOutputs().size());
runners_[func_name] = std::make_shared<Runner>(options_, mod, all_input_shapes, output_tensors);
@ -168,15 +167,15 @@ void TvmExecutionProvider::printOptions() {
}
std::shared_ptr<TvmModule> TvmExecutionProvider::compileModel(const std::string& func_name,
const Graph& graph,
const GraphViewer& graph_viewer,
InputsInfoMap& all_input_shapes) {
all_input_shapes.clear();
TVMTensorShapes input_shapes;
if (options_.freeze_weights) {
setInputShapesForFreezedNN(graph, input_shapes, all_input_shapes);
setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes);
} else {
setInputShapesForUnfreezedNN(graph, input_shapes, all_input_shapes);
setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes);
}
std::shared_ptr<TvmModule> mod = compilers_[func_name]->operator()(options_, input_shapes);
@ -184,14 +183,14 @@ std::shared_ptr<TvmModule> TvmExecutionProvider::compileModel(const std::string&
return mod;
}
void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph,
void TvmExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer,
TVMTensorShapes& input_shapes,
InputsInfoMap& all_input_shapes) {
const std::vector<const NodeArg*>& all_nodes = graph.GetInputsIncludingInitializers();
const std::vector<const NodeArg*>& all_nodes = graph_viewer.GetInputsIncludingInitializers();
size_t indx = 0;
for (const auto* node : all_nodes) {
if(!graph.IsInitializedTensor(node->Name())) {
if (!graph_viewer.IsInitializedTensor(node->Name())) {
TensorShapeVector shape = getInputShape(node);
all_input_shapes[indx++] = shape;
input_shapes.emplace_back(shape);
@ -199,16 +198,16 @@ void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph,
}
}
void TvmExecutionProvider::setInputShapesForUnfreezedNN(const Graph& graph,
void TvmExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer,
TVMTensorShapes& input_shapes,
InputsInfoMap& all_input_shapes) {
const std::vector<const NodeArg*>& all_nodes = graph.GetInputsIncludingInitializers();
const std::vector<const NodeArg*>& all_nodes = graph_viewer.GetInputsIncludingInitializers();
size_t indx = 0;
for (const auto* node : all_nodes) {
TensorShapeVector shape = getInputShape(node);
all_input_shapes[indx++] = shape;
if(!graph.IsInitializedTensor(node->Name())) {
if (!graph_viewer.IsInitializedTensor(node->Name())) {
input_shapes.emplace_back(shape);
}
}

View file

@ -40,7 +40,7 @@ class TvmExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
@ -48,10 +48,10 @@ class TvmExecutionProvider : public IExecutionProvider {
private:
void printOptions();
std::shared_ptr<tvm::TvmModule> compileModel(const std::string& func_name,
const Graph& graph,
const GraphViewer& graph_viewer,
InputsInfoMap& inputs_info);
void setInputShapesForFreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
void setInputShapesForUnfreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
TensorShapeVector getInputShape(const NodeArg* node);
TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto);
void prepareOutputTensors(const std::shared_ptr<tvm::TvmModule>& mod, std::vector<DLTensor>& output_tensors, size_t num);

View file

@ -32,28 +32,22 @@
namespace onnxruntime {
namespace vitisai_ep {
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node,
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::GraphViewer& graph_viewer,
const logging::Logger& logger) {
const auto* node_function = fused_node->GetFunctionBody();
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ",
fused_node->Name());
const Graph& node_subgraph = node_function->Body();
onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{},
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(),
onnxruntime::Model model{graph_viewer.Name(), true, ModelMetaData{}, PathString{},
IOnnxRuntimeOpSchemaRegistryList{}, graph_viewer.DomainToVersionMap(),
std::vector<ONNX_NAMESPACE::FunctionProto>(), logger};
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
*(model_proto.mutable_graph()) = node_subgraph.ToGraphProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
return model_proto;
}
VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
const onnxruntime::Node* fused_node,
const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& graph_viewer,
const std::string& backend_type,
const std::string& export_runtime_module,
const std::string& load_runtime_module,
@ -68,7 +62,7 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
allocator_ = context->allocator_handle;
name_ = context->node_name;
model_proto_ = GetModelProtoFromFusedNode(fused_node, *GetLogger());
model_proto_ = GetModelProtoFromFusedNode(graph_viewer, *GetLogger());
std::istringstream model_stream{model_proto_.SerializeAsString()};
xg_ = pyxir::onnx::import_onnx_model(model_stream);
@ -78,12 +72,12 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
if (load_runtime_module_.empty()) {
pyxir::partition(xg_, std::vector<std::string>{backend_type_}, "");
auto input_defs = fused_node->InputDefs();
auto input_defs = fused_node.InputDefs();
for (auto idef : input_defs) {
in_tensor_names_.push_back(idef->Name());
}
auto output_defs = fused_node->OutputDefs();
auto output_defs = fused_node.OutputDefs();
for (auto odef : output_defs) {
out_tensor_names_.push_back(odef->Name());
}

View file

@ -29,7 +29,8 @@ namespace vitisai_ep {
class VitisAICustomOp {
public:
VitisAICustomOp(const ComputeContext* context,
const onnxruntime::Node* fused_node,
const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& graph_viewer,
const std::string& backend_type,
const std::string& export_runtime_module,
const std::string& load_runtime_module,

View file

@ -273,12 +273,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
}
common::Status VitisAIExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto& fused_node : fused_nodes) {
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
NodeComputeInfo compute_info;
compute_info.create_state_func = [this, fused_node, logger = GetLogger()](ComputeContext* context, FunctionState* state) {
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, backend_type_, export_runtime_module_,
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, graph_body_viewer, backend_type_, export_runtime_module_,
load_runtime_module_, logger);
*state = p;
return 0;

View file

@ -29,7 +29,7 @@ class VitisAIExecutionProvider : public IExecutionProvider {
int GetDeviceId() const { return device_id_; }
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
private:

View file

@ -930,7 +930,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Do partitioning based on execution providers' capability.
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph,
session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP, mode));
@ -1153,7 +1153,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
std::unordered_map<std::string, HashValue> compiled_kernel_hashes;
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(),
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP,
GraphPartitioner::Mode::kOrtFormatLoad,

View file

@ -28,6 +28,7 @@
#include "core/util/math.h"
#include "core/framework/sparse_utils.h"
#include "core/common/string_helper.h"
#include "core/graph/graph_proto_serializer.h"
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_TORCH_INTEROP
@ -265,14 +266,11 @@ struct ProviderHostImpl : ProviderHost {
void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) override { return p->IExecutionProvider::TryInsertAllocator(allocator); }
std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) override { return p->IExecutionProvider::GetCapability(graph_viewer, kernel_registries); }
// !!! this api will be deprecated soon
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::vector<NodeComputeInfo>& node_compute_funcs) override {
return p->IExecutionProvider::Compile(fused_nodes, node_compute_funcs);
}
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::string& dll_path) override {
return p->IExecutionProvider::Compile(fused_nodes, dll_path);
}
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs, std::vector<NodeComputeInfo>& node_compute_funcs) override {
return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs);
}
@ -354,6 +352,8 @@ struct ProviderHostImpl : ProviderHost {
#endif
// TypeProto (wrapped)
std::unique_ptr<ONNX_NAMESPACE::TypeProto> TypeProto__construct() override { return std::make_unique<ONNX_NAMESPACE::TypeProto>(); }
void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) override { p->CopyFrom(*other); }
const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) override { return p->tensor_type(); }
ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) override { return p->mutable_tensor_type(); }
int TypeProto__value_case(const ONNX_NAMESPACE::TypeProto* p) override { return p->value_case(); }
@ -441,6 +441,7 @@ struct ProviderHostImpl : ProviderHost {
int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); }
bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); }
void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); }
// TensorProtos (wrapped)
ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); }
@ -762,6 +763,9 @@ struct ProviderHostImpl : ProviderHost {
const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override {
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args);
}
// Path (wrapped)
PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); }

View file

@ -1242,12 +1242,9 @@ TEST(ExecutionProviderTest, FunctionTest) {
InferenceSession session_object_2{so, GetEnvironment()};
ASSERT_STATUS_OK(
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
status = session_object_2.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session_object_2.Initialize();
ASSERT_TRUE(status.IsOK());
status = session_object_2.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(session_object_2.Load(model_file_name));
ASSERT_STATUS_OK(session_object_2.Initialize());
ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}

View file

@ -143,8 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
DefaultLoggingManager().DefaultLogger(), profiler);
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP);
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP);
ASSERT_TRUE(status.IsOK()) << status;
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
@ -209,7 +209,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
// Partition the graph
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP);
ASSERT_TRUE(status.IsOK()) << status;
@ -259,7 +259,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
// Partition the graph
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP);
ASSERT_TRUE(status.IsOK()) << status;

View file

@ -1,409 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "core/common/logging/logging.h"
#include "core/framework/compute_capability.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/op_kernel.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/framework/test_utils.h"
#include "test/test_environment.h"
#include "test/nuphar_tvm/tvm_demo/demo_compiler.h"
#include <tvm/runtime/ndarray.h>
namespace onnxruntime {
using namespace tvm_demo;
class TVMDemoKernel : public OpKernel {
public:
explicit TVMDemoKernel(const OpKernelInfo& info) : OpKernel(info) {}
protected:
const TensorShape& GetOutputShape(OpKernelContext* context, int /*i*/) const {
return context->Input<Tensor>(0)->Shape();
}
};
class UnionSet {
public:
UnionSet(int n) {
for (int i = 0; i < n; ++i) {
farthers_.push_back(i);
}
}
int get(int x) {
if (farthers_[x] == x) {
return x;
}
return farthers_[x] = get(farthers_[x]);
}
void merge(int x, int y) {
x = get(x);
y = get(y);
if (x != y) {
farthers_[y] = x;
}
}
std::vector<int> farthers_;
};
static DLDataType GetDataType(ONNXTensorElementDataType type) {
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
return {kDLFloat, 64, 1};
} else
ORT_THROW("not implement.");
}
namespace test {
struct TVMFuncState {
AllocateFunc test_allocate_func = nullptr;
DestroyFunc test_release_func = nullptr;
AllocatorHandle allocator = nullptr;
tvm::runtime::Module* module = nullptr;
};
class FuseExecutionProviderX : public CPUExecutionProvider {
public:
explicit FuseExecutionProviderX(const CPUExecutionProviderInfo& info) : CPUExecutionProvider(info) {
}
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override {
std::vector<std::unique_ptr<ComputeCapability>> result;
std::vector<onnxruntime::NodeIndex> fused_nodes;
for (auto& node : graph_viewer.Nodes()) {
if (node.OpType() == "Mul") {
fused_nodes.push_back(node.Index());
}
}
UnionSet set(static_cast<int>(fused_nodes.size()));
for (int i = 0; i < fused_nodes.size(); ++i) {
auto node = graph_viewer.GetNode(fused_nodes[i]);
for (auto it = node->InputNodesBegin(); it != node->InputNodesEnd(); ++it) {
auto index_it = std::find(fused_nodes.begin(), fused_nodes.end(), (*it).Index());
if (index_it != fused_nodes.end()) {
set.merge(i, static_cast<int>(index_it - fused_nodes.begin()));
}
}
}
std::vector<std::vector<onnxruntime::NodeIndex>> groups;
groups.resize(fused_nodes.size());
for (int i = 0; i < set.farthers_.size(); ++i) {
groups[set.get(i)].push_back(fused_nodes[i]);
}
for (auto& group : groups) {
if (group.size() > 1) {
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
std::set<const onnxruntime::NodeArg*> fused_inputs, fused_outputs;
for (auto index : group) {
sub_graph->nodes.push_back(index);
auto node = graph_viewer.GetNode(index);
for (auto input : node->InputDefs()) {
auto it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
} else {
fused_inputs.insert(input);
}
}
for (auto output : node->OutputDefs()) {
auto it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
} else {
fused_outputs.insert(output);
}
}
}
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = "TVMFuse";
meta_def->domain = "FuseTest";
for (auto input : fused_inputs) {
meta_def->inputs.push_back(input->Name());
}
for (auto output : fused_outputs) {
meta_def->outputs.push_back(output->Name());
}
meta_def->since_version = 1;
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
sub_graph->SetMetaDef(std::move(meta_def));
//TODO:set fuse kernel func;
result.push_back(
std::make_unique<ComputeCapability>(std::move(sub_graph)));
}
}
return result;
}
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override {
for (auto* fused_node : fused_nodes) {
auto func_body = fused_node->GetFunctionBody();
if (!func_body)
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
//1. Build tvm IR based on the Ort graph
auto demo_tvm_tensor_ctx = BuildTVMIR(func_body->Body());
//2. Create schedule for the built tvm IRs
auto s = CreateSchedule(demo_tvm_tensor_ctx);
//3. Build tvm module
std::vector<tvm::Tensor> tvm_args;
for (auto& t : demo_tvm_tensor_ctx.inputs) {
tvm_args.push_back(t);
}
for (auto& t : demo_tvm_tensor_ctx.outputs) {
tvm_args.push_back(t);
}
std::vector<std::string> func_names;
auto module_ptr = std::make_shared<tvm::runtime::Module>();
*module_ptr = BuildStackVMModule(s, tvm::build_config(), tvm_args, func_names);
modules_[fused_node->Name()] = module_ptr;
NodeComputeInfo compute_info;
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
auto* p = new TVMFuncState();
*p = {context->allocate_func, context->release_func, context->allocator_handle, modules_[context->node_name].get()};
*state = p;
return 0;
};
compute_info.release_state_func = [](FunctionState state) {
if (state)
delete static_cast<TVMFuncState*>(state);
};
//we use lambda to capture the tvm model, so we can use it to get the funciton.
compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
Ort::CustomOpApi ort{*api};
TVMFuncState* tvm_state = reinterpret_cast<TVMFuncState*>(state);
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
auto eval_func_name = "func";
DLContext cpu_context = {kDLCPU, 0};
size_t num_inputs = ort.KernelContext_GetInputCount(context);
size_t num_outputs = ort.KernelContext_GetOutputCount(context);
size_t n_args = num_inputs + num_outputs;
std::vector<DLTensor> dl_tensors(n_args);
std::vector<TVMValue> tvm_values(n_args);
std::vector<int> tvm_type_codes(n_args);
for (auto i = 0; i < num_inputs; i++) {
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, i);
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
auto tensor_type = ort.GetTensorElementType(tensor_info);
input_shapes.emplace_back(ort.GetTensorShape(tensor_info));
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
tvm_type_codes[i] = kNDArrayContainer;
dl_tensors[i].ctx = cpu_context;
dl_tensors[i].dtype = GetDataType(tensor_type);
dl_tensors[i].strides = nullptr;
dl_tensors[i].byte_offset = 0;
dl_tensors[i].data = const_cast<double*>(ort.GetTensorData<double>(input_tensor));
dl_tensors[i].ndim = input_shapes.back().size();
dl_tensors[i].shape = input_shapes.back().data();
tvm_values[i].v_handle = &dl_tensors[i];
}
for (auto i = 0; i < num_outputs; i++) {
//setup output tensor property
//todo: type should be set by framework.
output_shapes.push_back(input_shapes[i]);
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, output_shapes[i].data(), output_shapes[i].size());
auto tensor_info = ort.GetTensorTypeAndShape(output_tensor);
auto tensor_type = ort.GetTensorElementType(tensor_info);
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
tvm_type_codes[num_inputs + i] = kNDArrayContainer;
dl_tensors[num_inputs + i].ctx = cpu_context;
dl_tensors[num_inputs + i].dtype = GetDataType(tensor_type);
dl_tensors[num_inputs + i].strides = nullptr;
dl_tensors[num_inputs + i].byte_offset = 0;
dl_tensors[num_inputs + i].data = ort.GetTensorMutableData<double>(output_tensor);
dl_tensors[num_inputs + i].ndim = output_shapes.back().size();
dl_tensors[num_inputs + i].shape = output_shapes.back().data();
tvm_values[num_inputs + i].v_handle = &dl_tensors[num_inputs + i];
}
auto evaluate_func_ = tvm_state->module->GetFunction(eval_func_name);
tvm::TVMArgs tvm_args(&tvm_values[0], &tvm_type_codes[0], static_cast<int>(n_args));
tvm::TVMRetValue rvalue;
try {
evaluate_func_.CallPacked(tvm_args, &rvalue);
} catch (std::exception&) {
return Status(common::ONNXRUNTIME, common::FAIL); // TODO: Translate exception to error code
}
if (rvalue.type_code() != kNull) {
return Status(common::ONNXRUNTIME, common::FAIL); // TODO: get error code.
} else {
return Status::OK();
}
};
node_compute_funcs.push_back(compute_info);
}
return Status::OK();
}
private:
std::unordered_map<std::string, std::shared_ptr<tvm::runtime::Module>> modules_;
};
static void RunSession(InferenceSession& session_object,
RunOptions& run_options,
std::vector<int64_t>& dims_x,
std::vector<double>& values_x,
std::vector<int64_t>& dims_y,
std::vector<double>& values_y) {
// prepare inputs
OrtValue ml_value;
CreateMLValue<double>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("X1", ml_value));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Y4");
std::vector<OrtValue> fetches;
// Now run
common::Status st = session_object.Run(run_options, feeds, output_names, &fetches);
if (!st.IsOK()) {
std::cout << "Run returned status: " << st.ErrorMessage() << std::endl;
}
EXPECT_TRUE(st.IsOK());
ASSERT_EQ(1, fetches.size());
auto& rtensor = fetches.front().Get<Tensor>();
TensorShape expected_shape(dims_y);
EXPECT_EQ(expected_shape, rtensor.Shape());
const std::vector<double> found(rtensor.template Data<double>(), rtensor.template Data<double>() + expected_shape.Size());
ASSERT_EQ(found.size(), values_y.size());
for (size_t i = 0; i < found.size(); i++)
ASSERT_EQ(found[i], values_y[i]);
}
static const std::string MODEL_URI = "testdata/fuse_mul_1.onnx";
TEST(TVMTest, CodeGen_Demo_for_Fuse_Mul) {
SessionOptions so;
so.session_logid = "InferenceSessionTests.NoTimeout";
InferenceSession session_object{so, GetEnvironment()};
CPUExecutionProviderInfo info;
auto tvm_xp = std::make_unique<FuseExecutionProviderX>(info);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(tvm_xp)).IsOK());
EXPECT_TRUE(session_object.Load(MODEL_URI).IsOK());
EXPECT_TRUE(session_object.Initialize().IsOK());
RunOptions run_options;
run_options.run_tag = "one session/one tag";
// prepare inputs
std::vector<int64_t> dims_x = {
6,
};
std::vector<double> values_x = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {
6,
};
// now the expected value should be Mul's result.
std::vector<double> expected_values_y = {1.0, 32.0, 243.0, 1024.0, 3125.0, 7776.0};
// Now run
RunSession(session_object, run_options, dims_x, values_x, expected_dims_y, expected_values_y);
}
} // namespace test
} // namespace onnxruntime
TEST(TVMTest, Native_TVM) {
using namespace tvm;
auto n = var("n");
Array<Expr> shape;
shape.push_back(n);
auto A = placeholder(shape, Float(64), "A");
auto B = placeholder(shape, Float(64), "B");
auto D = placeholder(shape, Float(64), "D");
auto C = compute(
A->shape, [&A, &B](Expr i) {
return A[i] + B[i];
},
"C");
auto E = compute(
A->shape, [&C, &D](Expr i) {
return C[i] + D[i];
},
"E");
auto s = create_schedule({E->op});
auto args = Array<Tensor>({A, B, D, E});
std::unordered_map<Tensor, Buffer> binds;
auto config = build_config();
#ifdef USE_NUPHAR_TVM_WITH_LLVM
auto target = target::llvm();
#else
auto target = target::stackvm();
#endif
auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, Target(), config);
auto func = module.GetFunction("func");
DLDataType dtype;
dtype.code = kDLFloat;
dtype.bits = 64;
dtype.lanes = 1;
DLContext ctx;
ctx.device_type = DLDeviceType::kDLCPU;
ctx.device_id = 0;
std::vector<double> v = {1.0, 2.0, 3.0};
int64_t len = 3;
DLTensor tensor_A = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
DLTensor tensor_B = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
DLTensor tensor_D = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
std::vector<double> r;
r.resize(len);
DLTensor tensor_E = {&r[0], ctx, 1, dtype, &len, nullptr, 0};
TVMValue lvalues[4];
int type_codes[4] = {kNDArrayContainer, kNDArrayContainer, kNDArrayContainer, kNDArrayContainer};
lvalues[0].v_handle = &tensor_A;
lvalues[1].v_handle = &tensor_B;
lvalues[2].v_handle = &tensor_D;
lvalues[3].v_handle = &tensor_E;
TVMArgs tvm_args(lvalues, type_codes, 4);
TVMRetValue rvalue;
func.CallPacked(tvm_args, &rvalue);
CHECK_EQ(rvalue.type_code(), kNull);
double expected[3] = {3.0, 6.0, 9.0};
auto data_E = static_cast<double*>(tensor_E.data);
for (int i = 0; i < 3; i++) {
EXPECT_NEAR(*(data_E + i), expected[i], 0.001f);
}
}

View file

@ -1,226 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/nuphar_tvm/tvm_demo/demo_compiler.h"
#include "core/codegen/passes/scheduler/schedule_utils.h"
#include "core/codegen/passes/utils/ort_tvm_utils.h"
#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h"
#include "core/codegen/passes/scheduler/tvm_scheduler.h"
#include "core/codegen/passes/scheduler/tvm_schedule_builder.h"
#include <tvm/tvm.h>
#include <tvm/build_module.h>
namespace onnxruntime {
namespace tvm_demo {
// Create a dummy demo handle
static codegen::CodeGenHandle demo_handle;
// Create a dummy demo codegen context
static tvm_codegen::CodeGenContext demo_codegen_ctx(&demo_handle);
// Translate an Ort graph into tvm IR
// Note this function is specific for this demo.
// This function uses specific way for graph traversal or constructing tvm placeholders.
// It may or may not work for a universal Ort graph.
// For a more general example, please check nuphar provider.
DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph) {
// Create OpIRRegistry that holds all OpIRCreators
std::unique_ptr<tvm_codegen::OpIRRegistry> op_ir_registry =
std::make_unique<tvm_codegen::OpIRRegistry>();
// Register all generic OpIRCreators
tvm_codegen::RegisterAllGenericOpIRCreators(op_ir_registry.get());
// Create OpIRBuilder
std::shared_ptr<tvm_codegen::TVMIRBuilder> op_ir_builder =
std::make_shared<tvm_codegen::TVMIRBuilder>("Demo_Op_IR_Builder");
// Attach all generic OpIRCreators from op_ir_registry to op_ir_builder
tvm_codegen::RegisterGenericOrtOpTypeDispatcher(op_ir_builder, op_ir_registry.get());
// Create DemoTVMTensorCtx holdings tvm IR
DemoTVMTensorCtx result;
// Local lookup from name to tvm::Tensor
std::unordered_map<std::string, tvm::Tensor> tvm_tensors;
// Note this is a simplified traversal that works specifically for this demo
// but may or may not work for an univerisal model.
// For more general traversal, please check nuphar provider.
for (auto& node : graph.Nodes()) {
tvm::Array<tvm::Tensor> inputs;
tvm::Array<tvm::Tensor> outputs;
// Get inputs
for (auto& def : node.InputDefs()) {
const std::string& name = def->Name();
auto iter = tvm_tensors.find(name);
// Always create placeholder when not finding a tensor
// Note it is for this demo.
// It may or may not work for a universal graph.
if (iter == tvm_tensors.end()) {
tvm_tensors[name] =
tvm::placeholder(ShapeToTvmArray(def, demo_codegen_ctx),
tvm_codegen::ToTvmType(TensorProtoDataType(def)),
name + "_placeholder");
}
inputs.push_back(tvm_tensors[name]);
}
// call OpIBuilder's Evaluate to build tvm IR
ORT_THROW_IF_ERROR(op_ir_builder->Evaluate(inputs, node, demo_codegen_ctx, outputs));
// Store outputs
for (size_t def_id = 0; def_id < node.OutputDefs().size(); ++def_id) {
const NodeArg* def = node.OutputDefs()[def_id];
tvm_tensors[def->Name()] = outputs[def_id];
}
}
// put inputs to DemoTVMTensorCtx
for (auto& input : graph.GetInputs()) {
result.inputs.push_back(tvm_tensors[input->Name()]);
}
// check initializer
for (auto& initializer : graph.GetAllInitializedTensors()) {
result.inputs.push_back(tvm_tensors[initializer.first]);
}
// Only one output in this demo
auto& output = graph.GetOutputs()[0];
result.outputs.push_back(tvm_tensors[output->Name()]);
return result;
}
// Declare a Demo scheduler that always inserts compute_inline
DECLARE_TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)
// Define a Demo scheduler's Evaluate that always inserts compute_inline
bool TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)::Evaluate(
const tvm::Tensor& tensor,
const Node*,
tvm_codegen::CodeGenContext&,
tvm_codegen::ScheduleContext& ctx_sched) {
return TryInlineSchedule(tensor, ctx_sched);
}
// Register the always inline Scheduler to sched_registry
static void RegisterAlwaysInlineScheduler(tvm_codegen::TVMScheduleRegistry* sched_registry) {
sched_registry->Register(
std::make_unique<TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)>());
}
// Declare a schedule dispatcher that always dispatches the always inline Scheduler
DECLARE_SCHEDULE_DISPATCHER_CLASS(DemoTVM)
// Use a predefined key as DemoKey to dispatch the scheduler
constexpr auto predefined_key = "DemoKey";
// Define the schedule dispatcher's Find function
// that always dispatches the always inline Scheduler
// Note this dispatcher always returning a predefined_key is only for demo purpose.
// In practice, a dispatcher returns a key by checking tvm::Tensor, Node,
// or even meta data stored in CodeGenContext.
// Derived CodeGenContext allows compiler developers to store their specific meta data.
// For more detailed example, please check nuphar provider.
tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(DemoTVM)::Find(
const tvm::Tensor&, const Node*, tvm_codegen::CodeGenContext&) {
return DispatcherBase::Get(predefined_key);
}
// Attach the always inline Scheduler to the above dispatcher
// and then attach the dispatcher to the scheduler builder
static void AttachAlwaysInlineScheduler(const std::shared_ptr<tvm_codegen::TVMScheduleBuilder>& builder,
const tvm_codegen::TVMScheduleRegistry* registry) {
auto dispatcher = std::make_unique<SCHEDULE_DISPATCHER_CLASS(DemoTVM)>("DemoSchedulers");
// Using a predefined_key
dispatcher->Register(predefined_key,
registry->Get(TVM_SCHEDULER_STRING(AlwaysInline, DemoTVM)));
builder->InsertDispatcher(std::move(dispatcher));
}
// Traverse tvm::Tensor and then schedule them
// Note this traversal is simplified and specific for this demo.
// For a more general traversal, please check nuphar provider.
static void TraverseAndSchedule(
std::shared_ptr<tvm_codegen::TVMScheduleBuilder>& schedule_builder,
const tvm::Tensor& tensor,
tvm_codegen::ScheduleContext& ctx_schedule) {
ORT_THROW_IF_ERROR(schedule_builder->Evaluate(tensor, nullptr, demo_codegen_ctx, ctx_schedule));
// Traverse tensor's children (inputs)
for (auto& t : tensor->op->InputTensors()) {
// check whether it is a non-trivial tensor by checking its input size
if (t->op->InputTensors().size() > 0) {
TraverseAndSchedule(schedule_builder, t, ctx_schedule);
}
}
}
// Create a TVM schedule by always inserting tvm's compute_inline.
// Note this schedule is specific for this demo.
// In practice, always inline might lead to bad performance
// or even illegal loop transformation for some backends.
// For a more general example, please check nuphar provider.
tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx) {
// Create TVMScheduleRegistry that holds all Scheduler
std::unique_ptr<tvm_codegen::TVMScheduleRegistry> schedule_registry =
std::make_unique<tvm_codegen::TVMScheduleRegistry>();
// Register the always inline Scheduler to schedule_registry
RegisterAlwaysInlineScheduler(schedule_registry.get());
// Create a DemoScheduleBuilder
std::shared_ptr<tvm_codegen::TVMScheduleBuilder> schedule_builder =
std::make_shared<tvm_codegen::TVMScheduleBuilder>("Demo_Schedule_Builder");
// Attach the demo inline scheduler to the schedule_builder
AttachAlwaysInlineScheduler(schedule_builder, schedule_registry.get());
// Create scheudule object
tvm::Array<tvm::Operation> out_ops;
for (auto& t : ctx.outputs) {
out_ops.push_back(t->op);
}
// Create scheudule context
tvm_codegen::ScheduleContext ctx_schedule(out_ops);
// Traverse tvm::Tensor in a DFS way, and then schedule
for (auto& t : ctx.outputs) {
TraverseAndSchedule(schedule_builder, t, ctx_schedule);
}
// Make sure all outputs compute_root (tvm's requirement)
for (auto& t : ctx.outputs) {
tvm_codegen::InsertRootSchedule(t, ctx_schedule);
}
return ctx_schedule.schedule;
}
// Build TVM Module with a schedule using tvm's stackvm.
// Note in real practice, please change stackvm to other backends.
// For a more detailed example, please check nuphar provider.
tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule,
tvm::BuildConfig config,
tvm::Array<tvm::Tensor> tvm_args,
std::vector<std::string>& target_func_names) {
auto target = tvm::target::stackvm();
std::string func_name = "func";
auto args = tvm::Array<tvm::Tensor>(tvm_args);
std::unordered_map<tvm::Tensor, tvm::Buffer> binds;
auto lowered = lower(schedule, args, "func", binds, config);
// Uncomment the following line to dump lowered func
// std::cout << "Dumping lowered func: " << lowered[0]->body;
target_func_names.push_back(func_name);
return build(lowered, target, tvm::Target(), config);
}
} // namespace tvm_demo
} // namespace onnxruntime

View file

@ -1,31 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/graph/graph_viewer.h"
#include <tvm/tvm.h>
#include <tvm/build_module.h>
namespace onnxruntime {
namespace tvm_demo {
// A Demo data structure to hold tvm IR and context
struct DemoTVMTensorCtx {
tvm::Array<tvm::Tensor> inputs;
tvm::Array<tvm::Tensor> outputs;
};
// Translate an Ort graph into tvm IR
DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph);
// Create a demo schedule for the tvm IR
tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx);
// Build a demo tvm module with the tvm IR and schedule
tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule,
tvm::BuildConfig config,
tvm::Array<tvm::Tensor> tvm_args,
std::vector<std::string>& target_func_names);
} // namespace tvm_demo
} // namespace onnxruntime

View file

@ -21,10 +21,6 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
FusionStyle GetFusionStyle() const override {
return FusionStyle::FilteredGraphViewer;
}
DataLayout GetPreferredLayout() const override;
private: