mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
De-duplicate 1D scale and zero point tensors to scalars in DML kernels
This commit is contained in:
parent
b6c096d458
commit
2efa2ab081
12 changed files with 147 additions and 3 deletions
|
|
@ -82,7 +82,10 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
{
|
||||
uint32_t nodeCount;
|
||||
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
|
||||
|
||||
// TODO: Remove this
|
||||
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
|
||||
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
|
|
|
|||
|
|
@ -314,7 +314,11 @@ namespace Dml::GraphDescBuilder
|
|||
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
|
||||
// across the graph input as well as every consumer's unique constant node. However it is currently
|
||||
// only used for small inputs.
|
||||
uint32_t c_maxConstNodeDataSize = 64;
|
||||
|
||||
// TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
|
||||
// nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
|
||||
// values that have been de-duplicated with conversion to scalars in kernels.
|
||||
uint32_t c_maxConstNodeDataSize = 1024 * 1024;
|
||||
|
||||
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@
|
|||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
/*static*/ const uint32_t DmlOperator::zeroArray[8] = {};
|
||||
|
||||
DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
{
|
||||
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
|
||||
|
|
@ -834,4 +837,80 @@ namespace Dml
|
|||
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
|
||||
}
|
||||
|
||||
/*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
const DML_TENSOR_DESC* tensor,
|
||||
uint32_t kernelInputIndex)
|
||||
{
|
||||
if (!tensor)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(kernelInputIndex);
|
||||
if (!constExpTensor)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
|
||||
if (totalKernelInputElementCount <= 1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t elementSize = 0;
|
||||
|
||||
switch (constExpTensor->GetTensorDataType())
|
||||
{
|
||||
case MLOperatorTensorDataType::UInt8:
|
||||
case MLOperatorTensorDataType::Int8:
|
||||
elementSize = 1;
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
case MLOperatorTensorDataType::UInt16:
|
||||
case MLOperatorTensorDataType::Int16:
|
||||
elementSize = 2;
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::/*Float32*/Float:
|
||||
case MLOperatorTensorDataType::UInt32:
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
elementSize = 4;
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::/*Float64*/Double:
|
||||
case MLOperatorTensorDataType::UInt64:
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
elementSize = 8;
|
||||
break;
|
||||
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
||||
const std::uint8_t* byteData = static_cast<const std::uint8_t*>(constExpTensor->GetByteData());
|
||||
|
||||
assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
|
||||
auto *bufferTensorDesc = const_cast<DML_BUFFER_TENSOR_DESC*>(static_cast<const DML_BUFFER_TENSOR_DESC*>(tensor->Desc));
|
||||
|
||||
for (size_t i = 1; i < totalKernelInputElementCount; ++i)
|
||||
{
|
||||
if (memcmp(byteData, byteData + i * elementSize, elementSize))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
|
||||
{
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
|
||||
bufferTensorDesc->Strides = zeroArray;
|
||||
bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -149,6 +149,11 @@ namespace Dml
|
|||
uint32_t minDimensionCount = NchwDimensionCount
|
||||
) const;
|
||||
|
||||
static void TryConvertTensorToBroadcastScalar(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
const DML_TENSOR_DESC* tensor,
|
||||
uint32_t kernelInputIndex);
|
||||
|
||||
private:
|
||||
// For each input or output of the DML kernel, the corresponding input or output of the original
|
||||
// kernel. Entries for unused DML inputs are nullopt.
|
||||
|
|
@ -164,6 +169,7 @@ namespace Dml
|
|||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
|
||||
|
||||
static const uint32_t zeroArray[8];
|
||||
};
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -111,6 +111,8 @@ public:
|
|||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
// TODO: Port this to a graph description to enable DML graph optimization
|
||||
|
||||
dml::Graph graph(m_dmlDevice.Get());
|
||||
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
|
||||
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];
|
||||
|
|
|
|||
|
|
@ -585,6 +585,9 @@ public:
|
|||
opDesc.ScaleTensor = &inputDescs[1];
|
||||
opDesc.ZeroPointTensor = &inputDescs[2];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);
|
||||
|
||||
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,15 @@ public:
|
|||
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
|
||||
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
|
||||
AddDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
|
|
|
|||
|
|
@ -118,8 +118,8 @@ public:
|
|||
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
|
||||
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
|
||||
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
|
||||
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
|
||||
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
|
||||
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
|
||||
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
|
||||
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
|
||||
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
|
||||
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
|
||||
|
|
@ -129,6 +129,12 @@ public:
|
|||
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
|
||||
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -123,6 +123,9 @@ public:
|
|||
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
|
||||
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
|
||||
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1);
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2);
|
||||
|
||||
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
|
||||
opDescs.push_back(&dmlOpDesc[inputIndex]);
|
||||
|
|
@ -154,6 +157,10 @@ public:
|
|||
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
|
||||
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
|
||||
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale);
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint);
|
||||
|
||||
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
|
||||
opDescs.push_back(&opQuantizeDesc);
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,15 @@ public:
|
|||
convDesc.EndPadding = kernelArgs.endPadding;
|
||||
convDesc.GroupCount = m_groupCount;
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,6 +104,15 @@ public:
|
|||
matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr;
|
||||
matMulDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT);
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,6 +88,9 @@ public:
|
|||
dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale];
|
||||
dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point];
|
||||
dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc;
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale);
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point);
|
||||
|
||||
const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc};
|
||||
|
||||
|
|
@ -101,6 +104,10 @@ public:
|
|||
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale];
|
||||
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point];
|
||||
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale);
|
||||
TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point);
|
||||
|
||||
const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
|
|
|
|||
Loading…
Reference in a new issue