De-duplicate 1D scale and zero point tensors to scalars in DML kernels

This commit is contained in:
Jeff Bloomfield 2023-10-05 17:48:14 -07:00
parent b6c096d458
commit 2efa2ab081
12 changed files with 147 additions and 3 deletions

View file

@ -82,7 +82,10 @@ namespace Windows::AI::MachineLearning::Adapter
{
uint32_t nodeCount;
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
// TODO: Remove this
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;

View file

@ -314,7 +314,11 @@ namespace Dml::GraphDescBuilder
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
uint32_t c_maxConstNodeDataSize = 64;
// TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
// nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
// values that have been de-duplicated with conversion to scalars in kernels.
uint32_t c_maxConstNodeDataSize = 1024 * 1024;
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());

View file

@ -6,6 +6,9 @@
namespace Dml
{
/*static*/ const uint32_t DmlOperator::zeroArray[8] = {};
DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
@ -834,4 +837,80 @@ namespace Dml
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
/*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex)
{
if (!tensor)
{
return;
}
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(kernelInputIndex);
if (!constExpTensor)
{
return;
}
uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
if (totalKernelInputElementCount <= 1)
{
return;
}
uint32_t elementSize = 0;
switch (constExpTensor->GetTensorDataType())
{
case MLOperatorTensorDataType::UInt8:
case MLOperatorTensorDataType::Int8:
elementSize = 1;
break;
case MLOperatorTensorDataType::Float16:
case MLOperatorTensorDataType::UInt16:
case MLOperatorTensorDataType::Int16:
elementSize = 2;
break;
case MLOperatorTensorDataType::/*Float32*/Float:
case MLOperatorTensorDataType::UInt32:
case MLOperatorTensorDataType::Int32:
elementSize = 4;
break;
case MLOperatorTensorDataType::/*Float64*/Double:
case MLOperatorTensorDataType::UInt64:
case MLOperatorTensorDataType::Int64:
elementSize = 8;
break;
default:
return;
}
const std::uint8_t* byteData = static_cast<const std::uint8_t*>(constExpTensor->GetByteData());
assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
auto *bufferTensorDesc = const_cast<DML_BUFFER_TENSOR_DESC*>(static_cast<const DML_BUFFER_TENSOR_DESC*>(tensor->Desc));
for (size_t i = 1; i < totalKernelInputElementCount; ++i)
{
if (memcmp(byteData, byteData + i * elementSize, elementSize))
{
return;
}
}
if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
{
assert(false);
return;
}
bufferTensorDesc->Strides = zeroArray;
bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
}
} // namespace Dml

View file

@ -149,6 +149,11 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
) const;
static void TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex);
private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
@ -164,6 +169,7 @@ namespace Dml
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
static const uint32_t zeroArray[8];
};
} // namespace Dml

View file

@ -111,6 +111,8 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
// TODO: Port this to a graph description to enable DML graph optimization
dml::Graph graph(m_dmlDevice.Get());
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];

View file

@ -585,6 +585,9 @@ public:
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}

View file

@ -58,6 +58,15 @@ public:
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);

View file

@ -118,8 +118,8 @@ public:
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
@ -129,6 +129,12 @@ public:
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -123,6 +123,9 @@ public:
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2);
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
@ -154,6 +157,10 @@ public:
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint);
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);

View file

@ -117,6 +117,15 @@ public:
convDesc.EndPadding = kernelArgs.endPadding;
convDesc.GroupCount = m_groupCount;
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -104,6 +104,15 @@ public:
matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr;
matMulDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -88,6 +88,9 @@ public:
dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale];
dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point];
dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc;
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point);
const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc};
@ -101,6 +104,10 @@ public:
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point);
const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
MLOperatorGraphDesc operatorGraphDesc = {};