mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Merged PR 5691446: QLinear Graph Support
Enables Graph Support for QLinearConv, ConvInt, QLinearMatMul, and MatMulInt Related work items: #31249591
This commit is contained in:
parent
46e026e900
commit
50973de1a2
5 changed files with 16 additions and 29 deletions
|
|
@ -134,29 +134,6 @@ namespace Dml::GraphDescBuilder
|
|||
&graphNodeInfo
|
||||
);
|
||||
|
||||
// Determine the number of valid inputs and outputs of this node. The graph currently supports opererators
|
||||
// with unused inputs and outputs only at the end of each list.
|
||||
uint32_t validOpInputCount = 0;
|
||||
uint32_t validOpOutputCount = 0;
|
||||
|
||||
for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i)
|
||||
{
|
||||
if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
assert(i - validOpInputCount == 0);
|
||||
++validOpInputCount;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i)
|
||||
{
|
||||
if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
assert(i - validOpOutputCount == 0);
|
||||
++validOpOutputCount;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t nodeIndex = gsl::narrow_cast<uint32_t>(graphNodes.size());
|
||||
AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy
|
||||
|
||||
|
|
@ -166,8 +143,13 @@ namespace Dml::GraphDescBuilder
|
|||
std::vector<DmlBufferTensorDesc*> outputTensorDescs = opDesc.GetOutputTensors();
|
||||
|
||||
// Set connections of the new node
|
||||
for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex)
|
||||
for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex)
|
||||
{
|
||||
if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex];
|
||||
|
||||
const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex];
|
||||
|
|
@ -219,8 +201,13 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
// Store the new node for lookup when downstream nodes consume it.
|
||||
|
||||
for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex)
|
||||
for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex)
|
||||
{
|
||||
if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex];
|
||||
const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex];
|
||||
if (arg->Exists())
|
||||
|
|
|
|||
|
|
@ -576,10 +576,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
|
|
|||
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/daisy.jpg
Normal file → Executable file
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/daisy.jpg
Normal file → Executable file
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/rose.jpg
Normal file → Executable file
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/rose.jpg
Normal file → Executable file
|
Before Width: | Height: | Size: 52 KiB After Width: | Height: | Size: 52 KiB |
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/tulip.jpg
Normal file → Executable file
0
onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/tulip.jpg
Normal file → Executable file
|
Before Width: | Height: | Size: 67 KiB After Width: | Height: | Size: 67 KiB |
Loading…
Reference in a new issue