[DML EP] Add an implementation for NonZero (#13768)

### Description
Add the NonZero op for DML



### Motivation and Context
NonZero is used in a few transformer models, so having a DML
implementation will stop large tensors from being transferred to the CPU
and back to the GPU
This commit is contained in:
Patrice Vignola 2022-12-02 18:39:21 -08:00 committed by GitHub
parent b9702587df
commit b53bbe7370
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 3 deletions

View file

@ -980,6 +980,8 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|||6+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|Not|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bool)|
|OneHot|*in* indices:**T1**<br> *in* depth:**T2**<br> *in* values:**T3**<br> *out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -0,0 +1,192 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
namespace Dml
{
class DmlOperatorNonZero: public DmlOperator
{
public:
DmlOperatorNonZero(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
std::vector<DimensionType> inputShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
// Scalars have a rank of 0, but DML only supports 1 and more, which is the same
if (inputShape.empty())
{
inputShape.push_back(1);
}
uint32_t numElements = ComputeElementCountFromDimensions(inputShape);
gsl::span<const uint32_t> inputShapes[1] = {inputShape};
DmlOperator::InitializeWithShapes(kernelCreationContext, std::nullopt, std::nullopt, inputShapes, std::nullopt, 1);
m_rank = static_cast<DimensionType>(inputShape.size());
std::vector<DimensionType> outputCountShape = {1};
std::vector<DimensionType> outputCoordinatesShape = {numElements, m_rank};
// TODO: Remove the doubled strides when DML supports native int64 for NonZero
// TensorFlow outputs {rank, numElements}, but DML outputs {numElements, rank}
std::vector<DimensionType> outputCoordinatesStrides = {2, numElements * 2};
m_intermediateTensorDescs = {
TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCountShape),
TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCoordinatesShape, outputCoordinatesStrides),
};
// If the input has no elements, bypass the DML execution
if (numElements == 0)
{
m_emptyInput = true;
}
else
{
m_emptyInput = false;
m_outputCountShape = {1};
m_outputCoordinatesShape = {static_cast<int64_t>(numElements), static_cast<int64_t>(m_rank)};
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> intermediateDescs(m_intermediateTensorDescs.size());
for (size_t i = 0; i < intermediateDescs.size(); i++)
{
intermediateDescs[i] = m_intermediateTensorDescs[i].GetDmlDesc();
}
DML_NONZERO_COORDINATES_OPERATOR_DESC nonzeroCoordinatesDesc = {};
nonzeroCoordinatesDesc.InputTensor = &inputDescs[0];
nonzeroCoordinatesDesc.OutputCountTensor = &intermediateDescs[0];
nonzeroCoordinatesDesc.OutputCoordinatesTensor = &intermediateDescs[1];
// TODO: Remove this hack when DML supports native int64 for NonZero
// We use the int64/uint32 stride hack here, so zero out the data before writing to it
m_zeroOperator = InitializeZeroInt64Tensor(m_intermediateTensorDescs[1].GetBufferSizeInBytes());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_NONZERO_COORDINATES, &nonzeroCoordinatesDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
ExecutionProviderImpl* executionProvider = static_cast<ExecutionProviderImpl*>(m_executionProvider.Get());
// Create the DML output tensor for the number of nonzero elements
onnxruntime::Tensor outputCountDml(onnxruntime::DataTypeImpl::GetType<uint32_t>(), m_outputCountShape, executionProvider->GetGpuAllocator());
Microsoft::WRL::ComPtr<IMLOperatorTensor> outputCountDmlWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
&outputCountDml,
true,
executionProvider,
true);
// Create the DML output tensor for the coordinates (not cropped)
onnxruntime::Tensor intermediateCoordinatesDml(onnxruntime::DataTypeImpl::GetType<int64_t>(), m_outputCoordinatesShape, executionProvider->GetGpuAllocator());
Microsoft::WRL::ComPtr<IMLOperatorTensor> intermediateCoordinatesDmlWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
&intermediateCoordinatesDml,
true,
executionProvider,
true);
std::vector<IMLOperatorTensor*> nonzeroCoordinatesInputTensors = GetInputTensors(kernelContext);
std::vector<IMLOperatorTensor*> nonzeroCoordinatesOutputTensors = {outputCountDmlWrapper.Get(), intermediateCoordinatesDmlWrapper.Get()};
uint32_t nonzeroElementCount = 0;
if (!m_emptyInput)
{
ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(nonzeroCoordinatesInputTensors),
gsl::make_span(nonzeroCoordinatesOutputTensors)));
// Copy the number of nonzero elements back to the CPU
onnxruntime::Tensor outputCountCpu(onnxruntime::DataTypeImpl::GetType<uint32_t>(), {1}, executionProvider->GetCpuInputAllocator());
Microsoft::WRL::ComPtr<IMLOperatorTensor> outputCountCpuWrapper = wil::MakeOrThrow<Windows::AI::MachineLearning::Adapter::TensorWrapper>(
&outputCountCpu,
false,
executionProvider,
true);
ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor(
outputCountCpuWrapper.Get(),
nonzeroCoordinatesOutputTensors.front()));
nonzeroElementCount = outputCountCpu.Data<uint32_t>()[0];
}
// Create the final output tensor, which is cropped to the actual number of nonzero elements
std::vector<uint32_t> outputSizes({m_rank, nonzeroElementCount});
auto outputTensor = kernelContext.GetOutputTensor(0, outputSizes);
if (!m_emptyInput && nonzeroElementCount > 0)
{
// TODO: Remove this hack when DML supports native int64 for NonZero
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensor.GetInterface().Get());
ComPtr<IDMLCompiledOperator> sliceOperator = InitializeSlice(m_intermediateTensorDescs[1], nonzeroElementCount);
// Finally, we crop the output to the actual number of nonzero elements, thus removing the padding
std::array<IMLOperatorTensor*, 1> sliceInputTensors = {nonzeroCoordinatesOutputTensors[1]};
std::array<IMLOperatorTensor*, 1> sliceOutputTensors = {outputTensor.GetInterface().Get()};
ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
sliceOperator.Get(),
nullptr, // persistent resource binding
sliceInputTensors,
sliceOutputTensors));
}
}
private:
ComPtr<IDMLCompiledOperator> InitializeSlice(TensorDesc& inputDesc, uint32_t nonzeroElementCount)
{
assert(inputDesc.GetSizes().size() == 2);
uint32_t rank = inputDesc.GetSizes().back();
std::array<uint32_t, 2> inputWindowOffsets = {0, 0};
std::array<int32_t, 2> inputWindowStrides = {1, 1};
std::array<uint32_t, 2> inputWindowSizes = {nonzeroElementCount, rank};
// TODO: Remove the doubled strides when DML supports native int64 for NonZero
std::array<uint32_t, 2> outputStrides = {2, nonzeroElementCount * 2};
TensorDesc outputDesc(inputDesc.GetDmlDataType(), inputWindowSizes, outputStrides);
const auto inputOpDesc = inputDesc.GetDmlDesc();
const auto outputOpDesc = outputDesc.GetDmlDesc();
DML_SLICE1_OPERATOR_DESC sliceDesc = {};
sliceDesc.DimensionCount = 2;
sliceDesc.InputWindowOffsets = inputWindowOffsets.data();
sliceDesc.InputWindowSizes = inputWindowSizes.data();
sliceDesc.InputWindowStrides = inputWindowStrides.data();
sliceDesc.InputTensor = &inputOpDesc;
sliceDesc.OutputTensor = &outputOpDesc;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc };
ComPtr<IDMLOperator> dmlOperator;
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperator)));
ComPtr<IDMLCompiledOperator> dmlCompiledOperator;
ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&dmlCompiledOperator)));
return dmlCompiledOperator;
}
std::vector<TensorDesc> m_intermediateTensorDescs;
onnxruntime::TensorShape m_outputCountShape;
onnxruntime::TensorShape m_outputCoordinatesShape;
ComPtr<IDMLCompiledOperator> m_zeroOperator;
bool m_emptyInput = false;
uint32_t m_rank = 0;
};
DML_OP_DEFINE_CREATION_FUNCTION(NonZero, DmlOperatorNonZero);
} // namespace Dml

