diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8cedb7ecc0..19bfdc83aa 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -980,6 +980,8 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|||6+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
+|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
+|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)|
|OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp
new file mode 100644
index 0000000000..62374963ff
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNonZero.cpp
@@ -0,0 +1,192 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
+#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
+
+namespace Dml
+{
+
+class DmlOperatorNonZero: public DmlOperator
+{
+public:
+ DmlOperatorNonZero(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext)
+ {
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
+
+ std::vector inputShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
+
+ // Scalars have a rank of 0, but DML only supports 1 and more, which is the same
+ if (inputShape.empty())
+ {
+ inputShape.push_back(1);
+ }
+
+ uint32_t numElements = ComputeElementCountFromDimensions(inputShape);
+
+ gsl::span inputShapes[1] = {inputShape};
+ DmlOperator::InitializeWithShapes(kernelCreationContext, std::nullopt, std::nullopt, inputShapes, std::nullopt, 1);
+
+ m_rank = static_cast(inputShape.size());
+ std::vector outputCountShape = {1};
+ std::vector outputCoordinatesShape = {numElements, m_rank};
+
+ // TODO: Remove the doubled strides when DML supports native int64 for NonZero
+ // TensorFlow outputs {rank, numElements}, but DML outputs {numElements, rank}
+ std::vector outputCoordinatesStrides = {2, numElements * 2};
+ m_intermediateTensorDescs = {
+ TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCountShape),
+ TensorDesc(DML_TENSOR_DATA_TYPE_UINT32, outputCoordinatesShape, outputCoordinatesStrides),
+ };
+
+ // If the input has no elements, bypass the DML execution
+ if (numElements == 0)
+ {
+ m_emptyInput = true;
+ }
+ else
+ {
+ m_emptyInput = false;
+ m_outputCountShape = {1};
+ m_outputCoordinatesShape = {static_cast(numElements), static_cast(m_rank)};
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector intermediateDescs(m_intermediateTensorDescs.size());
+ for (size_t i = 0; i < intermediateDescs.size(); i++)
+ {
+ intermediateDescs[i] = m_intermediateTensorDescs[i].GetDmlDesc();
+ }
+
+ DML_NONZERO_COORDINATES_OPERATOR_DESC nonzeroCoordinatesDesc = {};
+ nonzeroCoordinatesDesc.InputTensor = &inputDescs[0];
+ nonzeroCoordinatesDesc.OutputCountTensor = &intermediateDescs[0];
+ nonzeroCoordinatesDesc.OutputCoordinatesTensor = &intermediateDescs[1];
+
+ // TODO: Remove this hack when DML supports native int64 for NonZero
+ // We use the int64/uint32 stride hack here, so zero out the data before writing to it
+ m_zeroOperator = InitializeZeroInt64Tensor(m_intermediateTensorDescs[1].GetBufferSizeInBytes());
+
+ DML_OPERATOR_DESC opDesc = { DML_OPERATOR_NONZERO_COORDINATES, &nonzeroCoordinatesDesc };
+ SetDmlOperatorDesc(opDesc, kernelCreationContext);
+ }
+ }
+
+ void Compute(const MLOperatorKernelContext& kernelContext)
+ {
+ ExecutionProviderImpl* executionProvider = static_cast(m_executionProvider.Get());
+
+ // Create the DML output tensor for the number of nonzero elements
+ onnxruntime::Tensor outputCountDml(onnxruntime::DataTypeImpl::GetType(), m_outputCountShape, executionProvider->GetGpuAllocator());
+ Microsoft::WRL::ComPtr outputCountDmlWrapper = wil::MakeOrThrow(
+ &outputCountDml,
+ true,
+ executionProvider,
+ true);
+
+ // Create the DML output tensor for the coordinates (not cropped)
+ onnxruntime::Tensor intermediateCoordinatesDml(onnxruntime::DataTypeImpl::GetType(), m_outputCoordinatesShape, executionProvider->GetGpuAllocator());
+ Microsoft::WRL::ComPtr intermediateCoordinatesDmlWrapper = wil::MakeOrThrow(
+ &intermediateCoordinatesDml,
+ true,
+ executionProvider,
+ true);
+
+ std::vector nonzeroCoordinatesInputTensors = GetInputTensors(kernelContext);
+ std::vector nonzeroCoordinatesOutputTensors = {outputCountDmlWrapper.Get(), intermediateCoordinatesDmlWrapper.Get()};
+
+ uint32_t nonzeroElementCount = 0;
+
+ if (!m_emptyInput)
+ {
+ ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
+ m_compiledOperator.Get(),
+ m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
+ gsl::make_span(nonzeroCoordinatesInputTensors),
+ gsl::make_span(nonzeroCoordinatesOutputTensors)));
+
+ // Copy the number of nonzero elements back to the CPU
+ onnxruntime::Tensor outputCountCpu(onnxruntime::DataTypeImpl::GetType(), {1}, executionProvider->GetCpuInputAllocator());
+ Microsoft::WRL::ComPtr outputCountCpuWrapper = wil::MakeOrThrow(
+ &outputCountCpu,
+ false,
+ executionProvider,
+ true);
+ ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor(
+ outputCountCpuWrapper.Get(),
+ nonzeroCoordinatesOutputTensors.front()));
+ nonzeroElementCount = outputCountCpu.Data()[0];
+ }
+
+ // Create the final output tensor, which is cropped to the actual number of nonzero elements
+ std::vector outputSizes({m_rank, nonzeroElementCount});
+ auto outputTensor = kernelContext.GetOutputTensor(0, outputSizes);
+
+ if (!m_emptyInput && nonzeroElementCount > 0)
+ {
+ // TODO: Remove this hack when DML supports native int64 for NonZero
+ ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensor.GetInterface().Get());
+
+ ComPtr sliceOperator = InitializeSlice(m_intermediateTensorDescs[1], nonzeroElementCount);
+
+ // Finally, we crop the output to the actual number of nonzero elements, thus removing the padding
+ std::array sliceInputTensors = {nonzeroCoordinatesOutputTensors[1]};
+ std::array sliceOutputTensors = {outputTensor.GetInterface().Get()};
+
+ ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
+ sliceOperator.Get(),
+ nullptr, // persistent resource binding
+ sliceInputTensors,
+ sliceOutputTensors));
+ }
+ }
+
+private:
+ ComPtr InitializeSlice(TensorDesc& inputDesc, uint32_t nonzeroElementCount)
+ {
+ assert(inputDesc.GetSizes().size() == 2);
+
+ uint32_t rank = inputDesc.GetSizes().back();
+ std::array inputWindowOffsets = {0, 0};
+ std::array inputWindowStrides = {1, 1};
+ std::array inputWindowSizes = {nonzeroElementCount, rank};
+
+ // TODO: Remove the doubled strides when DML supports native int64 for NonZero
+ std::array outputStrides = {2, nonzeroElementCount * 2};
+ TensorDesc outputDesc(inputDesc.GetDmlDataType(), inputWindowSizes, outputStrides);
+
+ const auto inputOpDesc = inputDesc.GetDmlDesc();
+ const auto outputOpDesc = outputDesc.GetDmlDesc();
+
+ DML_SLICE1_OPERATOR_DESC sliceDesc = {};
+ sliceDesc.DimensionCount = 2;
+ sliceDesc.InputWindowOffsets = inputWindowOffsets.data();
+ sliceDesc.InputWindowSizes = inputWindowSizes.data();
+ sliceDesc.InputWindowStrides = inputWindowStrides.data();
+ sliceDesc.InputTensor = &inputOpDesc;
+ sliceDesc.OutputTensor = &outputOpDesc;
+
+ DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc };
+
+ ComPtr dmlOperator;
+ ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperator)));
+
+ ComPtr dmlCompiledOperator;
+ ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&dmlCompiledOperator)));
+
+ return dmlCompiledOperator;
+ }
+
+ std::vector m_intermediateTensorDescs;
+ onnxruntime::TensorShape m_outputCountShape;
+ onnxruntime::TensorShape m_outputCoordinatesShape;
+ ComPtr m_zeroOperator;
+ bool m_emptyInput = false;
+ uint32_t m_rank = 0;
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(NonZero, DmlOperatorNonZero);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 0c0d08628c..bef3bf66d3 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -79,6 +79,7 @@ struct OperatorRegistrationInformation
std::optional requiredInputCountForDmlGraphSupport;
MLOperatorSupportQueryFunction supportQueryFunction;
+ bool allowDynamicInputShapes = false;
};
DML_OP_EXTERN_CREATION_FUNCTION(Copy);
@@ -266,6 +267,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
+DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
@@ -376,6 +378,9 @@ constexpr auto requiredConstantCpuInputs(Args... args)
#define REG_INFO(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__,
+#define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \
+ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__,
+
// Versioned operator
#define REG_INFO_VER(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__,
@@ -702,6 +707,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
+ {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
+ {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)},
// DmlFused operators
{REG_INFO_MSDML(1, DmlFusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@@ -769,8 +776,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
- desc.options = information.shapeInferenceFunction ?
- MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes;
+ desc.options = information.allowDynamicInputShapes ?
+ MLOperatorKernelOptions::AllowDynamicInputShapes : MLOperatorKernelOptions::None;
desc.minimumOperatorSetVersion = information.sinceVersion;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 224c6fe4af..8b8f4fa4a4 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -157,7 +157,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Compress = 9;
static const int sc_sinceVer_EyeLike = 9;
static const int sc_sinceVer_Scatter = 9;
- static const int sc_sinceVer_Nonzero = 9;
+ static const int sc_sinceVer_NonZero = 9;
static const int sc_sinceVer_Shrink = 9;
static const int sc_sinceVer_Greater = 9;
static const int sc_sinceVer_Less = 9;
@@ -303,6 +303,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Mod = 13;
static const int sc_sinceVer_Mul = 13;
static const int sc_sinceVer_Neg = 13;
+ static const int sc_sinceVer_NonZero = 13;
static const int sc_sinceVer_Pad = 13;
static const int sc_sinceVer_Pow = 13;
static const int sc_sinceVer_QuantizeLinear = 13;