De-duplicate 1D scale and zero point tensors to scalars in DML kernels (#18862)

### Description
Cleanup and rebase from [this
PR](https://github.com/microsoft/onnxruntime/pull/18629)



### Motivation and Context

---------

Co-authored-by: Christian Larson <chrilaMSFT@users.noreply.github.com>
Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com>
Co-authored-by: Jeff Bloomfield <jeffbloo@microsoft.com>
Co-authored-by: Anagha Rao <anagrao@microsoft.com>
This commit is contained in:
tbqh 2024-01-02 13:22:30 -06:00 committed by Jeff Bloomfield
parent bdaeebd6ff
commit 70d3f682a7
11 changed files with 153 additions and 9 deletions

View file

@ -85,7 +85,10 @@ namespace Windows::AI::MachineLearning::Adapter
{
uint32_t nodeCount = 0;
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
// TODO (jeffbloo): Remove this
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;

View file

@ -6,6 +6,9 @@
namespace Dml
{
/*static*/ const uint32_t DmlOperator::zeroArray[8] = {};
DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
@ -824,4 +827,84 @@ namespace Dml
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
/*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex)
{
if (!tensor)
{
return;
}
auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(kernelInputIndex);
if (!constExpTensor)
{
return;
}
else if (!constExpTensor->IsCpuData())
{
return;
}
uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
if (totalKernelInputElementCount <= 1)
{
return;
}
uint32_t elementSize = 0;
switch (constExpTensor->GetTensorDataType())
{
case MLOperatorTensorDataType::UInt8:
case MLOperatorTensorDataType::Int8:
elementSize = 1;
break;
case MLOperatorTensorDataType::Float16:
case MLOperatorTensorDataType::UInt16:
case MLOperatorTensorDataType::Int16:
elementSize = 2;
break;
case MLOperatorTensorDataType::/*Float32*/Float:
case MLOperatorTensorDataType::UInt32:
case MLOperatorTensorDataType::Int32:
elementSize = 4;
break;
case MLOperatorTensorDataType::/*Float64*/Double:
case MLOperatorTensorDataType::UInt64:
case MLOperatorTensorDataType::Int64:
elementSize = 8;
break;
default:
return;
}
const std::uint8_t* byteData = static_cast<const std::uint8_t*>(constExpTensor->GetByteData());
assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
auto *bufferTensorDesc = const_cast<DML_BUFFER_TENSOR_DESC*>(static_cast<const DML_BUFFER_TENSOR_DESC*>(tensor->Desc));
for (size_t i = 1; i < totalKernelInputElementCount; ++i)
{
if (memcmp(byteData, byteData + i * elementSize, elementSize))
{
return;
}
}
if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
{
assert(false);
return;
}
bufferTensorDesc->Strides = zeroArray;
bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
}
} // namespace Dml

View file

@ -149,6 +149,11 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
) const;
static void TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex);
private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
@ -164,6 +169,7 @@ namespace Dml
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
static const uint32_t zeroArray[8];
};
} // namespace Dml

View file

@ -111,6 +111,8 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
// TODO (jeffbloo): Port this to a graph description to enable DML graph optimization
dml::Graph graph(m_dmlDevice.Get());
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];

View file

@ -586,6 +586,9 @@ public:
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};

View file

@ -8,15 +8,15 @@ namespace Dml
class DmlOperatorQLinearAdd : public DmlOperator
{
enum InputTensors {
IN_A,
enum InputTensors {
IN_A,
IN_A_SCALE,
IN_A_ZERO_POINT,
IN_B,
IN_A_ZERO_POINT,
IN_B,
IN_B_SCALE,
IN_B_ZERO_POINT,
IN_C_SCALE,
IN_C_ZERO_POINT
IN_C_SCALE,
IN_C_ZERO_POINT
};
public:
@ -56,9 +56,18 @@ public:
AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE];
AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr;
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -118,8 +118,8 @@ public:
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
@ -129,6 +129,12 @@ public:
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -123,6 +123,9 @@ public:
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2);
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
@ -154,6 +157,10 @@ public:
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint);
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);

View file

@ -117,6 +117,15 @@ public:
convDesc.EndPadding = kernelArgs.endPadding;
convDesc.GroupCount = m_groupCount;
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -104,6 +104,15 @@ public:
matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr;
matMulDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}

View file

@ -88,6 +88,9 @@ public:
dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale];
dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point];
dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc;
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point);
const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc};
@ -101,6 +104,10 @@ public:
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale);
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point);
const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
MLOperatorGraphDesc operatorGraphDesc = {};