View file

@ -79,6 +79,7 @@ struct OperatorRegistrationInformation
std::optional<uint32_t> requiredInputCountForDmlGraphSupport;
MLOperatorSupportQueryFunction supportQueryFunction;
bool allowDynamicInputShapes = false;
};
DML_OP_EXTERN_CREATION_FUNCTION(Copy);
@ -266,6 +267,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
@ -376,6 +378,9 @@ constexpr auto requiredConstantCpuInputs(Args... args)
#define REG_INFO(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
#define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__,
// Versioned operator
#define REG_INFO_VER(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, ##__VA_ARGS__,
@ -702,6 +707,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
// DmlFused operators
{REG_INFO_MSDML(1, DmlFusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -769,8 +776,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
desc.options = information.shapeInferenceFunction ?
MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes;
desc.options = information.allowDynamicInputShapes ?
MLOperatorKernelOptions::AllowDynamicInputShapes : MLOperatorKernelOptions::None;
desc.minimumOperatorSetVersion = information.sinceVersion;

View file

@ -157,7 +157,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Compress = 9;
static const int sc_sinceVer_EyeLike = 9;
static const int sc_sinceVer_Scatter = 9;
static const int sc_sinceVer_Nonzero = 9;
static const int sc_sinceVer_NonZero = 9;
static const int sc_sinceVer_Shrink = 9;
static const int sc_sinceVer_Greater = 9;
static const int sc_sinceVer_Less = 9;
@ -303,6 +303,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Mod = 13;
static const int sc_sinceVer_Mul = 13;
static const int sc_sinceVer_Neg = 13;
static const int sc_sinceVer_NonZero = 13;
static const int sc_sinceVer_Pad = 13;
static const int sc_sinceVer_Pow = 13;
static const int sc_sinceVer_QuantizeLinear = 13;