Merged PR 5691446: QLinear Graph Support

Enables Graph Support for QLinearConv, ConvInt, QLinearMatMul, and MatMulInt

Related work items: #31249591
This commit is contained in:
Nick Feeney 2021-03-10 20:46:40 +00:00
parent 46e026e900
commit 50973de1a2
5 changed files with 16 additions and 29 deletions

View file

@ -134,29 +134,6 @@ namespace Dml::GraphDescBuilder
&graphNodeInfo
);
// Determine the number of valid inputs and outputs of this node. The graph currently supports opererators
// with unused inputs and outputs only at the end of each list.
uint32_t validOpInputCount = 0;
uint32_t validOpOutputCount = 0;
for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i)
{
if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpInputCount == 0);
++validOpInputCount;
}
}
for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i)
{
if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpOutputCount == 0);
++validOpOutputCount;
}
}
uint32_t nodeIndex = gsl::narrow_cast<uint32_t>(graphNodes.size());
AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy
@ -166,8 +143,13 @@ namespace Dml::GraphDescBuilder
std::vector<DmlBufferTensorDesc*> outputTensorDescs = opDesc.GetOutputTensors();
// Set connections of the new node
for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex)
for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex)
{
if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits<uint32_t>::max())
{
continue;
}
uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex];
const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex];
@ -219,8 +201,13 @@ namespace Dml::GraphDescBuilder
// Store the new node for lookup when downstream nodes consume it.
for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex)
for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex)
{
if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits<uint32_t>::max())
{
continue;
}
uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex];
const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex];
if (arg->Exists())

View file

@ -576,10 +576,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
};
template<typename T>

View file

Before

Width:  |  Height:  |  Size: 28 KiB

After

Width:  |  Height:  |  Size: 28 KiB

View file

Before

Width:  |  Height:  |  Size: 52 KiB

After

Width:  |  Height:  |  Size: 52 KiB

View file

Before

Width:  |  Height:  |  Size: 67 KiB

After

Width:  |  Height:  |  Size: 67 KiB