mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Unify the Compile API for mobile build and normal build (#10632)
* use the lightweight compile api as default; use dnnl ep for testing * apply to tensorrt ep * fix the missing files * fix build * fix the copy issue on linux * migrate migraphx and openvino ep * fix openvino build break * fix linux build * fix unused parameter * fix coreml build * use graph view's filtered initializers * fix openvino break * fix tvm compile api * fix tvm / rknpu / vitisai ep build * add IsInitializedTensor in graph_viewer; fix nuphar build * use serializer directly as tvm ep is still static lib * fix the type mismatch * fix the type mismatch * fix merge conflict * add a comment * fix minimal build * fix the DML EP's legacy approach * save type/shape in dnnl IR * fix linux break * fix tvm failure * dnnl ep: move initializer referenced out of dnnl subgraph * Revert "add IsInitializedTensor in graph_viewer; fix nuphar build" This reverts commit 1cc3c7f08c16fee4fe3309a67209eb769d479587. * add IsInitializedTensor to graph viewer * add the legacy code for nuphar build to temporarily make nuphar build work * ignore internal test for nuphar * remove the out of date tests * keep the legacy API in EP for a while * turn serializer into a static function * update comments * fix tvm build * Update include/onnxruntime/core/framework/execution_provider.h Co-authored-by: Pranav Sharma <prs@microsoft.com> * Update include/onnxruntime/core/framework/execution_provider.h Co-authored-by: Pranav Sharma <prs@microsoft.com> * Update onnxruntime/core/framework/execution_provider.cc Co-authored-by: Pranav Sharma <prs@microsoft.com> * updatee comments; add warning message for legacy compil call * add a flag to control out of scope arg in serialization * fix trt build; improve the test * resolve merege errors * fix a typo Co-authored-by: Cheng Tang <chenta@microsoft.com> Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
parent
eca4cbc419
commit
3f3c5fcd68
52 changed files with 516 additions and 1161 deletions
|
|
@ -30,6 +30,12 @@ if (onnxruntime_MINIMAL_BUILD)
|
|||
"${ONNXRUNTIME_ROOT}/core/graph/function*"
|
||||
)
|
||||
|
||||
# remove graph proto serializer
|
||||
list(APPEND onnxruntime_graph_src_exclude_patterns
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/graph_proto_serializer.h"
|
||||
)
|
||||
|
||||
# no optimizer support in base minimal build
|
||||
# some optimizer support in extended minimal build
|
||||
if (NOT onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ if (onnxruntime_USE_RKNPU)
|
|||
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src})
|
||||
endif()
|
||||
|
||||
if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
if ((NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_NUPHAR) OR onnxruntime_EXTENDED_MINIMAL_BUILD)
|
||||
file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS
|
||||
"${TEST_SRC_DIR}/providers/internal_testing/*"
|
||||
)
|
||||
|
|
@ -464,11 +464,10 @@ if(onnxruntime_USE_COREML)
|
|||
endif()
|
||||
|
||||
if(onnxruntime_USE_NUPHAR)
|
||||
file(GLOB_RECURSE onnxruntime_test_nuphar_src CONFIGURE_DEPENDS
|
||||
"${TEST_SRC_DIR}/nuphar_tvm/*.h"
|
||||
"${TEST_SRC_DIR}/nuphar_tvm/*.cc"
|
||||
)
|
||||
|
||||
# the test case under nuphar_tvm is only to verify some basic tvm show case, which is already out of date
|
||||
# it doesn't have relationship to nuphar directly. consider we have an official tvm execution provider now,
|
||||
# keep those test cases doesn't bring any value now.
|
||||
|
||||
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/framework/nuphar/*)
|
||||
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nuphar)
|
||||
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nuphar)
|
||||
|
|
|
|||
|
|
@ -197,48 +197,15 @@ class IExecutionProvider {
|
|||
// TODO: temparary sulotion, need to unify the interface in EP and AllocatorManager
|
||||
void TryInsertAllocator(AllocatorPtr allocator);
|
||||
|
||||
// creation of a fused node is not supported in a minimal build, so any EP enabled in that scenario must support
|
||||
// compilation via GraphViewer instances.
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
/**
|
||||
Given a list of fused_node, return create_state/compute/release_state func for each node.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
|
||||
/**
|
||||
Given a list of fused_node, return a dll that expose functions for each node.
|
||||
For each node, there should be three symbols:
|
||||
Create_State_${node_name}
|
||||
Compute_${node_name}
|
||||
Release_State_${node_name}
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::string& dll_path);
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
struct FusedNodeAndGraph {
|
||||
const std::reference_wrapper<onnxruntime::Node> fused_node;
|
||||
// GraphViewer that filters the full graph to the nodes that are covered by 'node'
|
||||
const std::reference_wrapper<GraphViewer> filtered_graph;
|
||||
};
|
||||
|
||||
/**
|
||||
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
|
||||
return create_state/compute/release_state func for each node.
|
||||
@remarks This is an optional interface that is only needed if the execution provider compiles nodes
|
||||
in a scenario involving the minimal build. i.e. on a mobile or embedded device with ORT format model.
|
||||
|
||||
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
|
||||
as it is only valid for the duration of the call to Compile.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
#endif
|
||||
|
||||
// Fusion approach that is suppported
|
||||
// !!! The "Function" FusionStyle will be deprecated soon.
|
||||
// !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
|
||||
enum class FusionStyle {
|
||||
// The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
|
||||
// in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
|
||||
|
|
@ -254,12 +221,35 @@ class IExecutionProvider {
|
|||
};
|
||||
|
||||
virtual FusionStyle GetFusionStyle() const {
|
||||
// existing EPs use this mode so default to it.
|
||||
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
|
||||
// FilteredGraphViewer
|
||||
return FusionStyle::Function;
|
||||
// All the ORT build in EP has migrate to FilteredGraphViewer style except Nuphar.
|
||||
// For newer EPs, please avoid use Function style as it will be deprecated soon.
|
||||
return FusionStyle::FilteredGraphViewer;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
/**
|
||||
* !!!! This API will be deprecated soon. If your execution provider overrides this API
|
||||
* !!!! Please migrate it to the "Compile" API with FusedNodeAndGraph type.
|
||||
Given a list of fused_node, return create_state/compute/release_state func for each node.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
|
||||
/**
|
||||
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
|
||||
return create_state/compute/release_state func for each node.
|
||||
@remarks This is now the default interface when execution provider wants to compile nodes
|
||||
for both minimal build and complete ort build.
|
||||
|
||||
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
|
||||
as it is only valid for the duration of the call to Compile.
|
||||
*/
|
||||
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs);
|
||||
|
||||
#endif
|
||||
|
||||
void SetLogger(const logging::Logger* logger) {
|
||||
logger_ = logger;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1258,6 +1258,10 @@ class Graph {
|
|||
return Resolve(default_options);
|
||||
}
|
||||
|
||||
const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept{
|
||||
return outer_scope_node_arg_names_;
|
||||
}
|
||||
|
||||
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -149,6 +149,10 @@ class GraphViewer {
|
|||
/** Get the internal graph*/
|
||||
const Graph& GetGraph() const { return *graph_; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept;
|
||||
#endif
|
||||
|
||||
/**
|
||||
returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
|
||||
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
|
||||
|
|
@ -156,6 +160,9 @@ class GraphViewer {
|
|||
*/
|
||||
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
|
||||
|
||||
/** Check if a given name is an initializer tensor's name in this graph. */
|
||||
bool IsInitializedTensor(const std::string& name) const;
|
||||
|
||||
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
|
||||
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
|
||||
@param check_outer_scope If true and the graph is a subgraph,
|
||||
|
|
|
|||
|
|
@ -151,26 +151,21 @@ void IExecutionProvider::RegisterAllocator(std::shared_ptr<AllocatorManager>) {
|
|||
return;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// !!!!This API will be deprecated soon. If your execution provider overrides this API
|
||||
// !!!!Please migrate it to the "Compile" API with FusedNodeAndGraph type.
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with fused Node is not implemented by " + type_);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& /*fused_node*/,
|
||||
std::string& /*dll_path*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with fused Node and dll path is not implemented by " + type_);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& /*fused_nodes_and_graphs*/,
|
||||
std::vector<NodeComputeInfo>& /*node_compute_funcs*/) {
|
||||
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED,
|
||||
"IExecutionProvider::Compile with FusedNodeAndGraph is not implemented by " + type_);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,15 @@ NonCudaOps non_cuda;
|
|||
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
|
||||
auto schema = node.Op();
|
||||
builder.SetName(schema->Name())
|
||||
.SetDomain(schema->domain())
|
||||
.SinceVersion(schema->SinceVersion())
|
||||
.Provider(node.GetExecutionProviderType());
|
||||
}
|
||||
#endif
|
||||
|
||||
// minimal KernelDef based on MetaDef instead of a Function based node
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const IndexedSubGraph::MetaDef& metadef,
|
||||
|
|
@ -103,7 +112,7 @@ static void AssignNodes(Graph& graph, const IndexedSubGraph& capability,
|
|||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep,
|
||||
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
|
||||
GraphPartitioner::Mode mode, std::vector<std::unique_ptr<ComputeCapability>>& capabilities,
|
||||
TransformLayoutFunction transform_layout) {
|
||||
{
|
||||
GraphViewer graph_viewer(graph);
|
||||
|
|
@ -142,15 +151,6 @@ static Status GetCapabilityForEP(Graph& graph, KernelRegistryManager& kernel_reg
|
|||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::Node& node) {
|
||||
auto schema = node.Op();
|
||||
builder.SetName(schema->Name())
|
||||
.SetDomain(schema->domain())
|
||||
.SinceVersion(schema->SinceVersion())
|
||||
.Provider(node.GetExecutionProviderType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a node can be placed on a specific provider.
|
||||
* Do nothing if the node is already assigned
|
||||
|
|
@ -162,8 +162,8 @@ static void BuildFusedKernelDef(KernelDefBuilder& builder, const onnxruntime::No
|
|||
* \return Fused node. Return nullptr if there is no fuse
|
||||
*/
|
||||
static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
||||
const KernelRegistryManager& kernel_registry_mgr, const std::string& provider_type,
|
||||
IExecutionProvider::FusionStyle fusion_style,
|
||||
const std::string& provider_type,
|
||||
GraphPartitioner::Mode mode,
|
||||
int& fused_node_unique_id) {
|
||||
Node* result = nullptr;
|
||||
|
|
@ -216,7 +216,18 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
|||
std::string node_name = oss.str();
|
||||
|
||||
Node* fused_node = nullptr;
|
||||
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
|
||||
// TODO1: The DML currently use some legacy approach.
|
||||
// It registers a generic predefined kernel for all purpose fusion,
|
||||
// so it rely on the function body in the fused node during kernel creation,
|
||||
// which is after the graph partition phase.
|
||||
// Ideally, it should be moved to "Compile" call.
|
||||
// Here we temporary keep the function body for DML fusion
|
||||
// Need to remove it after migrate DML to the Compile-based approach.
|
||||
// TODO2: Nuphar is out of maintain, keep it with old API temporarily.
|
||||
// We want to deprecate Nuphar soon.
|
||||
if (fusion_style == IExecutionProvider::FusionStyle::Function ||
|
||||
provider_type == kDmlExecutionProvider ||
|
||||
provider_type == kNupharExecutionProvider) {
|
||||
fused_node = &graph.FuseSubGraph(capability, node_name);
|
||||
} else {
|
||||
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
|
||||
|
|
@ -226,10 +237,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
|||
|
||||
fused_node->SetExecutionProviderType(provider_type);
|
||||
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *fused_node, provider_type)) {
|
||||
result = fused_node;
|
||||
}
|
||||
result = fused_node;
|
||||
} else {
|
||||
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
|
||||
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
|
||||
|
|
@ -250,7 +258,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
|
|||
|
||||
// for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up).
|
||||
// assign any nodes to the EP that are currently unassigned, and that the EP can handle.
|
||||
static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistryManager& kernel_registry_mgr,
|
||||
KernelRegistry& fused_kernel_registry,
|
||||
IExecutionProvider& current_ep,
|
||||
|
|
@ -268,7 +276,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
// we pass through the export_dll value and FuncManager from the top level graph
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, export_dll, func_mgr, kernel_registry_mgr,
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
|
||||
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
|
|
@ -295,6 +303,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
const std::string& type = current_ep.Type();
|
||||
auto fusion_style = current_ep.GetFusionStyle();
|
||||
std::vector<Node*> nodes_to_compile;
|
||||
|
||||
// The fused node may map to an existing kernel, so it is fused but doesn't need to be compiled
|
||||
// But we still need to finalize the graph fusion for those nodes.
|
||||
std::vector<Node*> nodes_to_complete_fuse;
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_complete_fuse;
|
||||
|
||||
// filter out the ComputeCapability instances that do not need compiling so we have a std::vector that's 1:1 with
|
||||
// nodes_to_compile.
|
||||
std::vector<std::unique_ptr<ComputeCapability>> capabilities_to_compile;
|
||||
|
|
@ -311,42 +325,46 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
continue;
|
||||
}
|
||||
|
||||
Node* n = PlaceNode(graph, *capability->sub_graph, kernel_registry_mgr, type, fusion_style, mode, fused_node_unique_id);
|
||||
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
|
||||
if (n != nullptr) {
|
||||
nodes_to_compile.push_back(n);
|
||||
capabilities_to_compile.push_back(std::move(capability));
|
||||
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
|
||||
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
|
||||
nodes_to_compile.push_back(n);
|
||||
capabilities_to_compile.push_back(std::move(capability));
|
||||
} else {
|
||||
// there is a predefined kernel for the fused node. doesn't need compile, but need to complete the fusing.
|
||||
nodes_to_complete_fuse.push_back(n);
|
||||
capabilities_to_complete_fuse.push_back(std::move(capability));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
|
||||
// even with single node, EP might still want to compile it.
|
||||
// for example, it want to JIT an optimized kernel for LSTM with a given shape.
|
||||
if (!nodes_to_compile.empty()) {
|
||||
std::vector<NodeComputeInfo> node_compute_funcs;
|
||||
|
||||
if (export_dll) {
|
||||
ORT_ENFORCE(fusion_style == IExecutionProvider::FusionStyle::Function,
|
||||
"Must use Function based fusion when exporting compiled nodes to dll.");
|
||||
}
|
||||
|
||||
// !!! The Function style fusion will be deprecated soon.
|
||||
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
|
||||
// TODO: Nuphar is out of maintain. Use the old api temporarily.
|
||||
// We want to deprecate it soon.
|
||||
// Create a Function based node where the fused nodes have a new Graph instance.
|
||||
static std::once_flag legacy_compile_method_warning_flag;
|
||||
std::call_once(
|
||||
legacy_compile_method_warning_flag, [](std::string_view ep_type) {
|
||||
LOGS_DEFAULT(WARNING) << "Execution Provider: " << ep_type << " is still using Funciton style Compile API, "
|
||||
<< " which will be deprecated soon, please migrate to the new Compile API based on "
|
||||
<< " FilteredGraphViewer. ";
|
||||
},
|
||||
type);
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
|
||||
|
||||
if (export_dll) {
|
||||
std::string dll_path;
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, dll_path));
|
||||
if (node_compute_funcs.size() != nodes_to_compile.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (auto* node : nodes_to_compile) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node->Name(), dll_path));
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_to_compile, node_compute_funcs));
|
||||
|
||||
if (node_compute_funcs.size() != nodes_to_compile.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
|
||||
}
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(nodes_to_compile[j]->Name(), std::move(node_compute_funcs[j])));
|
||||
}
|
||||
|
||||
for (auto* node : nodes_to_compile) {
|
||||
|
|
@ -358,12 +376,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
return FunctionKernel::Create(func_mgr, info, out);
|
||||
}));
|
||||
}
|
||||
|
||||
} else {
|
||||
// temporary storage for the GraphViewer for each IndexedSubGraph
|
||||
std::vector<std::unique_ptr<GraphViewer>> viewers;
|
||||
viewers.reserve(nodes_to_compile.size());
|
||||
std::vector<IExecutionProvider::FusedNodeAndGraph> nodes_and_viewers;
|
||||
nodes_and_viewers.reserve(nodes_to_compile.size());
|
||||
|
||||
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
|
||||
auto* node = nodes_to_compile[j];
|
||||
|
|
@ -401,6 +419,21 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, bool export_dll, FuncMa
|
|||
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: The DML currently use some legacy approach.
|
||||
// The fuse is done in FuseSubGraph function.
|
||||
// Need to remove it later when DML migrate to Compile approach
|
||||
if (!nodes_to_complete_fuse.empty() && type != kDmlExecutionProvider) {
|
||||
for (size_t j = 0, end = nodes_to_complete_fuse.size(); j < end; j++) {
|
||||
auto* node = nodes_to_complete_fuse[j];
|
||||
|
||||
const auto& cur_capability = capabilities_to_complete_fuse[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ValidateGraphPartitioning(graph));
|
||||
}
|
||||
|
|
@ -458,7 +491,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode,
|
||||
int& fused_node_unique_id,
|
||||
TransformLayoutFunction transform_layout_function) const {
|
||||
|
|
@ -467,7 +500,7 @@ Status GraphPartitioner::PartitionOnnxFormatModel(Graph& graph, bool export_dll,
|
|||
do {
|
||||
// process full graph with each EP
|
||||
for (const auto& ep : providers_) {
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, export_dll, func_mgr, kernel_registry_mgr_,
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_mgr_,
|
||||
fused_kernel_registry, *ep, mode, fused_node_unique_id,
|
||||
transform_layout_function));
|
||||
}
|
||||
|
|
@ -610,7 +643,7 @@ Status GraphPartitioner::PartitionOrtFormatModel(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
||||
TransformLayoutFunction transform_layout_function, Mode mode,
|
||||
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes) const {
|
||||
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
|
||||
|
|
@ -634,19 +667,16 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
|
|||
|
||||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, export_dll, func_mgr, *fused_kernel_registry, mode,
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(graph, func_mgr, *fused_kernel_registry, mode,
|
||||
fused_node_unique_id, transform_layout_function));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(export_dll);
|
||||
ORT_THROW("Not supported in this build.");
|
||||
#endif //! defined(ORT_MINIMAL_BUILD)
|
||||
} else {
|
||||
ORT_ENFORCE(compiled_kernel_hashes != nullptr, "Compiled kernel hashes must be provided");
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(graph, func_mgr, *fused_kernel_registry, *compiled_kernel_hashes,
|
||||
fused_node_unique_id, transform_layout_function));
|
||||
}
|
||||
|
||||
if (!fused_kernel_registry->IsEmpty()) {
|
||||
kernel_registry_mgr_.RegisterKernelRegistry(fused_kernel_registry);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class GraphPartitioner {
|
|||
}
|
||||
|
||||
// Run partitioning. Provide compiled_kernel_hashes if mode is kOrtFormatLoad.
|
||||
Status Partition(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Status Partition(Graph& graph, FuncManager& func_mgr,
|
||||
TransformLayoutFunction transform_layout_function,
|
||||
Mode mode = Mode::kNormal,
|
||||
std::unordered_map<std::string, HashValue>* compiled_kernel_hashes = nullptr) const;
|
||||
|
|
@ -41,7 +41,7 @@ class GraphPartitioner {
|
|||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status PartitionOnnxFormatModel(Graph& graph, bool export_dll, FuncManager& func_mgr,
|
||||
Status PartitionOnnxFormatModel(Graph& graph, FuncManager& func_mgr,
|
||||
KernelRegistry& fused_kernel_registry, Mode mode,
|
||||
int& fused_node_unique_id, TransformLayoutFunction transform_layout_function) const;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -276,9 +276,6 @@ class SessionState {
|
|||
concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; }
|
||||
concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; }
|
||||
|
||||
bool ExportDll() const noexcept { return export_fused_dll_; }
|
||||
void SetExportDllFlag(bool flag) noexcept { export_fused_dll_ = flag; }
|
||||
|
||||
const FuncManager& GetFuncMgr() const noexcept { return fused_funcs_mgr_; }
|
||||
FuncManager& GetMutableFuncMgr() noexcept { return fused_funcs_mgr_; }
|
||||
|
||||
|
|
@ -488,7 +485,6 @@ class SessionState {
|
|||
concurrency::ThreadPool* const thread_pool_{};
|
||||
concurrency::ThreadPool* const inter_op_thread_pool_{};
|
||||
|
||||
bool export_fused_dll_ = false;
|
||||
const DataTransferManager& data_transfer_mgr_;
|
||||
|
||||
bool use_deterministic_compute_;
|
||||
|
|
|
|||
54
onnxruntime/core/graph/graph_proto_serializer.cc
Normal file
54
onnxruntime/core/graph/graph_proto_serializer.cc
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/graph_proto_serializer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void GraphViewerToProto(const GraphViewer& graph_view,
|
||||
ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializer,
|
||||
bool include_outer_scope_args) {
|
||||
graph_proto.set_name(graph_view.Name());
|
||||
graph_proto.set_doc_string(graph_view.Description());
|
||||
|
||||
for (const auto* input_arg : graph_view.GetInputsIncludingInitializers()) {
|
||||
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
|
||||
}
|
||||
|
||||
for (const auto* output_arg : graph_view.GetOutputs()) {
|
||||
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
|
||||
}
|
||||
|
||||
for (const auto* value_info : graph_view.GetValueInfo()) {
|
||||
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
|
||||
}
|
||||
|
||||
if (include_outer_scope_args){
|
||||
// add the NodeArg info for outer scope NodeArgs so we capture the type information
|
||||
for (const auto& name : graph_view.GetOuterScopeNodeArgNames()) {
|
||||
auto* node_arg = graph_view.GetNodeArg(name);
|
||||
ORT_ENFORCE(node_arg, "Outer scope node arg name '" + name + "'was added but does not exist. ");
|
||||
*(graph_proto.mutable_value_info()->Add()) = node_arg->ToProto();
|
||||
}
|
||||
}
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
|
||||
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{graph_view.GetNode(node_idx)};
|
||||
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
|
||||
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
|
||||
p_node->ToProto(*node_proto, /* update_subgraphs */ true);
|
||||
}
|
||||
|
||||
if (include_initializer) {
|
||||
auto& initializers = graph_view.GetAllInitializedTensors();
|
||||
for (auto& it : initializers) {
|
||||
auto* p_initializer = graph_proto.add_initializer();
|
||||
*p_initializer = *(it.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
11
onnxruntime/core/graph/graph_proto_serializer.h
Normal file
11
onnxruntime/core/graph/graph_proto_serializer.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph_viewer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args);
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -271,9 +271,19 @@ bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_oute
|
|||
return GetConstantInitializer(name, check_outer_scope) != nullptr;
|
||||
}
|
||||
|
||||
bool GraphViewer::IsInitializedTensor(const std::string& name) const {
|
||||
return graph_->IsInitializedTensor(name);
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* GraphViewer::GetConstantInitializer(const std::string& initializer_name,
|
||||
bool check_outer_scope) const {
|
||||
return graph_->GetConstantInitializer(initializer_name, check_outer_scope);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const std::unordered_set<std::string>& GraphViewer::GetOuterScopeNodeArgNames() const noexcept {
|
||||
return graph_->GetOuterScopeNodeArgNames();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ class CoreMLExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
||||
// we implement the Compile that takes FusedNodeAndGraph instances
|
||||
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
|
|
|||
|
|
@ -124,33 +124,6 @@ std::vector<std::vector<NodeIndex>> DNNLExecutionProvider::GetSupportedNodes(con
|
|||
return supported_node_vecs;
|
||||
}
|
||||
|
||||
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
|
||||
for (const auto* input_arg : graph.GetInputs()) {
|
||||
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
|
||||
}
|
||||
|
||||
// Add all graph's initializers to the subgraph
|
||||
const auto& init_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& tensor : init_tensors) {
|
||||
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
|
||||
}
|
||||
|
||||
for (const auto* output_arg : graph.GetOutputs()) {
|
||||
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
|
||||
}
|
||||
|
||||
for (const auto* value_info : graph.GetValueInfo()) {
|
||||
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
|
||||
}
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
|
||||
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
|
||||
p_node->ToProto(*node_proto);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapability(
|
||||
const GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) const {
|
||||
|
|
@ -269,7 +242,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
if (dump_subgraphs_) {
|
||||
auto model = graph_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
|
||||
graph_viewer.ToProto(*model_proto->mutable_graph(), false, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
HashValue model_hash;
|
||||
int metadef_id = GenerateMetaDefId(graph_viewer, model_hash);
|
||||
|
|
@ -280,39 +253,34 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
return result;
|
||||
}
|
||||
|
||||
Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
|
||||
Status DNNLExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
//follow from coreml ep's Compile
|
||||
for (const auto* fused_node : fused_nodes) {
|
||||
const auto* func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
const Graph& graph_body = func_body->Body();
|
||||
auto graph_body_viewer = graph_body.CreateGraphViewer();
|
||||
|
||||
for (auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
if (dump_subgraphs_) {
|
||||
auto model = graph_body_viewer->CreateModel(*GetLogger());
|
||||
auto model = graph_body_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), false, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
model_proto->SerializeToOstream(dump);
|
||||
}
|
||||
|
||||
//subgraph
|
||||
auto dnnl_subgraph = std::make_unique<ort_dnnl::DnnlSubgraph>(ort_dnnl::DnnlSubgraph(*graph_body_viewer.get()));
|
||||
subgraphs_.emplace(fused_node->Name(), std::move(dnnl_subgraph));
|
||||
auto dnnl_subgraph = std::make_unique<ort_dnnl::DnnlSubgraph>(ort_dnnl::DnnlSubgraph(graph_body_viewer));
|
||||
subgraphs_.emplace(fused_node.Name(), std::move(dnnl_subgraph));
|
||||
|
||||
//apply transformation to subgraph
|
||||
if (enable_fusion_) {
|
||||
ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node->Name()].get());
|
||||
ort_dnnl::DnnlGraphTransformer().Apply(*subgraphs_[fused_node.Name()].get(), graph_body_viewer);
|
||||
}
|
||||
|
||||
//subgraph primitive
|
||||
auto dnnl_subgraph_primitive = std::make_unique<ort_dnnl::DnnlSubgraphPrimitive>(*subgraphs_[fused_node->Name()].get());
|
||||
auto dnnl_subgraph_primitive = std::make_unique<ort_dnnl::DnnlSubgraphPrimitive>(*subgraphs_[fused_node.Name()].get());
|
||||
{
|
||||
const auto& input_defs = fused_node->InputDefs();
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
std::vector<std::string> onnx_input_names(input_defs.size());
|
||||
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
onnx_input_names[i] = input_defs[i]->Name();
|
||||
|
|
@ -320,7 +288,7 @@ Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
|
|||
dnnl_subgraph_primitive->SetOrderedInputs(std::move(onnx_input_names));
|
||||
}
|
||||
{
|
||||
const auto& output_defs = fused_node->OutputDefs();
|
||||
const auto& output_defs = fused_node.OutputDefs();
|
||||
std::vector<std::string> onnx_output_names(output_defs.size());
|
||||
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
|
||||
onnx_output_names[i] = output_defs[i]->Name();
|
||||
|
|
@ -328,7 +296,7 @@ Status DNNLExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
|
|||
dnnl_subgraph_primitive->SetOrderedOutputs(std::move(onnx_output_names));
|
||||
}
|
||||
|
||||
subgraph_primitives_.emplace(fused_node->Name(), std::move(dnnl_subgraph_primitive));
|
||||
subgraph_primitives_.emplace(fused_node.Name(), std::move(dnnl_subgraph_primitive));
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -67,14 +67,14 @@ void DnnlMatMulInteger::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& nod
|
|||
|
||||
if (has_a_zero_point) {
|
||||
auto zp_A_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
|
||||
auto tensor = node.Input(IN_A_ZERO_POINT);
|
||||
auto& tensor = node.Input(IN_A_ZERO_POINT);
|
||||
auto zp_A_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_A_mem_desc_s32, eng);
|
||||
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_A_mem_s32;
|
||||
}
|
||||
|
||||
if (has_b_zero_point) {
|
||||
auto zp_B_mem_desc_s32 = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
|
||||
auto tensor = node.Input(IN_B_ZERO_POINT);
|
||||
auto& tensor = node.Input(IN_B_ZERO_POINT);
|
||||
auto zp_B_mem_s32 = sp.GetMemoryAndReshape(tensor, zp_B_mem_desc_s32, eng);
|
||||
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_B_mem_s32;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -164,14 +164,14 @@ void DnnlQAttention::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node)
|
|||
|
||||
if (has_input_zero_point) {
|
||||
auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
|
||||
auto tensor = node.Input(INPUT_ZP);
|
||||
auto& tensor = node.Input(INPUT_ZP);
|
||||
auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng);
|
||||
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = zp_mem;
|
||||
}
|
||||
|
||||
if (has_weights_zero_point) {
|
||||
auto zp_mem_desc = dnnl::memory::desc({1}, dnnl::memory::data_type::s32, {1});
|
||||
auto tensor = node.Input(WEIGHTS_ZP);
|
||||
auto& tensor = node.Input(WEIGHTS_ZP);
|
||||
auto zp_mem = sp.GetMemoryAndReshape(tensor, zp_mem_desc, eng);
|
||||
mem_map[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zp_mem;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,31 +10,51 @@ namespace ort_dnnl {
|
|||
DnnlTensor DnnlNode::empty_tensor_ = DnnlTensor("");
|
||||
|
||||
DnnlTensor::DnnlTensor(const NodeArg* arg) {
|
||||
arg_ = arg;
|
||||
if (!arg || !arg->Exists()) {
|
||||
tensor_name_ = "";
|
||||
} else {
|
||||
tensor_name_ = arg->Name();
|
||||
}
|
||||
// because the passed in ort graph will be released after compile
|
||||
// need to save the type/shape in dnnl IR
|
||||
arg_type_ = arg->Type();
|
||||
arg_type_proto_ = ONNX_NAMESPACE::TypeProto::Create();
|
||||
arg_type_proto_->copy_from(arg->TypeAsProto());
|
||||
}
|
||||
|
||||
DnnlTensor::DnnlTensor(std::string name) {
|
||||
tensor_name_ = name;
|
||||
arg_ = nullptr;
|
||||
arg_type_ = nullptr;
|
||||
arg_type_proto_ = nullptr;
|
||||
}
|
||||
|
||||
std::string DnnlTensor::Name() const {
|
||||
return tensor_name_;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* DnnlTensor::GetShape() const{
|
||||
if (arg_type_proto_ == nullptr || arg_type_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (arg_type_proto_->value_case() != ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType) {
|
||||
return nullptr;
|
||||
}
|
||||
auto& tensor_type = arg_type_proto_->tensor_type();
|
||||
if (tensor_type.has_shape()) {
|
||||
return &tensor_type.shape();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
dnnl::memory::dims DnnlTensor::Dim() const {
|
||||
if (arg_ == nullptr) {
|
||||
if (arg_type_proto_ == nullptr || arg_type_ == nullptr) {
|
||||
return dnnl::memory::dims();
|
||||
}
|
||||
auto shape_proto = arg_->Shape();
|
||||
auto* shape_proto = GetShape();
|
||||
// a shape without any information
|
||||
if (shape_proto == nullptr) {
|
||||
LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_->Type() << ": " << arg_->Name();
|
||||
LOGS_DEFAULT(INFO) << "nullptr shape for " << arg_type_ << ": " << tensor_name_;
|
||||
return dnnl::memory::dims();
|
||||
}
|
||||
std::vector<int64_t> shape;
|
||||
|
|
@ -42,7 +62,7 @@ dnnl::memory::dims DnnlTensor::Dim() const {
|
|||
for (const auto& dim : dims) {
|
||||
bool has_dim_value = dim.value_case() == dim.kDimValue;
|
||||
if (!has_dim_value) {
|
||||
LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_->Type() << ": " << arg_->Name();
|
||||
LOGS_DEFAULT(INFO) << "Dynamic shape for " << arg_type_ << ": " << tensor_name_;
|
||||
shape.push_back(DNNL_RUNTIME_DIM_VAL);
|
||||
} else {
|
||||
shape.push_back(dim.dim_value());
|
||||
|
|
@ -57,7 +77,10 @@ dnnl::memory::dims DnnlTensor::Dim() const {
|
|||
}
|
||||
|
||||
dnnl::memory::data_type DnnlTensor::Type() const {
|
||||
auto data_type = arg_->TypeAsProto()->tensor_type().elem_type();
|
||||
if (arg_type_proto_ == nullptr) {
|
||||
ORT_THROW("Invoke DnnlTensor's arg_type_proto_ not initialized yet.");
|
||||
}
|
||||
auto data_type = arg_type_proto_->tensor_type().elem_type();
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
|
||||
return dnnl::memory::data_type::undef;
|
||||
|
|
@ -123,7 +146,7 @@ void DnnlTensor::RemoveConsumer(const DnnlNodeArg& arg) {
|
|||
}
|
||||
|
||||
DnnlNode::DnnlNode(const Node* node) {
|
||||
onnx_node_ = node;
|
||||
since_version_ = node->SinceVersion();
|
||||
name_ = node->Name();
|
||||
op_type_ = node->OpType();
|
||||
attr_->insert(node->GetAttributes());
|
||||
|
|
@ -176,11 +199,11 @@ NodeAttributes& DnnlNode::Attributes() {
|
|||
}
|
||||
|
||||
int DnnlNode::SinceVersion() {
|
||||
return onnx_node_->SinceVersion();
|
||||
return since_version_;
|
||||
}
|
||||
|
||||
DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) {
|
||||
Build();
|
||||
DnnlSubgraph::DnnlSubgraph(const GraphViewer& graph_viewer) {
|
||||
Build(graph_viewer);
|
||||
is_dynamic_ = false;
|
||||
for (auto input : GetDnnlInputs()) {
|
||||
if (input->IsDynamic()) {
|
||||
|
|
@ -309,19 +332,11 @@ void DnnlSubgraph::AddNode(std::unique_ptr<DnnlNode> new_node) {
|
|||
dnnl_nodes_.back()->Index() = index;
|
||||
}
|
||||
|
||||
bool DnnlSubgraph::GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value) {
|
||||
return graph_viewer_.GetInitializedTensor(arg_name, value);
|
||||
}
|
||||
|
||||
bool DnnlSubgraph::IsConstantInitializer(const std::string& arg_name, bool check_outer_scope) {
|
||||
return graph_viewer_.IsConstantInitializer(arg_name, check_outer_scope);
|
||||
}
|
||||
|
||||
void DnnlSubgraph::Build() {
|
||||
void DnnlSubgraph::Build(const GraphViewer& graph_viewer) {
|
||||
//establish nodes, tensors and nodeargs
|
||||
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer_.GetNode(node_indices[i]));
|
||||
const auto* node(graph_viewer.GetNode(node_indices[i]));
|
||||
AddNode(std::make_unique<DnnlNode>(node));
|
||||
auto dnnl_node = dnnl_nodes_.back().get();
|
||||
std::vector<DnnlTensor*> inputs;
|
||||
|
|
@ -361,15 +376,15 @@ void DnnlSubgraph::Build() {
|
|||
//graph inputs including initializers and outputs can be deleted by graph transformation (eg, gelu fusion)
|
||||
//delete unneeded inputs don't affect onnxruntime passing them as input data handle
|
||||
//delete unneeded outputs will cause ep to output to fewer data handles then expected
|
||||
for (const auto* node_arg : graph_viewer_.GetInputsIncludingInitializers()) {
|
||||
for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) {
|
||||
inputs_.push_back(dnnl_tensors_[node_arg->Name()].get());
|
||||
}
|
||||
|
||||
for (const auto* node_arg : graph_viewer_.GetOutputs()) {
|
||||
for (const auto* node_arg : graph_viewer.GetOutputs()) {
|
||||
outputs_.push_back(dnnl_tensors_[node_arg->Name()].get());
|
||||
}
|
||||
|
||||
for (auto& initializer : graph_viewer_.GetAllInitializedTensors()) {
|
||||
for (auto& initializer : graph_viewer.GetAllInitializedTensors()) {
|
||||
auto& name = initializer.first;
|
||||
initializers_.push_back(dnnl_tensors_[name].get());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,8 +55,12 @@ class DnnlTensor {
|
|||
void RemoveConsumer(const DnnlNodeArg& arg);
|
||||
|
||||
private:
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* GetShape() const;
|
||||
|
||||
std::string tensor_name_;
|
||||
const NodeArg* arg_;
|
||||
ONNX_NAMESPACE::DataType arg_type_;
|
||||
std::unique_ptr<ONNX_NAMESPACE::TypeProto> arg_type_proto_;
|
||||
//a tensor can have no producer (input.initializer) or no consumer (output for subgraph)
|
||||
DnnlNodeArg producer_;
|
||||
std::vector<DnnlNodeArg> consumers_;
|
||||
|
|
@ -79,7 +83,7 @@ class DnnlNode {
|
|||
int SinceVersion();
|
||||
|
||||
private:
|
||||
const Node* onnx_node_ = nullptr;
|
||||
int since_version_;
|
||||
std::vector<DnnlTensor*> inputs_;
|
||||
std::vector<DnnlTensor*> outputs_;
|
||||
static DnnlTensor empty_tensor_;
|
||||
|
|
@ -101,7 +105,7 @@ class DnnlSubgraph {
|
|||
std::vector<DnnlTensor*> GetDnnlOutputs();
|
||||
std::vector<DnnlTensor*> GetDnnlInitializers();
|
||||
// build the subgraph IR
|
||||
void Build();
|
||||
void Build(const GraphViewer& graph_viewer);
|
||||
//check whether the subgraph is dynamic
|
||||
void TopoSort();
|
||||
bool IsDynamic();
|
||||
|
|
@ -110,9 +114,6 @@ class DnnlSubgraph {
|
|||
void AddTensor(std::unique_ptr<DnnlTensor> new_tensor);
|
||||
void RemoveTensor(const std::string& tensor_name);
|
||||
|
||||
bool GetInitializedTensor(const std::string& arg_name, const ONNX_NAMESPACE::TensorProto*& value);
|
||||
bool IsConstantInitializer(const std::string& arg_name, bool check_outer_scope);
|
||||
|
||||
private:
|
||||
//graph owns all nodes
|
||||
std::vector<std::unique_ptr<DnnlNode>> dnnl_nodes_;
|
||||
|
|
@ -122,7 +123,6 @@ class DnnlSubgraph {
|
|||
std::vector<DnnlTensor*> inputs_;
|
||||
std::vector<DnnlTensor*> outputs_; //output should never get deleted from graph transformation
|
||||
std::vector<DnnlTensor*> initializers_;
|
||||
const GraphViewer& graph_viewer_;
|
||||
bool is_dynamic_;
|
||||
};
|
||||
} // namespace ort_dnnl
|
||||
|
|
|
|||
|
|
@ -446,7 +446,7 @@ bool DnnlSubgraphPrimitive::HasMemory(std::string memory_name, dnnl::memory::des
|
|||
return false;
|
||||
}
|
||||
|
||||
void DnnlSubgraphPrimitive::SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) {
|
||||
void DnnlSubgraphPrimitive::SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output, bool is_scalar) {
|
||||
if (always_copy_output) {
|
||||
outputs_are_always_copied_.insert(tensor.Name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class DnnlSubgraphPrimitive {
|
|||
//set memory to a tensor (output)
|
||||
//if always_copy_output is true a copy of the memory will be made when the output is leaving the subgraph.
|
||||
//is_scalar is true to indicate a scalar output in order to allocate the correct onnxruntime output buffer
|
||||
void SetMemory(DnnlTensor tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false);
|
||||
void SetMemory(const DnnlTensor& tensor, dnnl::memory mem, bool always_copy_output = false, bool is_scalar = false);
|
||||
void SetMemory(std::string memory_name, dnnl::memory mem);
|
||||
void SetInitializer(std::string memory_name, dnnl::memory mem);
|
||||
dnnl::memory::desc GetOutputInfo(std::string name);
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ namespace onnxruntime {
|
|||
namespace ort_dnnl {
|
||||
|
||||
//apply all transformation rules in order
|
||||
void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph) {
|
||||
void DnnlGraphTransformer::Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
|
||||
ConvRelu(subgraph);
|
||||
MatMulAdd(subgraph);
|
||||
Gelu(subgraph);
|
||||
FastGelu(subgraph);
|
||||
RemoveMatMulIntegerZP(subgraph);
|
||||
Gelu(subgraph, onnx_subgraph_viewer);
|
||||
FastGelu(subgraph, onnx_subgraph_viewer);
|
||||
RemoveMatMulIntegerZP(subgraph, onnx_subgraph_viewer);
|
||||
}
|
||||
|
||||
//resolve a fusion by replacing old_indices nodes with a new_node
|
||||
|
|
@ -162,13 +162,13 @@ bool IsScalar(const DnnlTensor& input_arg) {
|
|||
return dim_size == 0 || (dim_size == 1 && dim[0] == 1);
|
||||
}
|
||||
|
||||
bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value) {
|
||||
bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value) {
|
||||
if (!IsScalar(input_arg)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (!subgraph.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
|
||||
if (!onnx_subgraph_viewer.GetInitializedTensor(input_arg.Name(), tensor_proto)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ DnnlNode* FirstParentByType(DnnlNode* node, const std::string& parent_type) {
|
|||
After Fusion:
|
||||
[root]--> Gelu ==>
|
||||
*/
|
||||
void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
|
||||
void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
|
||||
static int gelu_index = 0;
|
||||
//traverse with max index as there will be empty nodes due to fusion
|
||||
size_t max_index = subgraph.GetMaxNodeIndex();
|
||||
|
|
@ -249,8 +249,8 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
|
|||
// Check second input is sqrt(2)
|
||||
// Some Bert models uses this approximation of SQRT2 in the Gelu function
|
||||
float approximated_sqrt_two = 1.4142099618911743f;
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), approximated_sqrt_two) &&
|
||||
!IsInitilizedWithExpectedValue(subgraph, div_node->Input(1), static_cast<float>(M_SQRT2))) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), approximated_sqrt_two) &&
|
||||
!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, div_node->Input(1), static_cast<float>(M_SQRT2))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -275,7 +275,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
|
|||
}
|
||||
|
||||
bool is_add_input0 = add_node->Input(0).Name() == erf_node->Output(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
|
|||
if (!(is_mul2_input0 ^ is_mul2_input1)) {
|
||||
is_pattern_1 = false;
|
||||
}
|
||||
if (is_pattern_1 && !IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) {
|
||||
if (is_pattern_1 && !IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), 0.5f)) {
|
||||
is_pattern_1 = false;
|
||||
}
|
||||
if (is_pattern_1 && !IsNodeFusable(subgraph, mul2_node)) {
|
||||
|
|
@ -329,7 +329,7 @@ void DnnlGraphTransformer::Gelu(DnnlSubgraph& subgraph) {
|
|||
ORT_THROW("Invalid Mul node");
|
||||
}
|
||||
bool is_mul2_first_input = mul2_node->Input(0).Name() == mul1_node->Output(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_first_input ? 1 : 0), 0.5f)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
@ -369,14 +369,14 @@ The formula corresponding to Gelu activation subgraph :
|
|||
|
||||
where x is the input.
|
||||
*/
|
||||
void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph) {
|
||||
void DnnlGraphTransformer::FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
|
||||
static int fastgelu_index = 0;
|
||||
//traverse with max index as there will be empty nodes due to fusion
|
||||
size_t max_index = subgraph.GetMaxNodeIndex();
|
||||
for (size_t index = 0; index < max_index; index++) {
|
||||
auto dnnl_node = subgraph.GetDnnlNode(index);
|
||||
if (!FastGeluFirstFormula(subgraph, dnnl_node, fastgelu_index)) {
|
||||
FastGeluSecondFormula(subgraph, dnnl_node, fastgelu_index);
|
||||
if (!FastGeluFirstFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index)) {
|
||||
FastGeluSecondFormula(subgraph, onnx_subgraph_viewer, dnnl_node, fastgelu_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -389,7 +389,7 @@ The formula corresponding to Gelu activation subgraph :
|
|||
|
||||
where x is the input.
|
||||
*/
|
||||
bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* mul1_node, int& fastgelu_index) {
|
||||
bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* mul1_node, int& fastgelu_index) {
|
||||
std::vector<size_t> gelu_indices;
|
||||
//----------mul(0.44715)------------------
|
||||
if (mul1_node == nullptr || mul1_node->OpType() != "Mul") {
|
||||
|
|
@ -398,7 +398,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
|
|||
int32_t mul1_input_index = -1;
|
||||
const float mul_val = 0.044715f;
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(i), mul_val)) {
|
||||
if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(i), mul_val)) {
|
||||
mul1_input_index = i;
|
||||
break;
|
||||
}
|
||||
|
|
@ -424,7 +424,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
|
|||
return false;
|
||||
}
|
||||
bool is_add_input0 = mul2_node->Output(0).Name() == add1_node->Input(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add1_node->Input(is_add_input0 ? 1 : 0), 1.0f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -450,7 +450,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
|
|||
int32_t mul4_input_index = -1;
|
||||
const float mul4_val = 0.7978845834732056f;
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(i), mul4_val)) {
|
||||
if (IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(i), mul4_val)) {
|
||||
mul4_input_index = i;
|
||||
break;
|
||||
}
|
||||
|
|
@ -464,7 +464,7 @@ bool DnnlGraphTransformer::FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode
|
|||
|
||||
auto tanh_node = mul3_node->Output(0).GetConsumers()[0].GetNode();
|
||||
int32_t x_input_index = (mul1_input_index == 0) ? 1 : 0;
|
||||
if (FastGeluFormulaCommon(subgraph, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) {
|
||||
if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, mul1_node, x_input_index, tanh_node, gelu_indices, fastgelu_index)) {
|
||||
if (debug_log_) {
|
||||
LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (first formula)";
|
||||
}
|
||||
|
|
@ -481,15 +481,15 @@ The formula corresponding to Gelu activation subgraph :
|
|||
|
||||
where x is the input.
|
||||
*/
|
||||
void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* pow_node, int& fastgelu_index) {
|
||||
void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* pow_node, int& fastgelu_index) {
|
||||
std::vector<size_t> gelu_indices;
|
||||
//---------Pow-------------------
|
||||
if (pow_node == nullptr || pow_node->OpType() != "Pow") {
|
||||
return;
|
||||
}
|
||||
|
||||
auto pow_exponent = pow_node->Input(1);
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, pow_exponent, 3.0f)) {
|
||||
auto& pow_exponent = pow_node->Input(1);
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, pow_exponent, 3.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -505,7 +505,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
|
|||
|
||||
float fastgelu_muliplyer = 0.044714998453855515f;
|
||||
bool is_mul1_input0 = pow_node->Output(0).Name() == mul1_node->Input(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul1_node->Input(is_mul1_input0 ? 1 : 0), fastgelu_muliplyer)) {
|
||||
return;
|
||||
}
|
||||
if (!IsNodeFusable(subgraph, mul1_node)) {
|
||||
|
|
@ -531,7 +531,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
|
|||
// constant is sqrt(2/pi)
|
||||
float fastgelu_sqrt_2_div_pi = 0.7978845834732056f;
|
||||
bool is_mul2_input0 = add1_node->Output(0).Name() == mul2_node->Input(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, mul2_node->Input(is_mul2_input0 ? 1 : 0), fastgelu_sqrt_2_div_pi)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -543,7 +543,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
|
|||
//----------Tanh------------------
|
||||
auto tanh_node = mul2_node->Output(0).GetConsumers()[0].GetNode();
|
||||
// since the first node is pow the x_input_index is always 0
|
||||
if (FastGeluFormulaCommon(subgraph, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) {
|
||||
if (FastGeluFormulaCommon(subgraph, onnx_subgraph_viewer, pow_node, 0, tanh_node, gelu_indices, fastgelu_index)) {
|
||||
if (debug_log_) {
|
||||
LOGS_DEFAULT(ERROR) << "FastGelu fusion found [" << fastgelu_index << "] (second formula)";
|
||||
}
|
||||
|
|
@ -559,7 +559,7 @@ void DnnlGraphTransformer::FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNod
|
|||
x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))),
|
||||
where x is the input.
|
||||
*/
|
||||
bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index) {
|
||||
bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index) {
|
||||
//----------Tanh------------------
|
||||
if (tanh_node == nullptr || tanh_node->OpType() != "Tanh") {
|
||||
return false;
|
||||
|
|
@ -575,7 +575,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod
|
|||
return false;
|
||||
}
|
||||
bool is_add2_input0 = tanh_node->Output(0).Name() == add2_node->Input(0).Name();
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, add2_node->Input(is_add2_input0 ? 1 : 0), 1.0f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -607,7 +607,7 @@ bool DnnlGraphTransformer::FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNod
|
|||
if (!(is_mul_input0 ^ is_mul_input1)) {
|
||||
return false;
|
||||
}
|
||||
if (!IsInitilizedWithExpectedValue(subgraph, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) {
|
||||
if (!IsInitilizedWithExpectedValue(onnx_subgraph_viewer, prev_mul4_node->Input(is_mul_input0 ? 1 : 0), 0.5f)) {
|
||||
return false;
|
||||
}
|
||||
if (!IsNodeFusable(subgraph, prev_mul4_node)) {
|
||||
|
|
@ -748,7 +748,7 @@ void DnnlGraphTransformer::MatMulAdd(DnnlSubgraph& subgraph) {
|
|||
}
|
||||
}
|
||||
|
||||
void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) {
|
||||
void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer) {
|
||||
size_t max_index = subgraph.GetMaxNodeIndex();
|
||||
for (size_t index = 0; index < max_index; index++) {
|
||||
auto dnnl_node = subgraph.GetDnnlNode(index);
|
||||
|
|
@ -763,9 +763,9 @@ void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph) {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto b_zero_point = dnnl_node->Input(3);
|
||||
auto& b_zero_point = dnnl_node->Input(3);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (!subgraph.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) {
|
||||
if (!onnx_subgraph_viewer.GetInitializedTensor(b_zero_point.Name(), tensor_proto)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ namespace ort_dnnl {
|
|||
|
||||
class DnnlGraphTransformer {
|
||||
public:
|
||||
void Apply(DnnlSubgraph& subgraph);
|
||||
// The passed in onnx subgraph viewer is only valid during "Compile" phase,
|
||||
// so keep a reference to that onnx subgraph in DnnlSubgraph is risky.
|
||||
// passed in the onnx subgraph viewer explicitly to make sure we manage the lifetime correctly.
|
||||
void Apply(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer_);
|
||||
DnnlGraphTransformer() {
|
||||
const std::string debug_log_env = onnxruntime::GetEnvironmentVar("ORT_DNNL_DEBUG_LOG");
|
||||
if (!debug_log_env.empty()) {
|
||||
|
|
@ -18,15 +21,15 @@ class DnnlGraphTransformer {
|
|||
}
|
||||
|
||||
private:
|
||||
void Gelu(DnnlSubgraph& subgraph);
|
||||
void FastGelu(DnnlSubgraph& subgraph);
|
||||
bool FastGeluFirstFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index);
|
||||
void FastGeluSecondFormula(DnnlSubgraph& subgraph, DnnlNode* node, int& fastgelu_index);
|
||||
bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index);
|
||||
bool IsInitilizedWithExpectedValue(DnnlSubgraph& subgraph, DnnlTensor& input_arg, float expected_value);
|
||||
void Gelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
|
||||
void FastGelu(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
|
||||
bool FastGeluFirstFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index);
|
||||
void FastGeluSecondFormula(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* node, int& fastgelu_index);
|
||||
bool FastGeluFormulaCommon(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlNode* gelu_start_node, int32_t x_input_index, DnnlNode* tanh_node, std::vector<size_t>& gelu_indices, int& fastgelu_index);
|
||||
bool IsInitilizedWithExpectedValue(const onnxruntime::GraphViewer& onnx_subgraph_viewer, DnnlTensor& input_arg, float expected_value);
|
||||
void ConvRelu(DnnlSubgraph& subgraph);
|
||||
void MatMulAdd(DnnlSubgraph& subgraph);
|
||||
void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph);
|
||||
void RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const onnxruntime::GraphViewer& onnx_subgraph_viewer);
|
||||
// This function checks a few things
|
||||
// - the node in question has a single output
|
||||
// - The output of the node is only consumed by a one other node
|
||||
|
|
|
|||
|
|
@ -660,34 +660,6 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
|
|||
return true;
|
||||
}
|
||||
|
||||
// Convert GraphViewer graph to GraphProto
|
||||
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
|
||||
for (const auto* input_arg : graph.GetInputs()) {
|
||||
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
|
||||
}
|
||||
|
||||
// Add all graph's initializers to the subgraph
|
||||
const auto& init_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& tensor : init_tensors) {
|
||||
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
|
||||
}
|
||||
|
||||
for (const auto* output_arg : graph.GetOutputs()) {
|
||||
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
|
||||
}
|
||||
|
||||
for (const auto* value_info : graph.GetValueInfo()) {
|
||||
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
|
||||
}
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
|
||||
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
|
||||
p_node->ToProto(*node_proto);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const {
|
||||
std::unordered_set<size_t> node_set;
|
||||
node_set.reserve(graph_nodes_index.size());
|
||||
|
|
@ -913,7 +885,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
auto model = graph_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
|
||||
graph_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::string onnx_string_buffer;
|
||||
model_proto->SerializeToString(onnx_string_buffer);
|
||||
|
|
@ -1012,30 +984,24 @@ bool get_input_output_names(const GraphViewer& graph,
|
|||
return no_input_shape;
|
||||
}
|
||||
|
||||
Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
migraphx::onnx_options options;
|
||||
bool no_input_shape = false;
|
||||
for (const auto& fused_node : fused_nodes) {
|
||||
for (const auto& fused_node_graph : fused_nodes) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
// map parameter input name to index
|
||||
std::unordered_map<std::string, std::size_t> input_name_index;
|
||||
const auto& input_defs = fused_node->InputDefs();
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
input_name_index.reserve(input_defs.size());
|
||||
for (std::size_t i = 0; i < input_defs.size(); ++i) {
|
||||
input_name_index[input_defs[i]->Name()] = i;
|
||||
}
|
||||
|
||||
// Reconstruct graph proto from fused node's function body
|
||||
const auto* func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
|
||||
const Graph& graph_body = func_body->Body();
|
||||
auto graph_body_viewer = graph_body.CreateGraphViewer();
|
||||
auto model = graph_body_viewer->CreateModel(*GetLogger());
|
||||
auto model = graph_body_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::string onnx_string_buffer;
|
||||
model_proto->SerializeToString(onnx_string_buffer);
|
||||
|
|
@ -1069,11 +1035,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
|
|||
}
|
||||
|
||||
// compile the program
|
||||
map_progs_[fused_node->Name()] = prog;
|
||||
map_progs_[fused_node.Name()] = prog;
|
||||
|
||||
map_onnx_string_[fused_node->Name()] = onnx_string_buffer;
|
||||
map_input_index_[fused_node->Name()] = input_name_index;
|
||||
map_no_input_shape_[fused_node->Name()] = no_input_shape;
|
||||
map_onnx_string_[fused_node.Name()] = onnx_string_buffer;
|
||||
map_input_index_[fused_node.Name()] = input_name_index;
|
||||
map_no_input_shape_[fused_node.Name()] = no_input_shape;
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) const override;
|
||||
|
||||
Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
|
||||
|
|
|
|||
|
|
@ -23,9 +23,6 @@ class NnapiExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph_view,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
||||
// we implement the Compile that takes FusedNodeAndGraph instances
|
||||
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
|
|
|||
|
|
@ -118,6 +118,13 @@ class NupharExecutionProvider : public IExecutionProvider {
|
|||
return iter->second.get();
|
||||
}
|
||||
|
||||
onnxruntime::IExecutionProvider::FusionStyle GetFusionStyle() const override {
|
||||
// existing EPs use this mode so default to it.
|
||||
// newer EPs that can use the cheaper approach, or need to run in a minimal build, should override to return
|
||||
// FilteredGraphViewer
|
||||
return onnxruntime::IExecutionProvider::FusionStyle::Function;
|
||||
}
|
||||
|
||||
private:
|
||||
void CreateTVMTarget();
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ void BackendManager::ReleaseGlobalContext() {
|
|||
g_global_context.reset();
|
||||
}
|
||||
|
||||
BackendManager::BackendManager(const Node* fused_node, const logging::Logger& logger) {
|
||||
BackendManager::BackendManager(const onnxruntime::Node& fused_node,
|
||||
const onnxruntime::GraphViewer& subgraph,
|
||||
const logging::Logger& logger) {
|
||||
auto prec_str = GetGlobalContext().precision_str;
|
||||
if (prec_str == "FP32") {
|
||||
subgraph_context_.precision = InferenceEngine::Precision::FP32;
|
||||
|
|
@ -38,14 +40,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
|
|||
|
||||
// Save the indexes of graph inputs among fused_node's inputDefs
|
||||
// (which also contains initializers).
|
||||
auto node_input_defs = fused_node->InputDefs();
|
||||
auto node_input_defs = fused_node.InputDefs();
|
||||
int i = 0;
|
||||
for (auto idef : node_input_defs) {
|
||||
subgraph_context_.input_names.insert({idef->Name(), i});
|
||||
i++;
|
||||
}
|
||||
|
||||
auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs();
|
||||
auto graph_inputs = subgraph.GetInputs();
|
||||
for (auto input : graph_inputs) {
|
||||
if (GetGlobalContext().device_type.find("MYRIAD") != std::string::npos) {
|
||||
auto shape = input->Shape();
|
||||
|
|
@ -63,14 +65,14 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
|
|||
subgraph_context_.input_indexes.push_back(index);
|
||||
}
|
||||
|
||||
auto graph_outputs_defs = fused_node->OutputDefs();
|
||||
auto graph_outputs_defs = fused_node.OutputDefs();
|
||||
i = 0;
|
||||
for (auto output_def : graph_outputs_defs) {
|
||||
subgraph_context_.output_names.insert({output_def->Name(), i});
|
||||
i++;
|
||||
}
|
||||
subgraph_context_.subgraph_name = fused_node->Name();
|
||||
model_proto_ = GetModelProtoFromFusedNode(fused_node, logger);
|
||||
subgraph_context_.subgraph_name = fused_node.Name();
|
||||
model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
|
||||
|
||||
if (ModelHasBatchedInputs(*model_proto_) &&
|
||||
GetGlobalContext().is_wholly_supported_graph &&
|
||||
|
|
@ -81,7 +83,7 @@ BackendManager::BackendManager(const Node* fused_node, const logging::Logger& lo
|
|||
concrete_backend_ = BackendFactory::MakeBackend(*model_copy, GetGlobalContext(), subgraph_context_);
|
||||
subgraph_context_.has_dynamic_input_shape = false;
|
||||
|
||||
} else if (ModelHasSymbolicInputDims(fused_node)) {
|
||||
} else if (ModelHasSymbolicInputDims(subgraph)) {
|
||||
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims. Defering backend initialization";
|
||||
subgraph_context_.has_dynamic_input_shape = true;
|
||||
} else {
|
||||
|
|
@ -123,9 +125,9 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod
|
|||
return has_batched_inputs;
|
||||
}
|
||||
|
||||
bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const {
|
||||
bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const {
|
||||
bool has_sym_dims = false;
|
||||
auto graph_inputs = fused_node->GetFunctionBody()->Body().GetInputs();
|
||||
auto graph_inputs = subgraph.GetInputs();
|
||||
for (auto input : graph_inputs) {
|
||||
if (input->Shape() == nullptr) {
|
||||
has_sym_dims = true;
|
||||
|
|
@ -145,25 +147,23 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::Node* fused_no
|
|||
}
|
||||
|
||||
std::unique_ptr<ONNX_NAMESPACE::ModelProto>
|
||||
BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node,
|
||||
BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
|
||||
const onnxruntime::GraphViewer& subgraph,
|
||||
const logging::Logger& logger) const {
|
||||
const auto* node_function = fused_node->GetFunctionBody();
|
||||
const std::string& name = fused_node->Name();
|
||||
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", name);
|
||||
|
||||
const onnxruntime::Graph& node_subgraph = node_function->Body();
|
||||
auto model = node_subgraph.CreateGraphViewer()->CreateModel(logger);
|
||||
auto model = subgraph.CreateModel(logger);
|
||||
|
||||
auto model_proto = model->ToProto();
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
*model_proto->mutable_graph() = *node_subgraph.ToGraphProto();
|
||||
subgraph.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
|
||||
#ifndef NDEBUG
|
||||
if (openvino_ep::backend_utils::IsDebugEnabled()) {
|
||||
const std::string& name = fused_node.Name();
|
||||
std::fstream dump(name + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
model_proto->SerializeToOstream(dump);
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(fused_node);
|
||||
#endif
|
||||
|
||||
return model_proto;
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ namespace openvino_ep {
|
|||
// Singleton class that manages all the backends
|
||||
class BackendManager {
|
||||
public:
|
||||
BackendManager(const onnxruntime::Node* fused_node, const logging::Logger& logger);
|
||||
BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger);
|
||||
void Compute(Ort::CustomOpApi api, OrtKernelContext* context);
|
||||
void ShutdownBackendManager();
|
||||
static GlobalContext& GetGlobalContext();
|
||||
|
|
@ -21,8 +21,8 @@ class BackendManager {
|
|||
|
||||
private:
|
||||
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(
|
||||
const onnxruntime::Node* fused_node, const logging::Logger& logger) const;
|
||||
bool ModelHasSymbolicInputDims(const onnxruntime::Node* fused_node) const;
|
||||
const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const;
|
||||
bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const;
|
||||
bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const;
|
||||
|
||||
std::shared_ptr<ONNX_NAMESPACE::ModelProto>
|
||||
|
|
|
|||
|
|
@ -131,18 +131,21 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const
|
|||
}
|
||||
|
||||
common::Status OpenVINOExecutionProvider::Compile(
|
||||
const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto& fused_node : fused_nodes) {
|
||||
NodeComputeInfo compute_info;
|
||||
|
||||
#if defined (OV_API_20)
|
||||
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true;
|
||||
# else
|
||||
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false;
|
||||
#endif
|
||||
for (const auto& fused_node_graph : fused_nodes) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
|
||||
std::shared_ptr<openvino_ep::BackendManager> backend_manager = std::make_shared<openvino_ep::BackendManager>(fused_node, *GetLogger());
|
||||
NodeComputeInfo compute_info;
|
||||
|
||||
#if defined(OV_API_20)
|
||||
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true;
|
||||
#else
|
||||
openvino_ep::BackendManager::GetGlobalContext().use_api_2 = false;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<openvino_ep::BackendManager> backend_manager = std::make_shared<openvino_ep::BackendManager>(fused_node, graph_body_viewer, *GetLogger());
|
||||
|
||||
compute_info.create_state_func =
|
||||
[backend_manager](ComputeContext* context, FunctionState* state) {
|
||||
|
|
|
|||
|
|
@ -151,8 +151,8 @@ class OpenVINOExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) const override;
|
||||
|
||||
Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
const void* GetExecutionHandle() const noexcept override {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -256,30 +256,24 @@ RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
return result;
|
||||
}
|
||||
|
||||
common::Status RknpuExecutionProvider::Compile(
|
||||
const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto* fused_node : fused_nodes) {
|
||||
// Reconstruct graph proto from fused node's function body
|
||||
const auto* func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body) {
|
||||
return common::Status(common::ONNXRUNTIME,
|
||||
common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
const Graph& graph_body = func_body->Body();
|
||||
onnxruntime::Model model(graph_body.Name(), true, ModelMetaData(),
|
||||
common::Status RknpuExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
onnxruntime::Model model(graph_body_viewer.Name(), true, ModelMetaData(),
|
||||
PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
graph_body.DomainToVersionMap(),
|
||||
graph_body_viewer.DomainToVersionMap(),
|
||||
std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
*GetLogger());
|
||||
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
|
||||
*(model_proto.mutable_graph()) = graph_body.ToGraphProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
// Build map from input name to its index in input definitions
|
||||
std::unordered_map<std::string, int> input_map;
|
||||
const auto& input_defs = fused_node->InputDefs();
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
input_map.reserve(input_defs.size());
|
||||
for (int i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
input_map[input_defs[i]->Name()] = i;
|
||||
|
|
@ -287,15 +281,15 @@ common::Status RknpuExecutionProvider::Compile(
|
|||
|
||||
// Build map from output name to its index in output definitions
|
||||
std::unordered_map<std::string, int> output_map;
|
||||
const auto& output_defs = fused_node->OutputDefs();
|
||||
const auto& output_defs = fused_node.OutputDefs();
|
||||
output_map.reserve(output_defs.size());
|
||||
for (int i = 0, end = output_defs.size(); i < end; ++i) {
|
||||
output_map[output_defs[i]->Name()] = i;
|
||||
}
|
||||
|
||||
model_proto_[fused_node->Name()] = model_proto;
|
||||
input_info_[fused_node->Name()] = input_map;
|
||||
output_info_[fused_node->Name()] = output_map;
|
||||
model_proto_[fused_node.Name()] = model_proto;
|
||||
input_info_[fused_node.Name()] = input_map;
|
||||
output_info_[fused_node.Name()] = output_map;
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [&](ComputeContext* context,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class RknpuExecutionProvider : public IExecutionProvider {
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -291,17 +291,11 @@ std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider::GetCapabilit
|
|||
const std::vector<const KernelRegistry*>& kernel_registries) const {
|
||||
return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_registries);
|
||||
}
|
||||
|
||||
// !!! This API will be deprecated soon.
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
return g_host->IExecutionProvider__Compile(this, fused_nodes, node_compute_funcs);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::string& dll_path) {
|
||||
return g_host->IExecutionProvider__Compile(this, fused_nodes, dll_path);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs);
|
||||
|
|
|
|||
|
|
@ -218,8 +218,9 @@ struct ProviderHost {
|
|||
virtual void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) = 0;
|
||||
virtual std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) = 0;
|
||||
//!!! This API will be deprecated soon
|
||||
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::vector<NodeComputeInfo>& node_compute_funcs) = 0;
|
||||
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::string& dll_path) = 0;
|
||||
|
||||
virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs, std::vector<NodeComputeInfo>& node_compute_funcs) = 0;
|
||||
|
||||
virtual int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) = 0;
|
||||
|
|
@ -286,6 +287,8 @@ struct ProviderHost {
|
|||
#endif
|
||||
|
||||
// TypeProto
|
||||
virtual std::unique_ptr<ONNX_NAMESPACE::TypeProto> TypeProto__construct() = 0;
|
||||
virtual void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) = 0;
|
||||
virtual const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) = 0;
|
||||
|
||||
|
|
@ -368,6 +371,7 @@ struct ProviderHost {
|
|||
virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0;
|
||||
virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0;
|
||||
virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0;
|
||||
virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0;
|
||||
|
||||
virtual bool TensorProto_DataType_IsValid(int value) = 0;
|
||||
|
||||
|
|
@ -670,6 +674,8 @@ struct ProviderHost {
|
|||
virtual const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0;
|
||||
virtual const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0;
|
||||
|
||||
virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0;
|
||||
|
||||
// Path
|
||||
virtual PathString Path__ToPathString(const Path* p) noexcept = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -159,6 +159,8 @@ struct TensorProto final {
|
|||
|
||||
static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); }
|
||||
|
||||
void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); }
|
||||
|
||||
TensorProto() = delete;
|
||||
TensorProto(const TensorProto&) = delete;
|
||||
};
|
||||
|
|
@ -240,6 +242,8 @@ struct TypeProto_Sequence final {
|
|||
};
|
||||
|
||||
struct TypeProto final {
|
||||
static std::unique_ptr<TypeProto> Create() { return g_host->TypeProto__construct(); }
|
||||
|
||||
const TypeProto_Tensor& tensor_type() const { return g_host->TypeProto__tensor_type(this); }
|
||||
TypeProto_Tensor* mutable_tensor_type() { return g_host->TypeProto__mutable_tensor_type(this); }
|
||||
|
||||
|
|
@ -268,7 +272,10 @@ struct TypeProto final {
|
|||
|
||||
ValueCase value_case() const { return ValueCase(g_host->TypeProto__value_case(this)); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(TypeProto)
|
||||
void copy_from(const TypeProto* other) { return g_host->TypeProto__CopyFrom(this, other); }
|
||||
|
||||
TypeProto() = delete;
|
||||
TypeProto(const TypeProto&) = delete;
|
||||
};
|
||||
|
||||
struct ValueInfoProto final {
|
||||
|
|
@ -705,6 +712,8 @@ struct GraphViewer final {
|
|||
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); }
|
||||
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); }
|
||||
|
||||
void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); }
|
||||
|
||||
GraphViewer() = delete;
|
||||
GraphViewer(const GraphViewer&) = delete;
|
||||
void operator=(const GraphViewer&) = delete;
|
||||
|
|
|
|||
|
|
@ -516,34 +516,6 @@ Status TensorrtExecutionProvider::SetComputeStream(void* stream) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convert GraphViewer graph to GraphProto
|
||||
void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) {
|
||||
for (const auto* input_arg : graph.GetInputs()) {
|
||||
*(graph_proto.mutable_input()->Add()) = input_arg->ToProto();
|
||||
}
|
||||
|
||||
// Add all graph's initializers to the subgraph
|
||||
const auto& init_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& tensor : init_tensors) {
|
||||
*(graph_proto.mutable_initializer()->Add()) = *(tensor.second);
|
||||
}
|
||||
|
||||
for (const auto* output_arg : graph.GetOutputs()) {
|
||||
*(graph_proto.mutable_output()->Add()) = output_arg->ToProto();
|
||||
}
|
||||
|
||||
for (const auto* value_info : graph.GetValueInfo()) {
|
||||
*(graph_proto.mutable_value_info()->Add()) = value_info->ToProto();
|
||||
}
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
|
||||
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{graph.GetNode(node_idx)};
|
||||
p_node->ToProto(*node_proto);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const {
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<size_t> node_set;
|
||||
|
|
@ -799,7 +771,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
auto graph_viewer = graph_build.CreateGraphViewer();
|
||||
auto model = graph_viewer->CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
ToGraphProtoInternal(*graph_viewer, *model_proto->mutable_graph());
|
||||
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
std::string string_buf;
|
||||
|
|
@ -982,19 +954,19 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
|
||||
// Consolidate supported node list
|
||||
if (supported_nodes_vector.size() > 1) {
|
||||
nodes_vector.clear();
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
if (!group.first.empty()) {
|
||||
nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
|
||||
}
|
||||
}
|
||||
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
|
||||
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
|
||||
supported_nodes_vector = consolidated_supported_nodes_vector;
|
||||
nodes_vector.clear();
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
if (!group.first.empty()) {
|
||||
nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end());
|
||||
}
|
||||
}
|
||||
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
|
||||
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
|
||||
supported_nodes_vector = consolidated_supported_nodes_vector;
|
||||
}
|
||||
}
|
||||
|
||||
// Construct subgraph capability from node list
|
||||
|
|
@ -1020,12 +992,14 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
return result;
|
||||
}
|
||||
|
||||
common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
|
||||
common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto* fused_node : fused_nodes) {
|
||||
for (auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
// Build map from input name to its index in input definitions
|
||||
std::unordered_map<std::string, size_t> input_map;
|
||||
const auto& input_defs = fused_node->InputDefs();
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
input_map.reserve(input_defs.size());
|
||||
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
input_map[input_defs[i]->Name()] = i;
|
||||
|
|
@ -1033,29 +1007,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
|
||||
// Build map from output name to its index in output definitions
|
||||
std::unordered_map<std::string, size_t> output_map;
|
||||
const auto& output_defs = fused_node->OutputDefs();
|
||||
const auto& output_defs = fused_node.OutputDefs();
|
||||
output_map.reserve(output_defs.size());
|
||||
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
|
||||
output_map[output_defs[i]->Name()] = i;
|
||||
}
|
||||
|
||||
// Reconstruct graph proto from fused node's function body
|
||||
const auto* func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
const Graph& graph_body = func_body->Body();
|
||||
auto graph_body_viewer = graph_body.CreateGraphViewer();
|
||||
auto model = graph_body_viewer->CreateModel(*GetLogger());
|
||||
auto model = graph_body_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::string string_buf;
|
||||
model_proto->SerializeToString(string_buf);
|
||||
|
||||
if (dump_subgraphs_) {
|
||||
// Dump TensorRT subgraphs
|
||||
std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
model_proto->SerializeToOstream(dump);
|
||||
}
|
||||
|
||||
|
|
@ -1123,7 +1091,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
}
|
||||
|
||||
// Set precision flags
|
||||
std::string trt_node_name_with_precision = fused_node->Name();
|
||||
std::string trt_node_name_with_precision = fused_node.Name();
|
||||
if (fp16_enable_ && int8_enable_) {
|
||||
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
|
||||
trt_node_name_with_precision += "_fp16_int8";
|
||||
|
|
@ -1140,7 +1108,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
|
||||
// Set DLA
|
||||
if (fp16_enable_ || int8_enable_) {
|
||||
if (dla_enable_ && dla_core_ >= 0) { //DLA can only run with FP16 and INT8
|
||||
if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8
|
||||
int number_of_dla_core = trt_builder->getNbDLACores();
|
||||
if (number_of_dla_core == 0) {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
|
||||
|
|
@ -1205,7 +1173,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
|
||||
if (trt_context == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not build execution context for fused node: " + fused_node->Name());
|
||||
"TensorRT EP could not build execution context for fused node: " + fused_node.Name());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1220,7 +1188,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// If (1) engine cache enable is not set or (2) first time enable engine cache and no engine cache is present,
|
||||
// build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
|
||||
// be built at runtime
|
||||
|
|
@ -1231,7 +1198,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
trt_config->setInt8Calibrator(nullptr);
|
||||
if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node->Name());
|
||||
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1242,7 +1209,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
}
|
||||
if (trt_engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not build engine for fused node: " + fused_node->Name());
|
||||
"TensorRT EP could not build engine for fused node: " + fused_node.Name());
|
||||
}
|
||||
|
||||
if (engine_cache_enable_)
|
||||
|
|
@ -1252,7 +1219,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
|
||||
if (trt_context == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not build execution context for fused node: " + fused_node->Name());
|
||||
"TensorRT EP could not build execution context for fused node: " + fused_node.Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1280,15 +1247,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
}
|
||||
|
||||
// Save engine, context and input/output info to map
|
||||
parsers_.emplace(fused_node->Name(), std::move(trt_parser));
|
||||
engines_.emplace(fused_node->Name(), std::move(trt_engine));
|
||||
contexts_.emplace(fused_node->Name(), std::move(trt_context));
|
||||
builders_.emplace(fused_node->Name(), std::move(trt_builder));
|
||||
networks_.emplace(fused_node->Name(), std::move(trt_network));
|
||||
input_info_[fused_node->Name()].push_back(input_indexes);
|
||||
output_info_[fused_node->Name()].push_back(output_indexes);
|
||||
output_info_[fused_node->Name()].push_back(output_types);
|
||||
input_shape_ranges_[fused_node->Name()] = input_shape_ranges;
|
||||
parsers_.emplace(fused_node.Name(), std::move(trt_parser));
|
||||
engines_.emplace(fused_node.Name(), std::move(trt_engine));
|
||||
contexts_.emplace(fused_node.Name(), std::move(trt_context));
|
||||
builders_.emplace(fused_node.Name(), std::move(trt_builder));
|
||||
networks_.emplace(fused_node.Name(), std::move(trt_network));
|
||||
input_info_[fused_node.Name()].push_back(input_indexes);
|
||||
output_info_[fused_node.Name()].push_back(output_indexes);
|
||||
output_info_[fused_node.Name()].push_back(output_types);
|
||||
input_shape_ranges_[fused_node.Name()] = input_shape_ranges;
|
||||
|
||||
// Create function state
|
||||
// TODO: remove default capture
|
||||
|
|
@ -1310,7 +1277,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
compute_info.release_state_func = [](FunctionState state) {
|
||||
if (state) {
|
||||
// Serialize and save engine to cache
|
||||
//
|
||||
//
|
||||
// Note: only save engine to file if engine cache enable is set and engine is being updated due to input shape changed
|
||||
// or engine file is not previously existed
|
||||
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
|
||||
|
|
@ -1325,7 +1292,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
|
||||
delete static_cast<TensorrtFuncState*>(state);
|
||||
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine encryption function encrypt"));
|
||||
"TensorRT EP could not call engine encryption function encrypt"));
|
||||
}
|
||||
} else {
|
||||
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
|
||||
|
|
@ -1341,7 +1308,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
delete static_cast<TensorrtFuncState*>(state);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
|
||||
int GetDeviceId() const { return device_id_; }
|
||||
|
||||
common::Status Compile(const std::vector<Node*>& fused_nodes,
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/graph/graph_proto_serializer.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/graph/model.h"
|
||||
|
||||
|
|
@ -101,35 +102,33 @@ TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer,
|
|||
return result;
|
||||
}
|
||||
|
||||
common::Status TvmExecutionProvider::Compile(const std::vector<Node*>& nodes,
|
||||
common::Status TvmExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
printOptions();
|
||||
for (auto* fused_node : nodes) {
|
||||
auto func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body)
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
const std::string func_name = fused_node->Name();
|
||||
const Graph& node_graph = func_body->Body();
|
||||
Model model(node_graph.Name(), true, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), node_graph.DomainToVersionMap(),
|
||||
for (auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
const std::string func_name = fused_node.Name();
|
||||
Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), graph_body_viewer.DomainToVersionMap(),
|
||||
std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
|
||||
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
|
||||
|
||||
*(model_proto.mutable_graph()) = node_graph.ToGraphProto();
|
||||
//TVM EP is using static lib approach, so invoke serializer directly.
|
||||
GraphViewerToProto(graph_body_viewer, *model_proto.mutable_graph(), true, true);
|
||||
auto opset = model_proto.add_opset_import();
|
||||
opset->set_domain(kOnnxDomain);
|
||||
opset->set_version(node_graph.DomainToVersionMap().at(kOnnxDomain));
|
||||
opset->set_version(graph_body_viewer.DomainToVersionMap().at(kOnnxDomain));
|
||||
|
||||
std::string onnx_model_str;
|
||||
model_proto.SerializeToString(&onnx_model_str);
|
||||
compilers_[func_name] = std::make_shared<Compiler>(std::move(onnx_model_str),
|
||||
fused_node->ModelPath().ToPathString(),
|
||||
fused_node.ModelPath().ToPathString(),
|
||||
int(opset->version()));
|
||||
InputsInfoMap all_input_shapes;
|
||||
auto mod = compileModel(func_name, node_graph, all_input_shapes);
|
||||
auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes);
|
||||
|
||||
std::vector<DLTensor> output_tensors;
|
||||
prepareOutputTensors(mod, output_tensors, node_graph.GetOutputs().size());
|
||||
prepareOutputTensors(mod, output_tensors, graph_body_viewer.GetOutputs().size());
|
||||
|
||||
runners_[func_name] = std::make_shared<Runner>(options_, mod, all_input_shapes, output_tensors);
|
||||
|
||||
|
|
@ -168,15 +167,15 @@ void TvmExecutionProvider::printOptions() {
|
|||
}
|
||||
|
||||
std::shared_ptr<TvmModule> TvmExecutionProvider::compileModel(const std::string& func_name,
|
||||
const Graph& graph,
|
||||
const GraphViewer& graph_viewer,
|
||||
InputsInfoMap& all_input_shapes) {
|
||||
all_input_shapes.clear();
|
||||
|
||||
TVMTensorShapes input_shapes;
|
||||
if (options_.freeze_weights) {
|
||||
setInputShapesForFreezedNN(graph, input_shapes, all_input_shapes);
|
||||
setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes);
|
||||
} else {
|
||||
setInputShapesForUnfreezedNN(graph, input_shapes, all_input_shapes);
|
||||
setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes);
|
||||
}
|
||||
|
||||
std::shared_ptr<TvmModule> mod = compilers_[func_name]->operator()(options_, input_shapes);
|
||||
|
|
@ -184,14 +183,14 @@ std::shared_ptr<TvmModule> TvmExecutionProvider::compileModel(const std::string&
|
|||
return mod;
|
||||
}
|
||||
|
||||
void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph,
|
||||
void TvmExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer,
|
||||
TVMTensorShapes& input_shapes,
|
||||
InputsInfoMap& all_input_shapes) {
|
||||
const std::vector<const NodeArg*>& all_nodes = graph.GetInputsIncludingInitializers();
|
||||
const std::vector<const NodeArg*>& all_nodes = graph_viewer.GetInputsIncludingInitializers();
|
||||
|
||||
size_t indx = 0;
|
||||
for (const auto* node : all_nodes) {
|
||||
if(!graph.IsInitializedTensor(node->Name())) {
|
||||
if (!graph_viewer.IsInitializedTensor(node->Name())) {
|
||||
TensorShapeVector shape = getInputShape(node);
|
||||
all_input_shapes[indx++] = shape;
|
||||
input_shapes.emplace_back(shape);
|
||||
|
|
@ -199,16 +198,16 @@ void TvmExecutionProvider::setInputShapesForFreezedNN(const Graph& graph,
|
|||
}
|
||||
}
|
||||
|
||||
void TvmExecutionProvider::setInputShapesForUnfreezedNN(const Graph& graph,
|
||||
void TvmExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer,
|
||||
TVMTensorShapes& input_shapes,
|
||||
InputsInfoMap& all_input_shapes) {
|
||||
const std::vector<const NodeArg*>& all_nodes = graph.GetInputsIncludingInitializers();
|
||||
const std::vector<const NodeArg*>& all_nodes = graph_viewer.GetInputsIncludingInitializers();
|
||||
|
||||
size_t indx = 0;
|
||||
for (const auto* node : all_nodes) {
|
||||
TensorShapeVector shape = getInputShape(node);
|
||||
all_input_shapes[indx++] = shape;
|
||||
if(!graph.IsInitializedTensor(node->Name())) {
|
||||
if (!graph_viewer.IsInitializedTensor(node->Name())) {
|
||||
input_shapes.emplace_back(shape);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class TvmExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
|
||||
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
|
||||
|
|
@ -48,10 +48,10 @@ class TvmExecutionProvider : public IExecutionProvider {
|
|||
private:
|
||||
void printOptions();
|
||||
std::shared_ptr<tvm::TvmModule> compileModel(const std::string& func_name,
|
||||
const Graph& graph,
|
||||
const GraphViewer& graph_viewer,
|
||||
InputsInfoMap& inputs_info);
|
||||
void setInputShapesForFreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
|
||||
void setInputShapesForUnfreezedNN(const Graph& graph, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
|
||||
void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
|
||||
void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, TVMTensorShapes& input_shapes, InputsInfoMap& all_input_shapes);
|
||||
TensorShapeVector getInputShape(const NodeArg* node);
|
||||
TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto);
|
||||
void prepareOutputTensors(const std::shared_ptr<tvm::TvmModule>& mod, std::vector<DLTensor>& output_tensors, size_t num);
|
||||
|
|
|
|||
|
|
@ -32,28 +32,22 @@
|
|||
namespace onnxruntime {
|
||||
namespace vitisai_ep {
|
||||
|
||||
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node,
|
||||
static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const logging::Logger& logger) {
|
||||
const auto* node_function = fused_node->GetFunctionBody();
|
||||
|
||||
ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ",
|
||||
fused_node->Name());
|
||||
|
||||
const Graph& node_subgraph = node_function->Body();
|
||||
onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{},
|
||||
IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(),
|
||||
onnxruntime::Model model{graph_viewer.Name(), true, ModelMetaData{}, PathString{},
|
||||
IOnnxRuntimeOpSchemaRegistryList{}, graph_viewer.DomainToVersionMap(),
|
||||
std::vector<ONNX_NAMESPACE::FunctionProto>(), logger};
|
||||
|
||||
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
|
||||
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
*(model_proto.mutable_graph()) = node_subgraph.ToGraphProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
|
||||
return model_proto;
|
||||
}
|
||||
|
||||
VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
|
||||
const onnxruntime::Node* fused_node,
|
||||
const onnxruntime::Node& fused_node,
|
||||
const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& backend_type,
|
||||
const std::string& export_runtime_module,
|
||||
const std::string& load_runtime_module,
|
||||
|
|
@ -68,7 +62,7 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
|
|||
allocator_ = context->allocator_handle;
|
||||
name_ = context->node_name;
|
||||
|
||||
model_proto_ = GetModelProtoFromFusedNode(fused_node, *GetLogger());
|
||||
model_proto_ = GetModelProtoFromFusedNode(graph_viewer, *GetLogger());
|
||||
std::istringstream model_stream{model_proto_.SerializeAsString()};
|
||||
xg_ = pyxir::onnx::import_onnx_model(model_stream);
|
||||
|
||||
|
|
@ -78,12 +72,12 @@ VitisAICustomOp::VitisAICustomOp(const ComputeContext* context,
|
|||
if (load_runtime_module_.empty()) {
|
||||
pyxir::partition(xg_, std::vector<std::string>{backend_type_}, "");
|
||||
|
||||
auto input_defs = fused_node->InputDefs();
|
||||
auto input_defs = fused_node.InputDefs();
|
||||
for (auto idef : input_defs) {
|
||||
in_tensor_names_.push_back(idef->Name());
|
||||
}
|
||||
|
||||
auto output_defs = fused_node->OutputDefs();
|
||||
auto output_defs = fused_node.OutputDefs();
|
||||
for (auto odef : output_defs) {
|
||||
out_tensor_names_.push_back(odef->Name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ namespace vitisai_ep {
|
|||
class VitisAICustomOp {
|
||||
public:
|
||||
VitisAICustomOp(const ComputeContext* context,
|
||||
const onnxruntime::Node* fused_node,
|
||||
const onnxruntime::Node& fused_node,
|
||||
const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& backend_type,
|
||||
const std::string& export_runtime_module,
|
||||
const std::string& load_runtime_module,
|
||||
|
|
|
|||
|
|
@ -273,12 +273,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
return result;
|
||||
}
|
||||
|
||||
common::Status VitisAIExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto& fused_node : fused_nodes) {
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
|
||||
const Node& fused_node = fused_node_graph.fused_node;
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [this, fused_node, logger = GetLogger()](ComputeContext* context, FunctionState* state) {
|
||||
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, backend_type_, export_runtime_module_,
|
||||
auto* p = new vitisai_ep::VitisAICustomOp(context, fused_node, graph_body_viewer, backend_type_, export_runtime_module_,
|
||||
load_runtime_module_, logger);
|
||||
*state = p;
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class VitisAIExecutionProvider : public IExecutionProvider {
|
|||
|
||||
int GetDeviceId() const { return device_id_; }
|
||||
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -930,7 +930,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
|
||||
// Do partitioning based on execution providers' capability.
|
||||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph,
|
||||
session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP, mode));
|
||||
|
||||
|
|
@ -1153,7 +1153,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
|||
std::unordered_map<std::string, HashValue> compiled_kernel_hashes;
|
||||
|
||||
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
||||
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(),
|
||||
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
|
||||
session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP,
|
||||
GraphPartitioner::Mode::kOrtFormatLoad,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "core/util/math.h"
|
||||
#include "core/framework/sparse_utils.h"
|
||||
#include "core/common/string_helper.h"
|
||||
#include "core/graph/graph_proto_serializer.h"
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#ifdef ENABLE_TRAINING_TORCH_INTEROP
|
||||
|
|
@ -265,14 +266,11 @@ struct ProviderHostImpl : ProviderHost {
|
|||
void IExecutionProvider__TryInsertAllocator(IExecutionProvider* p, AllocatorPtr allocator) override { return p->IExecutionProvider::TryInsertAllocator(allocator); }
|
||||
std::vector<std::unique_ptr<ComputeCapability>> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) override { return p->IExecutionProvider::GetCapability(graph_viewer, kernel_registries); }
|
||||
// !!! this api will be deprecated soon
|
||||
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::vector<NodeComputeInfo>& node_compute_funcs) override {
|
||||
return p->IExecutionProvider::Compile(fused_nodes, node_compute_funcs);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<onnxruntime::Node*>& fused_nodes, std::string& dll_path) override {
|
||||
return p->IExecutionProvider::Compile(fused_nodes, dll_path);
|
||||
}
|
||||
|
||||
common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs, std::vector<NodeComputeInfo>& node_compute_funcs) override {
|
||||
return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs);
|
||||
}
|
||||
|
|
@ -354,6 +352,8 @@ struct ProviderHostImpl : ProviderHost {
|
|||
#endif
|
||||
|
||||
// TypeProto (wrapped)
|
||||
std::unique_ptr<ONNX_NAMESPACE::TypeProto> TypeProto__construct() override { return std::make_unique<ONNX_NAMESPACE::TypeProto>(); }
|
||||
void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) override { p->CopyFrom(*other); }
|
||||
const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) override { return p->tensor_type(); }
|
||||
ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) override { return p->mutable_tensor_type(); }
|
||||
int TypeProto__value_case(const ONNX_NAMESPACE::TypeProto* p) override { return p->value_case(); }
|
||||
|
|
@ -441,6 +441,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); }
|
||||
|
||||
bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); }
|
||||
void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); }
|
||||
|
||||
// TensorProtos (wrapped)
|
||||
ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); }
|
||||
|
|
@ -762,6 +763,9 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
|
||||
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
|
||||
void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override {
|
||||
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args);
|
||||
}
|
||||
|
||||
// Path (wrapped)
|
||||
PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); }
|
||||
|
|
|
|||
|
|
@ -1242,12 +1242,9 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
InferenceSession session_object_2{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(
|
||||
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
status = session_object_2.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object_2.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object_2.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
ASSERT_STATUS_OK(session_object_2.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object_2.Initialize());
|
||||
ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -143,8 +143,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP);
|
||||
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
|
||||
|
|
@ -209,7 +209,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
|
|||
|
||||
// Partition the graph
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
|
|
@ -259,7 +259,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
|
|||
|
||||
// Partition the graph
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr(),
|
||||
status = partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
|
||||
layout_transformer::TransformLayoutForCompilingEP);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,409 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/nuphar_tvm/tvm_demo/demo_compiler.h"
|
||||
|
||||
#include <tvm/runtime/ndarray.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using namespace tvm_demo;
|
||||
|
||||
class TVMDemoKernel : public OpKernel {
|
||||
public:
|
||||
explicit TVMDemoKernel(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
protected:
|
||||
const TensorShape& GetOutputShape(OpKernelContext* context, int /*i*/) const {
|
||||
return context->Input<Tensor>(0)->Shape();
|
||||
}
|
||||
};
|
||||
|
||||
class UnionSet {
|
||||
public:
|
||||
UnionSet(int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
farthers_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
int get(int x) {
|
||||
if (farthers_[x] == x) {
|
||||
return x;
|
||||
}
|
||||
return farthers_[x] = get(farthers_[x]);
|
||||
}
|
||||
|
||||
void merge(int x, int y) {
|
||||
x = get(x);
|
||||
y = get(y);
|
||||
if (x != y) {
|
||||
farthers_[y] = x;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> farthers_;
|
||||
};
|
||||
|
||||
static DLDataType GetDataType(ONNXTensorElementDataType type) {
|
||||
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
|
||||
return {kDLFloat, 64, 1};
|
||||
} else
|
||||
ORT_THROW("not implement.");
|
||||
}
|
||||
|
||||
namespace test {
|
||||
|
||||
struct TVMFuncState {
|
||||
AllocateFunc test_allocate_func = nullptr;
|
||||
DestroyFunc test_release_func = nullptr;
|
||||
AllocatorHandle allocator = nullptr;
|
||||
tvm::runtime::Module* module = nullptr;
|
||||
};
|
||||
|
||||
class FuseExecutionProviderX : public CPUExecutionProvider {
|
||||
public:
|
||||
explicit FuseExecutionProviderX(const CPUExecutionProviderInfo& info) : CPUExecutionProvider(info) {
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override {
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
std::vector<onnxruntime::NodeIndex> fused_nodes;
|
||||
for (auto& node : graph_viewer.Nodes()) {
|
||||
if (node.OpType() == "Mul") {
|
||||
fused_nodes.push_back(node.Index());
|
||||
}
|
||||
}
|
||||
|
||||
UnionSet set(static_cast<int>(fused_nodes.size()));
|
||||
for (int i = 0; i < fused_nodes.size(); ++i) {
|
||||
auto node = graph_viewer.GetNode(fused_nodes[i]);
|
||||
for (auto it = node->InputNodesBegin(); it != node->InputNodesEnd(); ++it) {
|
||||
auto index_it = std::find(fused_nodes.begin(), fused_nodes.end(), (*it).Index());
|
||||
if (index_it != fused_nodes.end()) {
|
||||
set.merge(i, static_cast<int>(index_it - fused_nodes.begin()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<onnxruntime::NodeIndex>> groups;
|
||||
groups.resize(fused_nodes.size());
|
||||
for (int i = 0; i < set.farthers_.size(); ++i) {
|
||||
groups[set.get(i)].push_back(fused_nodes[i]);
|
||||
}
|
||||
|
||||
for (auto& group : groups) {
|
||||
if (group.size() > 1) {
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
|
||||
std::set<const onnxruntime::NodeArg*> fused_inputs, fused_outputs;
|
||||
for (auto index : group) {
|
||||
sub_graph->nodes.push_back(index);
|
||||
auto node = graph_viewer.GetNode(index);
|
||||
for (auto input : node->InputDefs()) {
|
||||
auto it = fused_outputs.find(input);
|
||||
if (it != fused_outputs.end()) {
|
||||
fused_outputs.erase(it);
|
||||
} else {
|
||||
fused_inputs.insert(input);
|
||||
}
|
||||
}
|
||||
for (auto output : node->OutputDefs()) {
|
||||
auto it = fused_inputs.find(output);
|
||||
if (it != fused_inputs.end()) {
|
||||
fused_inputs.erase(it);
|
||||
} else {
|
||||
fused_outputs.insert(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
|
||||
meta_def->name = "TVMFuse";
|
||||
meta_def->domain = "FuseTest";
|
||||
for (auto input : fused_inputs) {
|
||||
meta_def->inputs.push_back(input->Name());
|
||||
}
|
||||
|
||||
for (auto output : fused_outputs) {
|
||||
meta_def->outputs.push_back(output->Name());
|
||||
}
|
||||
|
||||
meta_def->since_version = 1;
|
||||
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
|
||||
sub_graph->SetMetaDef(std::move(meta_def));
|
||||
//TODO:set fuse kernel func;
|
||||
result.push_back(
|
||||
std::make_unique<ComputeCapability>(std::move(sub_graph)));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override {
|
||||
for (auto* fused_node : fused_nodes) {
|
||||
auto func_body = fused_node->GetFunctionBody();
|
||||
if (!func_body)
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
//1. Build tvm IR based on the Ort graph
|
||||
auto demo_tvm_tensor_ctx = BuildTVMIR(func_body->Body());
|
||||
//2. Create schedule for the built tvm IRs
|
||||
auto s = CreateSchedule(demo_tvm_tensor_ctx);
|
||||
//3. Build tvm module
|
||||
std::vector<tvm::Tensor> tvm_args;
|
||||
for (auto& t : demo_tvm_tensor_ctx.inputs) {
|
||||
tvm_args.push_back(t);
|
||||
}
|
||||
for (auto& t : demo_tvm_tensor_ctx.outputs) {
|
||||
tvm_args.push_back(t);
|
||||
}
|
||||
|
||||
std::vector<std::string> func_names;
|
||||
auto module_ptr = std::make_shared<tvm::runtime::Module>();
|
||||
*module_ptr = BuildStackVMModule(s, tvm::build_config(), tvm_args, func_names);
|
||||
modules_[fused_node->Name()] = module_ptr;
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
|
||||
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
auto* p = new TVMFuncState();
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, modules_[context->node_name].get()};
|
||||
*state = p;
|
||||
return 0;
|
||||
};
|
||||
|
||||
compute_info.release_state_func = [](FunctionState state) {
|
||||
if (state)
|
||||
delete static_cast<TVMFuncState*>(state);
|
||||
};
|
||||
|
||||
//we use lambda to capture the tvm model, so we can use it to get the funciton.
|
||||
compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
TVMFuncState* tvm_state = reinterpret_cast<TVMFuncState*>(state);
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
|
||||
auto eval_func_name = "func";
|
||||
DLContext cpu_context = {kDLCPU, 0};
|
||||
size_t num_inputs = ort.KernelContext_GetInputCount(context);
|
||||
size_t num_outputs = ort.KernelContext_GetOutputCount(context);
|
||||
size_t n_args = num_inputs + num_outputs;
|
||||
std::vector<DLTensor> dl_tensors(n_args);
|
||||
std::vector<TVMValue> tvm_values(n_args);
|
||||
std::vector<int> tvm_type_codes(n_args);
|
||||
for (auto i = 0; i < num_inputs; i++) {
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, i);
|
||||
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
|
||||
auto tensor_type = ort.GetTensorElementType(tensor_info);
|
||||
input_shapes.emplace_back(ort.GetTensorShape(tensor_info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
|
||||
tvm_type_codes[i] = kNDArrayContainer;
|
||||
dl_tensors[i].ctx = cpu_context;
|
||||
dl_tensors[i].dtype = GetDataType(tensor_type);
|
||||
dl_tensors[i].strides = nullptr;
|
||||
dl_tensors[i].byte_offset = 0;
|
||||
dl_tensors[i].data = const_cast<double*>(ort.GetTensorData<double>(input_tensor));
|
||||
dl_tensors[i].ndim = input_shapes.back().size();
|
||||
dl_tensors[i].shape = input_shapes.back().data();
|
||||
tvm_values[i].v_handle = &dl_tensors[i];
|
||||
}
|
||||
|
||||
for (auto i = 0; i < num_outputs; i++) {
|
||||
//setup output tensor property
|
||||
//todo: type should be set by framework.
|
||||
output_shapes.push_back(input_shapes[i]);
|
||||
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, output_shapes[i].data(), output_shapes[i].size());
|
||||
auto tensor_info = ort.GetTensorTypeAndShape(output_tensor);
|
||||
auto tensor_type = ort.GetTensorElementType(tensor_info);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
|
||||
tvm_type_codes[num_inputs + i] = kNDArrayContainer;
|
||||
dl_tensors[num_inputs + i].ctx = cpu_context;
|
||||
dl_tensors[num_inputs + i].dtype = GetDataType(tensor_type);
|
||||
dl_tensors[num_inputs + i].strides = nullptr;
|
||||
dl_tensors[num_inputs + i].byte_offset = 0;
|
||||
dl_tensors[num_inputs + i].data = ort.GetTensorMutableData<double>(output_tensor);
|
||||
dl_tensors[num_inputs + i].ndim = output_shapes.back().size();
|
||||
dl_tensors[num_inputs + i].shape = output_shapes.back().data();
|
||||
tvm_values[num_inputs + i].v_handle = &dl_tensors[num_inputs + i];
|
||||
}
|
||||
|
||||
auto evaluate_func_ = tvm_state->module->GetFunction(eval_func_name);
|
||||
tvm::TVMArgs tvm_args(&tvm_values[0], &tvm_type_codes[0], static_cast<int>(n_args));
|
||||
tvm::TVMRetValue rvalue;
|
||||
try {
|
||||
evaluate_func_.CallPacked(tvm_args, &rvalue);
|
||||
} catch (std::exception&) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL); // TODO: Translate exception to error code
|
||||
}
|
||||
if (rvalue.type_code() != kNull) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL); // TODO: get error code.
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
node_compute_funcs.push_back(compute_info);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::shared_ptr<tvm::runtime::Module>> modules_;
|
||||
};
|
||||
|
||||
static void RunSession(InferenceSession& session_object,
|
||||
RunOptions& run_options,
|
||||
std::vector<int64_t>& dims_x,
|
||||
std::vector<double>& values_x,
|
||||
std::vector<int64_t>& dims_y,
|
||||
std::vector<double>& values_y) {
|
||||
// prepare inputs
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<double>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X1", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("Y4");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// Now run
|
||||
common::Status st = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
if (!st.IsOK()) {
|
||||
std::cout << "Run returned status: " << st.ErrorMessage() << std::endl;
|
||||
}
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
ASSERT_EQ(1, fetches.size());
|
||||
auto& rtensor = fetches.front().Get<Tensor>();
|
||||
TensorShape expected_shape(dims_y);
|
||||
EXPECT_EQ(expected_shape, rtensor.Shape());
|
||||
const std::vector<double> found(rtensor.template Data<double>(), rtensor.template Data<double>() + expected_shape.Size());
|
||||
ASSERT_EQ(found.size(), values_y.size());
|
||||
for (size_t i = 0; i < found.size(); i++)
|
||||
ASSERT_EQ(found[i], values_y[i]);
|
||||
}
|
||||
|
||||
static const std::string MODEL_URI = "testdata/fuse_mul_1.onnx";
|
||||
|
||||
TEST(TVMTest, CodeGen_Demo_for_Fuse_Mul) {
|
||||
SessionOptions so;
|
||||
|
||||
so.session_logid = "InferenceSessionTests.NoTimeout";
|
||||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
CPUExecutionProviderInfo info;
|
||||
auto tvm_xp = std::make_unique<FuseExecutionProviderX>(info);
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(tvm_xp)).IsOK());
|
||||
EXPECT_TRUE(session_object.Load(MODEL_URI).IsOK());
|
||||
EXPECT_TRUE(session_object.Initialize().IsOK());
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = "one session/one tag";
|
||||
|
||||
// prepare inputs
|
||||
std::vector<int64_t> dims_x = {
|
||||
6,
|
||||
};
|
||||
std::vector<double> values_x = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {
|
||||
6,
|
||||
};
|
||||
// now the expected value should be Mul's result.
|
||||
std::vector<double> expected_values_y = {1.0, 32.0, 243.0, 1024.0, 3125.0, 7776.0};
|
||||
|
||||
// Now run
|
||||
RunSession(session_object, run_options, dims_x, values_x, expected_dims_y, expected_values_y);
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
TEST(TVMTest, Native_TVM) {
|
||||
using namespace tvm;
|
||||
auto n = var("n");
|
||||
Array<Expr> shape;
|
||||
shape.push_back(n);
|
||||
auto A = placeholder(shape, Float(64), "A");
|
||||
auto B = placeholder(shape, Float(64), "B");
|
||||
auto D = placeholder(shape, Float(64), "D");
|
||||
auto C = compute(
|
||||
A->shape, [&A, &B](Expr i) {
|
||||
return A[i] + B[i];
|
||||
},
|
||||
"C");
|
||||
auto E = compute(
|
||||
A->shape, [&C, &D](Expr i) {
|
||||
return C[i] + D[i];
|
||||
},
|
||||
"E");
|
||||
|
||||
auto s = create_schedule({E->op});
|
||||
auto args = Array<Tensor>({A, B, D, E});
|
||||
std::unordered_map<Tensor, Buffer> binds;
|
||||
auto config = build_config();
|
||||
#ifdef USE_NUPHAR_TVM_WITH_LLVM
|
||||
auto target = target::llvm();
|
||||
#else
|
||||
auto target = target::stackvm();
|
||||
#endif
|
||||
auto lowered = lower(s, args, "func", binds, config);
|
||||
auto module = build(lowered, target, Target(), config);
|
||||
auto func = module.GetFunction("func");
|
||||
|
||||
DLDataType dtype;
|
||||
dtype.code = kDLFloat;
|
||||
dtype.bits = 64;
|
||||
dtype.lanes = 1;
|
||||
DLContext ctx;
|
||||
ctx.device_type = DLDeviceType::kDLCPU;
|
||||
ctx.device_id = 0;
|
||||
|
||||
std::vector<double> v = {1.0, 2.0, 3.0};
|
||||
int64_t len = 3;
|
||||
DLTensor tensor_A = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
|
||||
DLTensor tensor_B = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
|
||||
DLTensor tensor_D = {&v[0], ctx, 1, dtype, &len, nullptr, 0};
|
||||
|
||||
std::vector<double> r;
|
||||
r.resize(len);
|
||||
DLTensor tensor_E = {&r[0], ctx, 1, dtype, &len, nullptr, 0};
|
||||
|
||||
TVMValue lvalues[4];
|
||||
int type_codes[4] = {kNDArrayContainer, kNDArrayContainer, kNDArrayContainer, kNDArrayContainer};
|
||||
lvalues[0].v_handle = &tensor_A;
|
||||
lvalues[1].v_handle = &tensor_B;
|
||||
lvalues[2].v_handle = &tensor_D;
|
||||
lvalues[3].v_handle = &tensor_E;
|
||||
|
||||
TVMArgs tvm_args(lvalues, type_codes, 4);
|
||||
TVMRetValue rvalue;
|
||||
func.CallPacked(tvm_args, &rvalue);
|
||||
CHECK_EQ(rvalue.type_code(), kNull);
|
||||
double expected[3] = {3.0, 6.0, 9.0};
|
||||
auto data_E = static_cast<double*>(tensor_E.data);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
EXPECT_NEAR(*(data_E + i), expected[i], 0.001f);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/nuphar_tvm/tvm_demo/demo_compiler.h"
|
||||
|
||||
#include "core/codegen/passes/scheduler/schedule_utils.h"
|
||||
#include "core/codegen/passes/utils/ort_tvm_utils.h"
|
||||
#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h"
|
||||
#include "core/codegen/passes/scheduler/tvm_scheduler.h"
|
||||
#include "core/codegen/passes/scheduler/tvm_schedule_builder.h"
|
||||
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/build_module.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace tvm_demo {
|
||||
|
||||
// Create a dummy demo handle
|
||||
static codegen::CodeGenHandle demo_handle;
|
||||
// Create a dummy demo codegen context
|
||||
static tvm_codegen::CodeGenContext demo_codegen_ctx(&demo_handle);
|
||||
|
||||
// Translate an Ort graph into tvm IR
|
||||
// Note this function is specific for this demo.
|
||||
// This function uses specific way for graph traversal or constructing tvm placeholders.
|
||||
// It may or may not work for a universal Ort graph.
|
||||
// For a more general example, please check nuphar provider.
|
||||
DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph) {
|
||||
// Create OpIRRegistry that holds all OpIRCreators
|
||||
std::unique_ptr<tvm_codegen::OpIRRegistry> op_ir_registry =
|
||||
std::make_unique<tvm_codegen::OpIRRegistry>();
|
||||
|
||||
// Register all generic OpIRCreators
|
||||
tvm_codegen::RegisterAllGenericOpIRCreators(op_ir_registry.get());
|
||||
|
||||
// Create OpIRBuilder
|
||||
std::shared_ptr<tvm_codegen::TVMIRBuilder> op_ir_builder =
|
||||
std::make_shared<tvm_codegen::TVMIRBuilder>("Demo_Op_IR_Builder");
|
||||
|
||||
// Attach all generic OpIRCreators from op_ir_registry to op_ir_builder
|
||||
tvm_codegen::RegisterGenericOrtOpTypeDispatcher(op_ir_builder, op_ir_registry.get());
|
||||
|
||||
// Create DemoTVMTensorCtx holdings tvm IR
|
||||
DemoTVMTensorCtx result;
|
||||
|
||||
// Local lookup from name to tvm::Tensor
|
||||
std::unordered_map<std::string, tvm::Tensor> tvm_tensors;
|
||||
|
||||
// Note this is a simplified traversal that works specifically for this demo
|
||||
// but may or may not work for an univerisal model.
|
||||
// For more general traversal, please check nuphar provider.
|
||||
for (auto& node : graph.Nodes()) {
|
||||
tvm::Array<tvm::Tensor> inputs;
|
||||
tvm::Array<tvm::Tensor> outputs;
|
||||
|
||||
// Get inputs
|
||||
for (auto& def : node.InputDefs()) {
|
||||
const std::string& name = def->Name();
|
||||
auto iter = tvm_tensors.find(name);
|
||||
// Always create placeholder when not finding a tensor
|
||||
// Note it is for this demo.
|
||||
// It may or may not work for a universal graph.
|
||||
if (iter == tvm_tensors.end()) {
|
||||
tvm_tensors[name] =
|
||||
tvm::placeholder(ShapeToTvmArray(def, demo_codegen_ctx),
|
||||
tvm_codegen::ToTvmType(TensorProtoDataType(def)),
|
||||
name + "_placeholder");
|
||||
}
|
||||
inputs.push_back(tvm_tensors[name]);
|
||||
}
|
||||
|
||||
// call OpIBuilder's Evaluate to build tvm IR
|
||||
ORT_THROW_IF_ERROR(op_ir_builder->Evaluate(inputs, node, demo_codegen_ctx, outputs));
|
||||
|
||||
// Store outputs
|
||||
for (size_t def_id = 0; def_id < node.OutputDefs().size(); ++def_id) {
|
||||
const NodeArg* def = node.OutputDefs()[def_id];
|
||||
tvm_tensors[def->Name()] = outputs[def_id];
|
||||
}
|
||||
}
|
||||
|
||||
// put inputs to DemoTVMTensorCtx
|
||||
for (auto& input : graph.GetInputs()) {
|
||||
result.inputs.push_back(tvm_tensors[input->Name()]);
|
||||
}
|
||||
|
||||
// check initializer
|
||||
for (auto& initializer : graph.GetAllInitializedTensors()) {
|
||||
result.inputs.push_back(tvm_tensors[initializer.first]);
|
||||
}
|
||||
|
||||
// Only one output in this demo
|
||||
auto& output = graph.GetOutputs()[0];
|
||||
result.outputs.push_back(tvm_tensors[output->Name()]);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Declare a Demo scheduler that always inserts compute_inline
|
||||
DECLARE_TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)
|
||||
|
||||
// Define a Demo scheduler's Evaluate that always inserts compute_inline
|
||||
bool TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)::Evaluate(
|
||||
const tvm::Tensor& tensor,
|
||||
const Node*,
|
||||
tvm_codegen::CodeGenContext&,
|
||||
tvm_codegen::ScheduleContext& ctx_sched) {
|
||||
return TryInlineSchedule(tensor, ctx_sched);
|
||||
}
|
||||
|
||||
// Register the always inline Scheduler to sched_registry
|
||||
static void RegisterAlwaysInlineScheduler(tvm_codegen::TVMScheduleRegistry* sched_registry) {
|
||||
sched_registry->Register(
|
||||
std::make_unique<TVM_SCHEDULER_CLASS(AlwaysInline, DemoTVM)>());
|
||||
}
|
||||
|
||||
// Declare a schedule dispatcher that always dispatches the always inline Scheduler
|
||||
DECLARE_SCHEDULE_DISPATCHER_CLASS(DemoTVM)
|
||||
|
||||
// Use a predefined key as DemoKey to dispatch the scheduler
|
||||
constexpr auto predefined_key = "DemoKey";
|
||||
|
||||
// Define the schedule dispatcher's Find function
|
||||
// that always dispatches the always inline Scheduler
|
||||
// Note this dispatcher always returning a predefined_key is only for demo purpose.
|
||||
// In practice, a dispatcher returns a key by checking tvm::Tensor, Node,
|
||||
// or even meta data stored in CodeGenContext.
|
||||
// Derived CodeGenContext allows compiler developers to store their specific meta data.
|
||||
// For more detailed example, please check nuphar provider.
|
||||
tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(DemoTVM)::Find(
|
||||
const tvm::Tensor&, const Node*, tvm_codegen::CodeGenContext&) {
|
||||
return DispatcherBase::Get(predefined_key);
|
||||
}
|
||||
|
||||
// Attach the always inline Scheduler to the above dispatcher
|
||||
// and then attach the dispatcher to the scheduler builder
|
||||
static void AttachAlwaysInlineScheduler(const std::shared_ptr<tvm_codegen::TVMScheduleBuilder>& builder,
|
||||
const tvm_codegen::TVMScheduleRegistry* registry) {
|
||||
auto dispatcher = std::make_unique<SCHEDULE_DISPATCHER_CLASS(DemoTVM)>("DemoSchedulers");
|
||||
|
||||
// Using a predefined_key
|
||||
dispatcher->Register(predefined_key,
|
||||
registry->Get(TVM_SCHEDULER_STRING(AlwaysInline, DemoTVM)));
|
||||
|
||||
builder->InsertDispatcher(std::move(dispatcher));
|
||||
}
|
||||
|
||||
// Traverse tvm::Tensor and then schedule them
|
||||
// Note this traversal is simplified and specific for this demo.
|
||||
// For a more general traversal, please check nuphar provider.
|
||||
static void TraverseAndSchedule(
|
||||
std::shared_ptr<tvm_codegen::TVMScheduleBuilder>& schedule_builder,
|
||||
const tvm::Tensor& tensor,
|
||||
tvm_codegen::ScheduleContext& ctx_schedule) {
|
||||
ORT_THROW_IF_ERROR(schedule_builder->Evaluate(tensor, nullptr, demo_codegen_ctx, ctx_schedule));
|
||||
|
||||
// Traverse tensor's children (inputs)
|
||||
for (auto& t : tensor->op->InputTensors()) {
|
||||
// check whether it is a non-trivial tensor by checking its input size
|
||||
if (t->op->InputTensors().size() > 0) {
|
||||
TraverseAndSchedule(schedule_builder, t, ctx_schedule);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a TVM schedule by always inserting tvm's compute_inline.
|
||||
// Note this schedule is specific for this demo.
|
||||
// In practice, always inline might lead to bad performance
|
||||
// or even illegal loop transformation for some backends.
|
||||
// For a more general example, please check nuphar provider.
|
||||
tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx) {
|
||||
// Create TVMScheduleRegistry that holds all Scheduler
|
||||
std::unique_ptr<tvm_codegen::TVMScheduleRegistry> schedule_registry =
|
||||
std::make_unique<tvm_codegen::TVMScheduleRegistry>();
|
||||
|
||||
// Register the always inline Scheduler to schedule_registry
|
||||
RegisterAlwaysInlineScheduler(schedule_registry.get());
|
||||
|
||||
// Create a DemoScheduleBuilder
|
||||
std::shared_ptr<tvm_codegen::TVMScheduleBuilder> schedule_builder =
|
||||
std::make_shared<tvm_codegen::TVMScheduleBuilder>("Demo_Schedule_Builder");
|
||||
|
||||
// Attach the demo inline scheduler to the schedule_builder
|
||||
AttachAlwaysInlineScheduler(schedule_builder, schedule_registry.get());
|
||||
|
||||
// Create scheudule object
|
||||
tvm::Array<tvm::Operation> out_ops;
|
||||
for (auto& t : ctx.outputs) {
|
||||
out_ops.push_back(t->op);
|
||||
}
|
||||
|
||||
// Create scheudule context
|
||||
tvm_codegen::ScheduleContext ctx_schedule(out_ops);
|
||||
|
||||
// Traverse tvm::Tensor in a DFS way, and then schedule
|
||||
for (auto& t : ctx.outputs) {
|
||||
TraverseAndSchedule(schedule_builder, t, ctx_schedule);
|
||||
}
|
||||
|
||||
// Make sure all outputs compute_root (tvm's requirement)
|
||||
for (auto& t : ctx.outputs) {
|
||||
tvm_codegen::InsertRootSchedule(t, ctx_schedule);
|
||||
}
|
||||
|
||||
return ctx_schedule.schedule;
|
||||
}
|
||||
|
||||
// Build TVM Module with a schedule using tvm's stackvm.
|
||||
// Note in real practice, please change stackvm to other backends.
|
||||
// For a more detailed example, please check nuphar provider.
|
||||
tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule,
|
||||
tvm::BuildConfig config,
|
||||
tvm::Array<tvm::Tensor> tvm_args,
|
||||
std::vector<std::string>& target_func_names) {
|
||||
auto target = tvm::target::stackvm();
|
||||
std::string func_name = "func";
|
||||
auto args = tvm::Array<tvm::Tensor>(tvm_args);
|
||||
std::unordered_map<tvm::Tensor, tvm::Buffer> binds;
|
||||
auto lowered = lower(schedule, args, "func", binds, config);
|
||||
// Uncomment the following line to dump lowered func
|
||||
// std::cout << "Dumping lowered func: " << lowered[0]->body;
|
||||
target_func_names.push_back(func_name);
|
||||
return build(lowered, target, tvm::Target(), config);
|
||||
}
|
||||
|
||||
} // namespace tvm_demo
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/build_module.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace tvm_demo {
|
||||
// A Demo data structure to hold tvm IR and context
|
||||
struct DemoTVMTensorCtx {
|
||||
tvm::Array<tvm::Tensor> inputs;
|
||||
tvm::Array<tvm::Tensor> outputs;
|
||||
};
|
||||
|
||||
// Translate an Ort graph into tvm IR
|
||||
DemoTVMTensorCtx BuildTVMIR(const onnxruntime::Graph& graph);
|
||||
|
||||
// Create a demo schedule for the tvm IR
|
||||
tvm::Schedule CreateSchedule(const DemoTVMTensorCtx& ctx);
|
||||
|
||||
// Build a demo tvm module with the tvm IR and schedule
|
||||
tvm::runtime::Module BuildStackVMModule(tvm::Schedule schedule,
|
||||
tvm::BuildConfig config,
|
||||
tvm::Array<tvm::Tensor> tvm_args,
|
||||
std::vector<std::string>& target_func_names);
|
||||
|
||||
} // namespace tvm_demo
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -21,10 +21,6 @@ class InternalTestingExecutionProvider : public IExecutionProvider {
|
|||
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
FusionStyle GetFusionStyle() const override {
|
||||
return FusionStyle::FilteredGraphViewer;
|
||||
}
|
||||
|
||||
DataLayout GetPreferredLayout() const override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
Loading…
Reference in a new issue