mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[DML EP] Add an implementation for NonZero (#13768)
### Description Add the NonZero op for DML ### Motivation and Context NonZero is used in a few transformer models, so having a DML implementation will stop large tensors from being transferred to the CPU and back to the GPU
This commit is contained in:
parent
b9702587df
commit
b53bbe7370
4 changed files with 205 additions and 3 deletions
|
|
@ -980,6 +980,8 @@ Do not modify directly.*
|
|||
|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|
||||
|||6+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|
||||
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|
||||
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|
||||
|Not|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bool)|
|
||||
|OneHot|*in* indices:**T1**<br> *in* depth:**T2**<br> *in* values:**T3**<br> *out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorNonZero: public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorNonZero(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
|
||||
|
||||
std::vector<DimensionType> inputShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
|
||||
// Scalars have a rank of 0, but DML only supports 1 and more, which is the same
|
||||
if (inputShape.empty())
|
||||
{
|
||||
inputShape.push_back(1);
|
||||
}
|
||||
|
||||
uint32_t numElements = ComputeElementCountFromDimensions(inputShape);
|
||||
|
||||
gsl::span<const uint32_t> inputShapes[1] = {inputShape};
|
||||
DmlOperator::InitializeWithShapes(kernelCreationContext, std::nullopt, std::nullopt, inputShapes, std::nullopt, 1);
|
||||
|
||||
m_rank = static_cast<DimensionType>(inputShape.size());
|
||||
std::vector<DimensionType> outputCountShape = {1};
|
||||
std::vector<DimensionType> outputCoordinatesShape = {numElements, m_rank};
|
||||
|
||||
// TODO: Remove the doubled strides when DML supports native int64 for NonZero
|
||||
// TensorFlow outputs {rank, numElements}, but DML outputs {numElements, rank}
|
||||
std::vector<DimensionType> outputCoordinatesStrides = {2, numElements * 2};
|
||||
m_intermediateTensorDescs = {
|
||||
TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCountShape),
|
||||
TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCoordinatesShape, outputCoordinatesStrides),
|
||||
};
|
||||
|
||||
// If the input has no elements, bypass the DML execution
|
||||
if (numElements == 0)
|
||||
{
|
||||
m_emptyInput = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_emptyInput = false;
|
||||
m_outputCountShape = {1};
|
||||
m_outputCoordinatesShape = {static_cast<int64_t>(numElements), static_cast<int64_t>(m_rank)};
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> intermediateDescs(m_intermediateTensorDescs.size());
|
||||
for (size_t i = 0; i < intermediateDescs.size(); i++)
|
||||
{
|
||||
intermediateDescs[i] = m_intermediateTensorDescs[i].GetDmlDesc();
|
||||
}
|
||||
|
||||
DML_NONZERO_COORDINATES_OPERATOR_DESC nonzeroCoordinatesDesc = {};
|
||||
nonzeroCoordinatesDesc.InputTensor = &inputDescs[0];
|
||||
nonzeroCoordinatesDesc.OutputCountTensor = &intermediateDescs[0];
|
||||
nonzeroCoordinatesDesc.OutputCoordinatesTensor = &intermediateDescs[1];
|
||||
|
||||
// TODO: Remove this hack when DML supports native int64 for NonZero
|
||||
// We use the int64/uint32 stride hack here, so zero out the data before writing to it
|
||||
m_zeroOperator = InitializeZeroInt64Tensor(m_intermediateTensorDescs[1].GetBufferSizeInBytes());
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_NONZERO_COORDINATES, &nonzeroCoordinatesDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(const MLOperatorKernelContext& kernelContext)
|
||||
{
|
||||
ExecutionProviderImpl* executionProvider = static_cast<ExecutionProviderImpl*>(m_executionProvider.Get());
|
||||
|
||||
// Create the DML output tensor for the number of nonzero elements
|
||||
onnxruntime::Tensor outputCountDml(onnxruntime::DataTypeImpl::GetType<uint32_t>(), m_outputCountShape, executionProvider->GetGpuAllocator());
|
||||
Microsoft::WRL::ComPtr<IMLOperatorTensor> outputCountDmlWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
|
||||
&outputCountDml,
|
||||
true,
|
||||
executionProvider,
|
||||
true);
|
||||
|
||||
// Create the DML output tensor for the coordinates (not cropped)
|
||||
onnxruntime::Tensor intermediateCoordinatesDml(onnxruntime::DataTypeImpl::GetType<int64_t>(), m_outputCoordinatesShape, executionProvider->GetGpuAllocator());
|
||||
Microsoft::WRL::ComPtr<IMLOperatorTensor> intermediateCoordinatesDmlWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
|
||||
&intermediateCoordinatesDml,
|
||||
true,
|
||||
executionProvider,
|
||||
true);
|
||||
|
||||
std::vector<IMLOperatorTensor*> nonzeroCoordinatesInputTensors = GetInputTensors(kernelContext);
|
||||
std::vector<IMLOperatorTensor*> nonzeroCoordinatesOutputTensors = {outputCountDmlWrapper.Get(), intermediateCoordinatesDmlWrapper.Get()};
|
||||
|
||||
uint32_t nonzeroElementCount = 0;
|
||||
|
||||
if (!m_emptyInput)
|
||||
{
|
||||
ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
|
||||
m_compiledOperator.Get(),
|
||||
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
|
||||
gsl::make_span(nonzeroCoordinatesInputTensors),
|
||||
gsl::make_span(nonzeroCoordinatesOutputTensors)));
|
||||
|
||||
// Copy the number of nonzero elements back to the CPU
|
||||
onnxruntime::Tensor outputCountCpu(onnxruntime::DataTypeImpl::GetType<uint32_t>(), {1}, executionProvider->GetCpuInputAllocator());
|
||||
Microsoft::WRL::ComPtr<IMLOperatorTensor> outputCountCpuWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
|
||||
&outputCountCpu,
|
||||
false,
|
||||
executionProvider,
|
||||
true);
|
||||
ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor(
|
||||
outputCountCpuWrapper.Get(),
|
||||
nonzeroCoordinatesOutputTensors.front()));
|
||||
nonzeroElementCount = outputCountCpu.Data<uint32_t>()[0];
|
||||
}
|
||||
|
||||
// Create the final output tensor, which is cropped to the actual number of nonzero elements
|
||||
std::vector<uint32_t> outputSizes({m_rank, nonzeroElementCount});
|
||||
auto outputTensor = kernelContext.GetOutputTensor(0, outputSizes);
|
||||
|
||||
if (!m_emptyInput && nonzeroElementCount > 0)
|
||||
{
|
||||
// TODO: Remove this hack when DML supports native int64 for NonZero
|
||||
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensor.GetInterface().Get());
|
||||
|
||||
ComPtr<IDMLCompiledOperator> sliceOperator = InitializeSlice(m_intermediateTensorDescs[1], nonzeroElementCount);
|
||||
|
||||
// Finally, we crop the output to the actual number of nonzero elements, thus removing the padding
|
||||
std::array<IMLOperatorTensor*, 1> sliceInputTensors = {nonzeroCoordinatesOutputTensors[1]};
|
||||
std::array<IMLOperatorTensor*, 1> sliceOutputTensors = {outputTensor.GetInterface().Get()};
|
||||
|
||||
ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
|
||||
sliceOperator.Get(),
|
||||
nullptr, // persistent resource binding
|
||||
sliceInputTensors,
|
||||
sliceOutputTensors));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ComPtr<IDMLCompiledOperator> InitializeSlice(TensorDesc& inputDesc, uint32_t nonzeroElementCount)
|
||||
{
|
||||
assert(inputDesc.GetSizes().size() == 2);
|
||||
|
||||
uint32_t rank = inputDesc.GetSizes().back();
|
||||
std::array<uint32_t, 2> inputWindowOffsets = {0, 0};
|
||||
std::array<int32_t, 2> inputWindowStrides = {1, 1};
|
||||
std::array<uint32_t, 2> inputWindowSizes = {nonzeroElementCount, rank};
|
||||
|
||||
// TODO: Remove the doubled strides when DML supports native int64 for NonZero
|
||||
std::array<uint32_t, 2> outputStrides = {2, nonzeroElementCount * 2};
|
||||
TensorDesc outputDesc(inputDesc.GetDmlDataType(), inputWindowSizes, outputStrides);
|
||||
|
||||
const auto inputOpDesc = inputDesc.GetDmlDesc();
|
||||
const auto outputOpDesc = outputDesc.GetDmlDesc();
|
||||
|
||||
DML_SLICE1_OPERATOR_DESC sliceDesc = {};
|
||||
sliceDesc.DimensionCount = 2;
|
||||
sliceDesc.InputWindowOffsets = inputWindowOffsets.data();
|
||||
sliceDesc.InputWindowSizes = inputWindowSizes.data();
|
||||
sliceDesc.InputWindowStrides = inputWindowStrides.data();
|
||||
sliceDesc.InputTensor = &inputOpDesc;
|
||||
sliceDesc.OutputTensor = &outputOpDesc;
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc };
|
||||
|
||||
ComPtr<IDMLOperator> dmlOperator;
|
||||
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperator)));
|
||||
|
||||
ComPtr<IDMLCompiledOperator> dmlCompiledOperator;
|
||||
ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&dmlCompiledOperator)));
|
||||
|
||||
return dmlCompiledOperator;
|
||||
}
|
||||
|
||||
std::vector<TensorDesc> m_intermediateTensorDescs;
|
||||
onnxruntime::TensorShape m_outputCountShape;
|
||||
onnxruntime::TensorShape m_outputCoordinatesShape;
|
||||
ComPtr<IDMLCompiledOperator> m_zeroOperator;
|
||||
bool m_emptyInput = false;
|
||||
uint32_t m_rank = 0;
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(NonZero, DmlOperatorNonZero);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -79,6 +79,7 @@ struct OperatorRegistrationInformation
|
|||
std::optional<uint32_t> requiredInputCountForDmlGraphSupport;
|
||||
|
||||
MLOperatorSupportQueryFunction supportQueryFunction;
|
||||
bool allowDynamicInputShapes = false;
|
||||
};
|
||||
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Copy);
|
||||
|
|
@ -266,6 +267,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Size);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
|
||||
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
|
||||
|
|
@ -376,6 +378,9 @@ constexpr auto requiredConstantCpuInputs(Args... args)
|
|||
#define REG_INFO(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
|
||||
|
||||
#define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__,
|
||||
|
||||
// Versioned operator
|
||||
#define REG_INFO_VER(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, ##__VA_ARGS__,
|
||||
|
|
@ -702,6 +707,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
|
||||
|
||||
// DmlFused operators
|
||||
{REG_INFO_MSDML(1, DmlFusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -769,8 +776,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
|
||||
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
|
||||
|
||||
desc.options = information.shapeInferenceFunction ?
|
||||
MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes;
|
||||
desc.options = information.allowDynamicInputShapes ?
|
||||
MLOperatorKernelOptions::AllowDynamicInputShapes : MLOperatorKernelOptions::None;
|
||||
|
||||
desc.minimumOperatorSetVersion = information.sinceVersion;
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Compress = 9;
|
||||
static const int sc_sinceVer_EyeLike = 9;
|
||||
static const int sc_sinceVer_Scatter = 9;
|
||||
static const int sc_sinceVer_Nonzero = 9;
|
||||
static const int sc_sinceVer_NonZero = 9;
|
||||
static const int sc_sinceVer_Shrink = 9;
|
||||
static const int sc_sinceVer_Greater = 9;
|
||||
static const int sc_sinceVer_Less = 9;
|
||||
|
|
@ -303,6 +303,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Mod = 13;
|
||||
static const int sc_sinceVer_Mul = 13;
|
||||
static const int sc_sinceVer_Neg = 13;
|
||||
static const int sc_sinceVer_NonZero = 13;
|
||||
static const int sc_sinceVer_Pad = 13;
|
||||
static const int sc_sinceVer_Pow = 13;
|
||||
static const int sc_sinceVer_QuantizeLinear = 13;
|
||||
|
|
|
|||
Loading…
Reference in a new issue