[DML EP] Remove suffix removal adhoc logic for fusedNodeArgNames (#11879)

* DML EP: Remove suffix removal hack for fusedNodeArgName

* Acknowledged PR comments

* Removed reference from gsl::span

Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
sumitsays 2022-06-17 17:04:16 -07:00 committed by GitHub
parent f97bd38c4f
commit 52f2b3bf89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 81 deletions

View file

@ -19,7 +19,9 @@ namespace Dml
FusedGraphKernel(
const onnxruntime::OpKernelInfo& kernelInfo,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap) : OpKernel(kernelInfo)
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames) : OpKernel(kernelInfo)
{
// Get the graph for the function which was created according to the computational
// capacity returned by the execution provider's graph partitioner
@ -45,14 +47,20 @@ namespace Dml
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider));
}
TranslateAndCompileGraph(kernelInfo, graph, node.InputDefs(), node.OutputDefs(), graphNodePropertyMap, transferredInitializerMap);
TranslateAndCompileGraph(
kernelInfo,
graph,
fusedNodeInputArgOriginalNames,
fusedNodeOutputArgOriginalNames,
graphNodePropertyMap,
transferredInitializerMap);
}
void TranslateAndCompileGraph(
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
)
@ -68,7 +76,7 @@ namespace Dml
m_inputsConstant.resize(graphInputCount);
for (uint32_t i = 0; i < graphInputCount; ++i)
{
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap);
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputArgOriginalNames, transferredInitializerMap);
}
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
@ -77,8 +85,8 @@ namespace Dml
m_inputsConstant.size(),
transferredInitializerMap,
graph,
fusedNodeInputDefs,
fusedNodeOutputDefs,
fusedNodeInputArgOriginalNames,
fusedNodeOutputArgOriginalNames,
graphNodePropertyMap,
device.Get(),
m_executionHandle);
@ -95,7 +103,7 @@ namespace Dml
m_inputsConstant,
kernelInfo,
graphDesc,
fusedNodeInputDefs,
fusedNodeInputArgOriginalNames,
m_inputsUsed,
initInputBindings,
initInputResources,
@ -511,9 +519,16 @@ namespace Dml
onnxruntime::OpKernel* CreateFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames
)
{
return new FusedGraphKernel(info, graphNodePropertyMap, transferredInitializerMap);
return new FusedGraphKernel(
info,
graphNodePropertyMap,
transferredInitializerMap,
fusedNodeInputArgOriginalNames,
fusedNodeOutputArgOriginalNames);
}
} // namespace Dml

View file

@ -8,7 +8,9 @@ namespace Dml
{
onnxruntime::OpKernel* CreateFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames
);
} // namespace Dml

View file

@ -38,8 +38,8 @@ namespace Dml::GraphDescBuilder
const size_t isConstGpuGraphInputCount,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle)
@ -56,10 +56,9 @@ namespace Dml::GraphDescBuilder
// Map from Lotus node argument names to input indices of the fused kernel node.
std::unordered_map<std::string, uint32_t> nameToFusedNodeInputIndex;
for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex)
for (size_t inputIndex = 0; inputIndex < fusedNodeInputArgOriginalNames.size(); ++inputIndex)
{
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(fusedNodeInputArgOriginalNames[inputIndex]);
if (!graphInput)
{
@ -168,13 +167,6 @@ namespace Dml::GraphDescBuilder
{
auto iter = nameToFusedNodeInputIndex.find(arg->Name());
// The graph input could be missing the suffix, so try to match without it.
// This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph.
if (iter == nameToFusedNodeInputIndex.end())
{
iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name()));
}
if (iter != nameToFusedNodeInputIndex.end())
{
// This is a graph input
@ -240,10 +232,9 @@ namespace Dml::GraphDescBuilder
assert(graphNodes.size() == orderedNodeIndices.size());
// Add graph output nodes, which might be in a different order from the encapsulating node
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex)
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputArgOriginalNames.size(); ++outputIndex)
{
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(fusedNodeOutputArgOriginalNames[outputIndex]);
ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());

View file

@ -45,8 +45,8 @@ namespace Dml
const size_t isConstGpuGraphInputCount,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle);

View file

@ -105,11 +105,11 @@ namespace GraphKernelHelper
bool GetGraphInputConstness(
uint32_t index,
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
{
// Transferred initializers are uploaded to GPU memory
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name()));
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[index]);
if (iter != transferredInitializerMap.end())
{
return true;
@ -139,7 +139,7 @@ namespace GraphKernelHelper
const std::vector<uint8_t>& inputsConstant,
const onnxruntime::OpKernelInfo& kernelInfo,
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
_Out_ std::vector<bool>& inputsUsed,
_Inout_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
_Inout_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
@ -154,7 +154,7 @@ namespace GraphKernelHelper
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
for (uint32_t i = 0; i < graphInputCount; i++)
{
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]);
if (iter != transferredInitializerMap.end()) {
initializerToLastInputIndexMap[&iter->second] = i;
}
@ -172,7 +172,7 @@ namespace GraphKernelHelper
// initialization or execution). So just throw away the transferred initializer and skip this input.
if (!inputsUsed[i])
{
transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
transferredInitializerMap.erase(fusedNodeInputArgOriginalNames[i]);
if (inputRawData)
{
@ -183,7 +183,7 @@ namespace GraphKernelHelper
}
// Look for the initializer among those transferred from the graph during partitioning
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]);
if (iter != transferredInitializerMap.end())
{
std::byte* tensorPtr = nullptr;
@ -320,43 +320,5 @@ namespace GraphKernelHelper
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
// for a function, and the graph for that function exposed as a kernel property. When the ordering
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
{
const char* suffix = nullptr;
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
}
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
}
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
}
if (suffix)
{
return std::string(
fusedNodeArgeName.begin(),
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
);
}
else
{
return fusedNodeArgeName;
}
}
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -41,7 +41,7 @@ namespace GraphKernelHelper
bool GetGraphInputConstness(
uint32_t index,
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
void ProcessInputData(
@ -50,7 +50,7 @@ namespace GraphKernelHelper
const std::vector<uint8_t>& inputsConstant,
const onnxruntime::OpKernelInfo& kernelInfo,
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
_Out_ std::vector<bool>& inputsUsed,
_Inout_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
_Inout_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
@ -68,8 +68,6 @@ namespace GraphKernelHelper
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName);
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -537,10 +537,18 @@ namespace Dml
printf("\n");
#endif
auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
// These nodeArgNames will be used while creating DML Graph inside FusedGraphKernel.cpp
// Ordering of input/output nodeArgs in below vector will be same as Node::Definitions::input_defs because
// ORT is populating these args as it is while creating the FusedNode at Graph::CreateFusedSubGraphNode()
// Why we need these names?
// After Partitioning and before reaching to FusedGraphKernel, ORT may modify the input/output nodeArg names
// present in FusedNode (Node::Definitions::input_defs) as part of some transformers like memcopy, or L1/L2/L3 transformers.
std::vector<std::string> fusedNodeInputArgOriginalNames = def->inputs;
std::vector<std::string> fusedNodeOutputArgOriginalNames = def->outputs;
auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
{
out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap));
return Status::OK();
out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames));
return Status::OK();
};
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.