[DML EP] Don't fuse a capability outside the compile call (#13468)

### Description
DML EP was a special EP w.r.t. capability fusion. It used to fuse a
capability outside the IExecutionProvider::Compile() call. But after
recent re-architecture #13131, it is no longer a special case.



### Motivation and Context
Why is this change required? What problem does it solve?
To make DML EP consistent with the ORT design.
- If it fixes an open issue, please link to the issue here.  N/A

Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
sumitsays 2022-10-26 15:21:33 -07:00 committed by GitHub
parent 1c8a22ec68
commit 490e4ddea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -295,15 +295,7 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
std::string node_name = oss.str();
Node* fused_node = nullptr;
// TODO1: The DML currently use some legacy approach.
// It registers a generic predefined kernel for all purpose fusion,
// so it rely on the function body in the fused node during kernel creation,
// which is after the graph partition phase.
// Ideally, it should be moved to "Compile" call.
// Here we temporary keep the function body for DML fusion
// Need to remove it after migrate DML to the Compile-based approach.
if (fusion_style == IExecutionProvider::FusionStyle::Function ||
provider_type == kDmlExecutionProvider) {
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
fused_node = &graph.FuseSubGraph(capability, node_name);
} else {
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
@ -471,10 +463,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
}
}
// TODO: The DML currently use some legacy approach.
// The fuse is done in FuseSubGraph function.
// Need to remove it later when DML migrate to Compile approach
if (!nodes_to_complete_fuse.empty() && type != kDmlExecutionProvider) {
if (!nodes_to_complete_fuse.empty()) {
for (size_t j = 0, end = nodes_to_complete_fuse.size(); j < end; j++) {
auto* node = nodes_to_complete_fuse[j];