mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
[DML EP] Remove suffix removal adhoc logic for fusedNodeArgNames (#11879)
* DML EP: Remove suffix removal hack for fusedNodeArgName * Acknowledged PR comments * Removed reference from gsl::span Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
parent
f97bd38c4f
commit
52f2b3bf89
7 changed files with 57 additions and 81 deletions
|
|
@ -19,7 +19,9 @@ namespace Dml
|
|||
FusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap) : OpKernel(kernelInfo)
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames) : OpKernel(kernelInfo)
|
||||
{
|
||||
// Get the graph for the function which was created according to the computational
|
||||
// capacity returned by the execution provider's graph partitioner
|
||||
|
|
@ -45,14 +47,20 @@ namespace Dml
|
|||
ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider));
|
||||
}
|
||||
|
||||
TranslateAndCompileGraph(kernelInfo, graph, node.InputDefs(), node.OutputDefs(), graphNodePropertyMap, transferredInitializerMap);
|
||||
TranslateAndCompileGraph(
|
||||
kernelInfo,
|
||||
graph,
|
||||
fusedNodeInputArgOriginalNames,
|
||||
fusedNodeOutputArgOriginalNames,
|
||||
graphNodePropertyMap,
|
||||
transferredInitializerMap);
|
||||
}
|
||||
|
||||
void TranslateAndCompileGraph(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
|
||||
)
|
||||
|
|
@ -68,7 +76,7 @@ namespace Dml
|
|||
m_inputsConstant.resize(graphInputCount);
|
||||
for (uint32_t i = 0; i < graphInputCount; ++i)
|
||||
{
|
||||
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap);
|
||||
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputArgOriginalNames, transferredInitializerMap);
|
||||
}
|
||||
|
||||
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
|
||||
|
|
@ -77,8 +85,8 @@ namespace Dml
|
|||
m_inputsConstant.size(),
|
||||
transferredInitializerMap,
|
||||
graph,
|
||||
fusedNodeInputDefs,
|
||||
fusedNodeOutputDefs,
|
||||
fusedNodeInputArgOriginalNames,
|
||||
fusedNodeOutputArgOriginalNames,
|
||||
graphNodePropertyMap,
|
||||
device.Get(),
|
||||
m_executionHandle);
|
||||
|
|
@ -95,7 +103,7 @@ namespace Dml
|
|||
m_inputsConstant,
|
||||
kernelInfo,
|
||||
graphDesc,
|
||||
fusedNodeInputDefs,
|
||||
fusedNodeInputArgOriginalNames,
|
||||
m_inputsUsed,
|
||||
initInputBindings,
|
||||
initInputResources,
|
||||
|
|
@ -511,9 +519,16 @@ namespace Dml
|
|||
onnxruntime::OpKernel* CreateFusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& info,
|
||||
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames
|
||||
)
|
||||
{
|
||||
return new FusedGraphKernel(info, graphNodePropertyMap, transferredInitializerMap);
|
||||
return new FusedGraphKernel(
|
||||
info,
|
||||
graphNodePropertyMap,
|
||||
transferredInitializerMap,
|
||||
fusedNodeInputArgOriginalNames,
|
||||
fusedNodeOutputArgOriginalNames);
|
||||
}
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ namespace Dml
|
|||
{
|
||||
onnxruntime::OpKernel* CreateFusedGraphKernel(
|
||||
const onnxruntime::OpKernelInfo& info,
|
||||
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames
|
||||
);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ namespace Dml::GraphDescBuilder
|
|||
const size_t isConstGpuGraphInputCount,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle)
|
||||
|
|
@ -56,10 +56,9 @@ namespace Dml::GraphDescBuilder
|
|||
// Map from Lotus node argument names to input indices of the fused kernel node.
|
||||
std::unordered_map<std::string, uint32_t> nameToFusedNodeInputIndex;
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex)
|
||||
for (size_t inputIndex = 0; inputIndex < fusedNodeInputArgOriginalNames.size(); ++inputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(
|
||||
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
|
||||
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(fusedNodeInputArgOriginalNames[inputIndex]);
|
||||
|
||||
if (!graphInput)
|
||||
{
|
||||
|
|
@ -168,13 +167,6 @@ namespace Dml::GraphDescBuilder
|
|||
{
|
||||
auto iter = nameToFusedNodeInputIndex.find(arg->Name());
|
||||
|
||||
// The graph input could be missing the suffix, so try to match without it.
|
||||
// This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph.
|
||||
if (iter == nameToFusedNodeInputIndex.end())
|
||||
{
|
||||
iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name()));
|
||||
}
|
||||
|
||||
if (iter != nameToFusedNodeInputIndex.end())
|
||||
{
|
||||
// This is a graph input
|
||||
|
|
@ -240,10 +232,9 @@ namespace Dml::GraphDescBuilder
|
|||
assert(graphNodes.size() == orderedNodeIndices.size());
|
||||
|
||||
// Add graph output nodes, which might be in a different order from the encapsulating node
|
||||
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex)
|
||||
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputArgOriginalNames.size(); ++outputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
|
||||
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
|
||||
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(fusedNodeOutputArgOriginalNames[outputIndex]);
|
||||
|
||||
ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
|
||||
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ namespace Dml
|
|||
const size_t isConstGpuGraphInputCount,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const gsl::span<const std::string> fusedNodeOutputArgOriginalNames,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle);
|
||||
|
|
|
|||
|
|
@ -105,11 +105,11 @@ namespace GraphKernelHelper
|
|||
bool GetGraphInputConstness(
|
||||
uint32_t index,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
|
||||
{
|
||||
// Transferred initializers are uploaded to GPU memory
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name()));
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[index]);
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
return true;
|
||||
|
|
@ -139,7 +139,7 @@ namespace GraphKernelHelper
|
|||
const std::vector<uint8_t>& inputsConstant,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
_Out_ std::vector<bool>& inputsUsed,
|
||||
_Inout_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
|
||||
_Inout_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
|
||||
|
|
@ -154,7 +154,7 @@ namespace GraphKernelHelper
|
|||
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
|
||||
for (uint32_t i = 0; i < graphInputCount; i++)
|
||||
{
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]);
|
||||
if (iter != transferredInitializerMap.end()) {
|
||||
initializerToLastInputIndexMap[&iter->second] = i;
|
||||
}
|
||||
|
|
@ -172,7 +172,7 @@ namespace GraphKernelHelper
|
|||
// initialization or execution). So just throw away the transferred initializer and skip this input.
|
||||
if (!inputsUsed[i])
|
||||
{
|
||||
transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
transferredInitializerMap.erase(fusedNodeInputArgOriginalNames[i]);
|
||||
|
||||
if (inputRawData)
|
||||
{
|
||||
|
|
@ -183,7 +183,7 @@ namespace GraphKernelHelper
|
|||
}
|
||||
|
||||
// Look for the initializer among those transferred from the graph during partitioning
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]);
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
std::byte* tensorPtr = nullptr;
|
||||
|
|
@ -320,43 +320,5 @@ namespace GraphKernelHelper
|
|||
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
|
||||
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
|
||||
}
|
||||
|
||||
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
|
||||
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
|
||||
// for a function, and the graph for that function exposed as a kernel property. When the ordering
|
||||
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
|
||||
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
|
||||
{
|
||||
const char* suffix = nullptr;
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
|
||||
}
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
|
||||
}
|
||||
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
|
||||
}
|
||||
|
||||
if (suffix)
|
||||
{
|
||||
return std::string(
|
||||
fusedNodeArgeName.begin(),
|
||||
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
return fusedNodeArgeName;
|
||||
}
|
||||
}
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -41,7 +41,7 @@ namespace GraphKernelHelper
|
|||
bool GetGraphInputConstness(
|
||||
uint32_t index,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
|
||||
|
||||
void ProcessInputData(
|
||||
|
|
@ -50,7 +50,7 @@ namespace GraphKernelHelper
|
|||
const std::vector<uint8_t>& inputsConstant,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const gsl::span<const std::string> fusedNodeInputArgOriginalNames,
|
||||
_Out_ std::vector<bool>& inputsUsed,
|
||||
_Inout_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
|
||||
_Inout_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
|
||||
|
|
@ -68,8 +68,6 @@ namespace GraphKernelHelper
|
|||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
|
||||
|
||||
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName);
|
||||
|
||||
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -537,10 +537,18 @@ namespace Dml
|
|||
printf("\n");
|
||||
#endif
|
||||
|
||||
auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
|
||||
// These nodeArgNames will be used while creating DML Graph inside FusedGraphKernel.cpp
|
||||
// Ordering of input/output nodeArgs in below vector will be same as Node::Definitions::input_defs because
|
||||
// ORT is populating these args as it is while creating the FusedNode at Graph::CreateFusedSubGraphNode()
|
||||
// Why we need these names?
|
||||
// After Partitioning and before reaching to FusedGraphKernel, ORT may modify the input/output nodeArg names
|
||||
// present in FusedNode (Node::Definitions::input_defs) as part of some transformers like memcopy, or L1/L2/L3 transformers.
|
||||
std::vector<std::string> fusedNodeInputArgOriginalNames = def->inputs;
|
||||
std::vector<std::string> fusedNodeOutputArgOriginalNames = def->outputs;
|
||||
auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
|
||||
{
|
||||
out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap));
|
||||
return Status::OK();
|
||||
out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
|
||||
|
|
|
|||
Loading…
Reference in a new issue