diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8cedb7ecc0..19bfdc83aa 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -980,6 +980,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||6+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| +|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)| |OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp new file mode 100644 index 0000000000..62374963ff --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h" + +namespace Dml +{ + +class DmlOperatorNonZero: public DmlOperator +{ +public: + DmlOperatorNonZero(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1); + + std::vector inputShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + + // Scalars have a rank of 0, but DML only supports 1 and more, which is the same + if (inputShape.empty()) + { + inputShape.push_back(1); + } + + uint32_t numElements = ComputeElementCountFromDimensions(inputShape); + + gsl::span inputShapes[1] = {inputShape}; + DmlOperator::InitializeWithShapes(kernelCreationContext, std::nullopt, std::nullopt, inputShapes, std::nullopt, 1); + + m_rank = static_cast(inputShape.size()); + std::vector outputCountShape = {1}; + std::vector outputCoordinatesShape = {numElements, m_rank}; + + // TODO: Remove the doubled strides when DML supports native int64 for NonZero + // TensorFlow outputs {rank, numElements}, but DML outputs {numElements, rank} + std::vector outputCoordinatesStrides = {2, numElements * 2}; + m_intermediateTensorDescs = { + TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCountShape), + TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCoordinatesShape, outputCoordinatesStrides), + }; + + // If the input has no elements, bypass the DML execution + if (numElements == 0) + { + m_emptyInput = true; + } + else + { + m_emptyInput = false; + m_outputCountShape = {1}; + m_outputCoordinatesShape = {static_cast(numElements), static_cast(m_rank)}; + + std::vector inputDescs = GetDmlInputDescs(); + std::vector intermediateDescs(m_intermediateTensorDescs.size()); + for (size_t i = 0; i < intermediateDescs.size(); i++) + { + intermediateDescs[i] = m_intermediateTensorDescs[i].GetDmlDesc(); + } + + DML_NONZERO_COORDINATES_OPERATOR_DESC nonzeroCoordinatesDesc = {}; + nonzeroCoordinatesDesc.InputTensor = &inputDescs[0]; + nonzeroCoordinatesDesc.OutputCountTensor = &intermediateDescs[0]; + nonzeroCoordinatesDesc.OutputCoordinatesTensor = &intermediateDescs[1]; + + // TODO: Remove this hack when DML supports native int64 for NonZero + // We use the int64/uint32 stride hack here, so zero out the data before writing to it + m_zeroOperator = InitializeZeroInt64Tensor(m_intermediateTensorDescs[1].GetBufferSizeInBytes()); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_NONZERO_COORDINATES, &nonzeroCoordinatesDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } + } + + void Compute(const MLOperatorKernelContext& kernelContext) + { + ExecutionProviderImpl* executionProvider = static_cast(m_executionProvider.Get()); + + // Create the DML output tensor for the number of nonzero elements + onnxruntime::Tensor outputCountDml(onnxruntime::DataTypeImpl::GetType(), m_outputCountShape, executionProvider->GetGpuAllocator()); + Microsoft::WRL::ComPtr outputCountDmlWrapper = wil::MakeOrThrow( + &outputCountDml, + true, + executionProvider, + true); + + // Create the DML output tensor for the coordinates (not cropped) + onnxruntime::Tensor intermediateCoordinatesDml(onnxruntime::DataTypeImpl::GetType(), m_outputCoordinatesShape, executionProvider->GetGpuAllocator()); + Microsoft::WRL::ComPtr intermediateCoordinatesDmlWrapper = wil::MakeOrThrow( + &intermediateCoordinatesDml, + true, + executionProvider, + true); + + std::vector nonzeroCoordinatesInputTensors = GetInputTensors(kernelContext); + std::vector nonzeroCoordinatesOutputTensors = {outputCountDmlWrapper.Get(), intermediateCoordinatesDmlWrapper.Get()}; + + uint32_t nonzeroElementCount = 0; + + if (!m_emptyInput) + { + ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator( + m_compiledOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(nonzeroCoordinatesInputTensors), + gsl::make_span(nonzeroCoordinatesOutputTensors))); + + // Copy the number of nonzero elements back to the CPU + onnxruntime::Tensor outputCountCpu(onnxruntime::DataTypeImpl::GetType(), {1}, executionProvider->GetCpuInputAllocator()); + Microsoft::WRL::ComPtr outputCountCpuWrapper = wil::MakeOrThrow( + &outputCountCpu, + false, + executionProvider, + true); + ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor( + outputCountCpuWrapper.Get(), + nonzeroCoordinatesOutputTensors.front())); + nonzeroElementCount = outputCountCpu.Data()[0]; + } + + // Create the final output tensor, which is cropped to the actual number of nonzero elements + std::vector outputSizes({m_rank, nonzeroElementCount}); + auto outputTensor = kernelContext.GetOutputTensor(0, outputSizes); + + if (!m_emptyInput && nonzeroElementCount > 0) + { + // TODO: Remove this hack when DML supports native int64 for NonZero + ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensor.GetInterface().Get()); + + ComPtr sliceOperator = InitializeSlice(m_intermediateTensorDescs[1], nonzeroElementCount); + + // Finally, we crop the output to the actual number of nonzero elements, thus removing the padding + std::array sliceInputTensors = {nonzeroCoordinatesOutputTensors[1]}; + std::array sliceOutputTensors = {outputTensor.GetInterface().Get()}; + + ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator( + sliceOperator.Get(), + nullptr, // persistent resource binding + sliceInputTensors, + sliceOutputTensors)); + } + } + +private: + ComPtr InitializeSlice(TensorDesc& inputDesc, uint32_t nonzeroElementCount) + { + assert(inputDesc.GetSizes().size() == 2); + + uint32_t rank = inputDesc.GetSizes().back(); + std::array inputWindowOffsets = {0, 0}; + std::array inputWindowStrides = {1, 1}; + std::array inputWindowSizes = {nonzeroElementCount, rank}; + + // TODO: Remove the doubled strides when DML supports native int64 for NonZero + std::array outputStrides = {2, nonzeroElementCount * 2}; + TensorDesc outputDesc(inputDesc.GetDmlDataType(), inputWindowSizes, outputStrides); + + const auto inputOpDesc = inputDesc.GetDmlDesc(); + const auto outputOpDesc = outputDesc.GetDmlDesc(); + + DML_SLICE1_OPERATOR_DESC sliceDesc = {}; + sliceDesc.DimensionCount = 2; + sliceDesc.InputWindowOffsets = inputWindowOffsets.data(); + sliceDesc.InputWindowSizes = inputWindowSizes.data(); + sliceDesc.InputWindowStrides = inputWindowStrides.data(); + sliceDesc.InputTensor = &inputOpDesc; + sliceDesc.OutputTensor = &outputOpDesc; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc }; + + ComPtr dmlOperator; + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperator))); + + ComPtr dmlCompiledOperator; + ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&dmlCompiledOperator))); + + return dmlCompiledOperator; + } + + std::vector m_intermediateTensorDescs; + onnxruntime::TensorShape m_outputCountShape; + onnxruntime::TensorShape m_outputCoordinatesShape; + ComPtr m_zeroOperator; + bool m_emptyInput = false; + uint32_t m_rank = 0; +}; + +DML_OP_DEFINE_CREATION_FUNCTION(NonZero, DmlOperatorNonZero); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 0c0d08628c..bef3bf66d3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -79,6 +79,7 @@ struct OperatorRegistrationInformation std::optional requiredInputCountForDmlGraphSupport; MLOperatorSupportQueryFunction supportQueryFunction; + bool allowDynamicInputShapes = false; }; DML_OP_EXTERN_CREATION_FUNCTION(Copy); @@ -266,6 +267,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu); DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); +DML_OP_EXTERN_CREATION_FUNCTION(NonZero); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); @@ -376,6 +378,9 @@ constexpr auto requiredConstantCpuInputs(Args... args) #define REG_INFO(version, operatorName, ...) \ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, +#define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \ + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__, + // Versioned operator #define REG_INFO_VER(version, operatorName, ...) \ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__, @@ -702,6 +707,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)}, // DmlFused operators {REG_INFO_MSDML(1, DmlFusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -769,8 +776,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) // The graph must be configured with operators from only the legacy DML API, or only the new DML API bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported); - desc.options = information.shapeInferenceFunction ? - MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes; + desc.options = information.allowDynamicInputShapes ? + MLOperatorKernelOptions::AllowDynamicInputShapes : MLOperatorKernelOptions::None; desc.minimumOperatorSetVersion = information.sinceVersion; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 224c6fe4af..8b8f4fa4a4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -157,7 +157,7 @@ namespace OperatorHelper static const int sc_sinceVer_Compress = 9; static const int sc_sinceVer_EyeLike = 9; static const int sc_sinceVer_Scatter = 9; - static const int sc_sinceVer_Nonzero = 9; + static const int sc_sinceVer_NonZero = 9; static const int sc_sinceVer_Shrink = 9; static const int sc_sinceVer_Greater = 9; static const int sc_sinceVer_Less = 9; @@ -303,6 +303,7 @@ namespace OperatorHelper static const int sc_sinceVer_Mod = 13; static const int sc_sinceVer_Mul = 13; static const int sc_sinceVer_Neg = 13; + static const int sc_sinceVer_NonZero = 13; static const int sc_sinceVer_Pad = 13; static const int sc_sinceVer_Pow = 13; static const int sc_sinceVer_QuantizeLinear = 13;