diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 118aba78a6..eb67973d3c 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -20,7 +20,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.2.1.0) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0) # Restore nuget packages, which will pull down the DirectML redist package add_custom_command( diff --git a/onnxruntime/core/providers/dml/.clang-format b/onnxruntime/core/providers/dml/.clang-format new file mode 100644 index 0000000000..4322280747 --- /dev/null +++ b/onnxruntime/core/providers/dml/.clang-format @@ -0,0 +1,3 @@ +# Readability matters. Prevent syntax noise in pull requests for people who +# accidentally leave enabled the auto-formatting options in Visual Studio. +DisableFormat: true diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 0415976b58..8f58597f45 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -94,11 +94,16 @@ namespace Windows::AI::MachineLearning::Adapter const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; - + struct GraphNodeFactoryRegistration { GraphNodeFactory factory; std::optional requiredInputCount; + + // The operator inputs/outputs must be a floating point data type. When true, + // if the node's tensor data type is not-floating point, the node is partioned + // separately (unless the input/output is a CPU constant input, which is okay, + // as those can be read directly by the DML operator in the DML_OPERATOR_DESC). bool requiresFloatFormatsExceptConstInputs = false; }; @@ -109,6 +114,20 @@ namespace Windows::AI::MachineLearning::Adapter std::vector requiredConstantCpuInputs; std::optional graphNodeFactoryRegistration; KernelSupportQuery supportQuery; + + // Many ONNX operators use 64-bit tensors, but most DML operators only support + // 32-bit indices. This flag indicates to the graph whether it's okay to compute + // the result using 32-bit tensors (ignoring the upper bits) via doubled strides. + bool supportedWith64BitTensorsVia32BitStrides = false; + + // When true, the input to the current operator may come from any execution + // provider. Otherwise it must have come from another DML node to assume it's safe + // to use 64-bit to 32-bit striding. + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false; + + // Operator supports true 64-bit tensors directly, no strides needed. + // So fallback to strided 32-bit only occurs when the device lacks 64-bit support. + bool prefer64BitTensorsDirectly = false; }; using InternalRegistrationInfoMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ff4061bc8e..888110d477 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -334,6 +334,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool supportsGraph, const uint32_t* requiredInputCountForGraph, bool requiresFloatFormatsForGraph, + bool supportedWith64BitTensorsVia32BitStrides, + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, + bool prefer64BitTensorsDirectly, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { @@ -456,6 +459,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( { auto regInfo = std::make_shared(); regInfo->requiredConstantCpuInputs = constantCpuInputCapture; + regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides; + regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp; + regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly; // Only internal operators support usage in DML graphs if (supportsGraph) @@ -527,8 +533,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( else { // Currently unsupported for external operators - if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph || - requiresFloatFormatsForGraph || requiredConstantCpuInputs) + if (canAliasFirstInput || + supportsGraph || + requiredInputCountForGraph || + requiresFloatFormatsForGraph || + requiredConstantCpuInputs || + supportedWith64BitTensorsVia32BitStrides || + supportedWith64BitTensorsVia32BitStridesFromAnyEp || + prefer64BitTensorsDirectly) { THROW_HR(E_INVALIDARG); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 51075019ed..2bd9bf8b1a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -42,6 +42,9 @@ class AbiCustomRegistry : public WRL::Base dimensions, MLOperatorTensorDataType tensorDataType) { return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType); @@ -90,4 +108,40 @@ size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor) return ComputeByteSizeFromDimensions(gsl::make_span(dimensions.data(), dimensionCount), tensor.GetTensorDataType()); } +uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice) +{ + uint32_t deviceTypeMask = 0u; + + // Form the bitmask of all supported data types. + for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i) + { + DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast(i) }; + DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {}; + + THROW_IF_FAILED(dmlDevice->CheckFeatureSupport( + DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT, + sizeof(dataTypeQuery), + &dataTypeQuery, + sizeof(dataTypeSupport), + &dataTypeSupport + )); + + deviceTypeMask |= (dataTypeSupport.IsSupported << i); + } + + return deviceTypeMask; +} + +void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides) +{ + assert(sizes.size() == strides.size()); + + uint32_t stride = 1; + for (size_t i = strides.size(); i-- > 0; ) + { + strides[i] = stride; + stride *= sizes[i]; + } +} + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h index 9bfcb83f32..0f0c533558 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h @@ -10,13 +10,16 @@ namespace Dml { using namespace OperatorHelper; - static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX; + static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX1; DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType); DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept; + DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept; MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType); size_t ComputeByteSizeFromDimensions(gsl::span dimensions, MLOperatorTensorDataType tensorDataType); size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor); + uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice); + void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides); bool IsSigned(DML_TENSOR_DATA_TYPE dataType); @@ -40,6 +43,12 @@ namespace Dml UINT elementSizeInBytes = 0; switch (dataType) { + case DML_TENSOR_DATA_TYPE_FLOAT64: + case DML_TENSOR_DATA_TYPE_UINT64: + case DML_TENSOR_DATA_TYPE_INT64: + elementSizeInBytes = 8; + break; + case DML_TENSOR_DATA_TYPE_FLOAT32: case DML_TENSOR_DATA_TYPE_UINT32: case DML_TENSOR_DATA_TYPE_INT32: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 7f1617855b..011cdcbe24 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -376,8 +376,9 @@ namespace Dml { assert(!m_closed); + const size_t sourceSizeInBytes = ComputeByteSizeFromTensor(*src); const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst); - THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src)); // Tensors must be the same size + THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != sourceSizeInBytes); // Tensors must be the same size if (dataSizeInBytes == 0) { @@ -461,7 +462,7 @@ namespace Dml } CATCH_RETURN(); - uint32_t ExecutionProviderImpl::GetSuppportedDeviceDataTypeMask() const + uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const { // The DML provider registers all supported kernels up-front regardless of actual device capability, // but this is problematic later when executing the graph because DirectML will fail to create @@ -470,26 +471,7 @@ namespace Dml // handle them, similar to the fallback in CUDAExecutionProvider::GetCapability for certain RNN/GRU/Conv // attributes. - uint32_t deviceTypeMask = 0u; - - // Form the bitmask of all supported data types. - for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i) - { - DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast(i) }; - DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {}; - - THROW_IF_FAILED(m_dmlDevice->CheckFeatureSupport( - DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT, - sizeof(dataTypeQuery), - &dataTypeQuery, - sizeof(dataTypeSupport), - &dataTypeSupport - )); - - deviceTypeMask |= (dataTypeSupport.IsSupported << i); - } - - return deviceTypeMask; + return Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get()); } std::vector> @@ -498,7 +480,7 @@ namespace Dml const std::vector& registries) const { std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_"; - uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask(); + uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); return PartitionGraph( graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 670021be41..27359e5ebf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -16,10 +16,9 @@ using Base = Microsoft::WRL::RuntimeClass< TInterfaces...>; } -using namespace Microsoft::WRL; - namespace Dml { + using Microsoft::WRL::ComPtr; class PooledUploadHeap; class ReadbackHeap; class ExecutionContext; @@ -87,7 +86,7 @@ namespace Dml const std::vector& registries ) const; - uint32_t GetSuppportedDeviceDataTypeMask() const; + uint32_t GetSupportedDeviceDataTypeMask() const; onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const; onnxruntime::common::Status CopyTensors(const std::vector& src_dst_pairs) const; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h index bf565bddea..4689bae39d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h @@ -7,6 +7,7 @@ union ActivationOperatorDescUnion { DML_ACTIVATION_IDENTITY_OPERATOR_DESC identity; DML_ACTIVATION_ELU_OPERATOR_DESC elu; + DML_ACTIVATION_CELU_OPERATOR_DESC celu; DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax; DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid; DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu; @@ -36,6 +37,7 @@ struct ActivationOperatorDesc switch (activationType) { case DML_OPERATOR_ACTIVATION_ELU: return { activationType, ¶ms.elu }; + case DML_OPERATOR_ACTIVATION_CELU: return { activationType, ¶ms.celu }; case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, ¶ms.hardmax }; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, ¶ms.sigmoid }; case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, ¶ms.identity }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index b8af6e20a0..2b177faeda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 120; - static constexpr size_t ActivationFunctionCount = 19; + static constexpr auto ValueCount = 141; + static constexpr size_t ActivationFunctionCount = 20; }; template <> @@ -62,7 +62,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 3; + static constexpr auto ValueCount = 4; }; template <> @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 2; + static constexpr auto ValueCount = 4; }; template <> @@ -113,6 +113,12 @@ struct EnumTraits static constexpr auto ValueCount = 3; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 1; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -273,6 +279,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL; +}; + template <> struct OperatorDescTraits { @@ -393,6 +411,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REDUCE; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMIN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMAX; +}; + template <> struct OperatorDescTraits { @@ -747,6 +777,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE1; +}; + template <> struct OperatorDescTraits { @@ -771,12 +807,108 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_AND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_OR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_XOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_NOT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_COUNT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_RELU_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RANDOM_GENERATOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_NONZERO_COORDINATES; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ADAM_OPTIMIZER; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1; +}; + template <> struct OperatorDescTraits { static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_ELU; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_CELU; +}; + template <> struct OperatorDescTraits { @@ -993,6 +1125,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_L using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT> { @@ -1113,6 +1257,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REDUCE> using DescType = DML_REDUCE_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMIN> +{ + using DescType = DML_ARGMIN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMAX> +{ + using DescType = DML_ARGMAX_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> { @@ -1467,6 +1623,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZ using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE1> +{ + using DescType = DML_RESAMPLE1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER> { @@ -1491,12 +1653,108 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVO using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_AND> +{ + using DescType = DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_OR> +{ + using DescType = DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_XOR> +{ + using DescType = DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_NOT> +{ + using DescType = DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_COUNT> +{ + using DescType = DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_RELU_GRAD> +{ + using DescType = DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING_GRAD> +{ + using DescType = DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING_GRAD> +{ + using DescType = DML_MAX_POOLING_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RANDOM_GENERATOR> +{ + using DescType = DML_RANDOM_GENERATOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_NONZERO_COORDINATES> +{ + using DescType = DML_NONZERO_COORDINATES_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE_GRAD> +{ + using DescType = DML_RESAMPLE_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE_GRAD> +{ + using DescType = DML_SLICE_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ADAM_OPTIMIZER> +{ + using DescType = DML_ADAM_OPTIMIZER_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN> +{ + using DescType = DML_ROI_ALIGN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1> +{ + using DescType = DML_GATHER_ND1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { using DescType = DML_ACTIVATION_ELU_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_CELU> +{ + using DescType = DML_ACTIVATION_CELU_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX> { @@ -1652,6 +1910,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: @@ -1692,6 +1954,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_GEMM_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_REDUCE: return std::invoke(std::forward(visitor), DML_REDUCE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ARGMIN: + return std::invoke(std::forward(visitor), DML_ARGMIN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ARGMAX: + return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_AVERAGE_POOLING: return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LP_POOLING: @@ -1820,8 +2086,40 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_CONVOLUTION_INTEGER_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_AND: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_OR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_RELU_GRAD: + return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_AVERAGE_POOLING_GRAD: + return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING_GRAD: + return std::invoke(std::forward(visitor), DML_MAX_POOLING_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_RANDOM_GENERATOR: + return std::invoke(std::forward(visitor), DML_RANDOM_GENERATOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_NONZERO_COORDINATES: + return std::invoke(std::forward(visitor), DML_NONZERO_COORDINATES_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_RESAMPLE_GRAD: + return std::invoke(std::forward(visitor), DML_RESAMPLE_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SLICE_GRAD: + return std::invoke(std::forward(visitor), DML_SLICE_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ADAM_OPTIMIZER: + return std::invoke(std::forward(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ROI_ALIGN: + return std::invoke(std::forward(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ND1: + return std::invoke(std::forward(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_CELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARDMAX: return std::invoke(std::forward(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: @@ -1887,6 +2185,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS"; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN"; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL"; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT"; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR"; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR"; @@ -1907,6 +2207,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; + case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; + case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; @@ -1971,6 +2273,21 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY"; case DML_OPERATOR_CONVOLUTION_INTEGER: return "DML_OPERATOR_CONVOLUTION_INTEGER"; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION"; + case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return "DML_OPERATOR_ELEMENT_WISE_BIT_AND"; + case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return "DML_OPERATOR_ELEMENT_WISE_BIT_OR"; + case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return "DML_OPERATOR_ELEMENT_WISE_BIT_XOR"; + case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return "DML_OPERATOR_ELEMENT_WISE_BIT_NOT"; + case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return "DML_OPERATOR_ELEMENT_WISE_BIT_COUNT"; + case DML_OPERATOR_ACTIVATION_RELU_GRAD: return "DML_OPERATOR_ACTIVATION_RELU_GRAD"; + case DML_OPERATOR_AVERAGE_POOLING_GRAD: return "DML_OPERATOR_AVERAGE_POOLING_GRAD"; + case DML_OPERATOR_MAX_POOLING_GRAD: return "DML_OPERATOR_MAX_POOLING_GRAD"; + case DML_OPERATOR_RANDOM_GENERATOR: return "DML_OPERATOR_RANDOM_GENERATOR"; + case DML_OPERATOR_NONZERO_COORDINATES: return "DML_OPERATOR_NONZERO_COORDINATES"; + case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD"; + case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; + case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; + case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; + case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 4d6b40dc14..ec3d24070b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -296,6 +296,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS[2] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -599,6 +627,38 @@ constexpr DML_OPERATOR_SCHEMA DML_REDUCE_OPERATOR_SCHEMA { DML_REDUCE_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ARGMIN_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ARGMIN_OPERATOR_SCHEMA { + "DML_OPERATOR_ARGMIN", + DML_OPERATOR_ARGMIN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_ARGMIN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ARGMAX_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ARGMAX_OPERATOR_SCHEMA { + "DML_OPERATOR_ARGMAX", + DML_OPERATOR_ARGMAX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_ARGMAX_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1628,7 +1688,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIEL DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1648,6 +1708,247 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA { DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_AND", + DML_OPERATOR_ELEMENT_WISE_BIT_AND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_OR", + DML_OPERATOR_ELEMENT_WISE_BIT_OR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_XOR", + DML_OPERATOR_ELEMENT_WISE_BIT_XOR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_NOT", + DML_OPERATOR_ELEMENT_WISE_BIT_NOT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", + DML_OPERATOR_ELEMENT_WISE_BIT_COUNT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_RELU_GRAD", + DML_OPERATOR_ACTIVATION_RELU_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_AVERAGE_POOLING_GRAD", + DML_OPERATOR_AVERAGE_POOLING_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_POOLING_GRAD", + DML_OPERATOR_MAX_POOLING_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputStateTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputStateTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Type", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_RANDOM_GENERATOR_OPERATOR_SCHEMA { + "DML_OPERATOR_RANDOM_GENERATOR", + DML_OPERATOR_RANDOM_GENERATOR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCountTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCoordinatesTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_NONZERO_COORDINATES_OPERATOR_SCHEMA { + "DML_OPERATOR_NONZERO_COORDINATES", + DML_OPERATOR_NONZERO_COORDINATES, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS[7] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Scales", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "InputPixelOffsets", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_RESAMPLE_GRAD", + DML_OPERATOR_RESAMPLE_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 7, + DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SLICE_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_SLICE_GRAD", + DML_OPERATOR_SLICE_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS[12] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputParametersTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputFirstMomentTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputSecondMomentTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "GradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "TrainingStepTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputParametersTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputFirstMomentTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSecondMomentTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "LearningRate", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta1", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta2", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA { + "DML_OPERATOR_ADAM_OPTIMIZER", + DML_OPERATOR_ADAM_OPTIMIZER, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 12, + DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS[11] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA { + "DML_OPERATOR_ROI_ALIGN", + DML_OPERATOR_ROI_ALIGN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 11, + DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BatchDimensionCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER_ND1", + DML_OPERATOR_GATHER_ND1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1662,6 +1963,20 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_ELU_OPERATOR_SCHEMA { DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_CELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_CELU", + DML_OPERATOR_ACTIVATION_CELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS[2] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index d837dbe4a6..7630e98489 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -145,6 +145,22 @@ inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_ OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC& desc) { return { @@ -328,6 +344,26 @@ inline std::vector GetFields(const DML_REDUCE_OPERATOR_DESC& desc OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), }; } +inline std::vector GetFields(const DML_ARGMIN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisDirection))), + }; +} +inline std::vector GetFields(const DML_ARGMAX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisDirection))), + }; +} inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_DESC& desc) { return { @@ -993,6 +1029,157 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTI OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.GroupCount))), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + }; +} +inline std::vector GetFields(const DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} +inline std::vector GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + }; +} +inline std::vector GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputStateTensor))), + OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputStateTensor))), + OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Type))), + }; +} +inline std::vector GetFields(const DML_NONZERO_COORDINATES_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputCountTensor))), + OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputCoordinatesTensor))), + }; +} +inline std::vector GetFields(const DML_RESAMPLE_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Scales), desc.DimensionCount)), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InputPixelOffsets), desc.DimensionCount)), + OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputPixelOffsets), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_SLICE_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputWindowOffsets), desc.DimensionCount)), + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.InputWindowSizes), desc.DimensionCount)), + OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InputWindowStrides), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_ADAM_OPTIMIZER_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputParametersTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputFirstMomentTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputSecondMomentTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.GradientTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.TrainingStepTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputParametersTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputFirstMomentTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputSecondMomentTensor))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.LearningRate))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.Beta1))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.Beta2))), + OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Epsilon))), + }; +} +inline std::vector GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ROITensor))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BatchIndicesTensor))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.ReductionFunction))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.SpatialScaleX))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.SpatialScaleY))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutOfBoundsInputValue))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.MinimumSamplesPerOutput))), + OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.MaximumSamplesPerOutput))), + }; +} +inline std::vector GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputDimensionCount))), + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), + OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BatchDimensionCount))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1001,6 +1188,14 @@ inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DE OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), }; } +inline std::vector GetFields(const DML_ACTIVATION_CELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_HARDMAX_OPERATOR_DESC& desc) { return { @@ -1165,6 +1360,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA; @@ -1185,6 +1382,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_CONVOLUTION: return DML_CONVOLUTION_OPERATOR_SCHEMA; case DML_OPERATOR_GEMM: return DML_GEMM_OPERATOR_SCHEMA; case DML_OPERATOR_REDUCE: return DML_REDUCE_OPERATOR_SCHEMA; + case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA; + case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA; case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; @@ -1249,7 +1448,23 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_RELU_GRAD: return DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_AVERAGE_POOLING_GRAD: return DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_POOLING_GRAD: return DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_RANDOM_GENERATOR: return DML_RANDOM_GENERATOR_OPERATOR_SCHEMA; + case DML_OPERATOR_NONZERO_COORDINATES: return DML_NONZERO_COORDINATES_OPERATOR_SCHEMA; + case DML_OPERATOR_RESAMPLE_GRAD: return DML_RESAMPLE_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA; + case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA; @@ -1345,6 +1560,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return AbstractOperatorDesc( &DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA, @@ -1425,6 +1648,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_REDUCE_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ARGMIN: + return AbstractOperatorDesc( + &DML_ARGMIN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ARGMAX: + return AbstractOperatorDesc( + &DML_ARGMAX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_AVERAGE_POOLING: return AbstractOperatorDesc( &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, @@ -1681,10 +1912,74 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_AND: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_OR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_RELU_GRAD: + return AbstractOperatorDesc( + &DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_AVERAGE_POOLING_GRAD: + return AbstractOperatorDesc( + &DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_POOLING_GRAD: + return AbstractOperatorDesc( + &DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_RANDOM_GENERATOR: + return AbstractOperatorDesc( + &DML_RANDOM_GENERATOR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_NONZERO_COORDINATES: + return AbstractOperatorDesc( + &DML_NONZERO_COORDINATES_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_RESAMPLE_GRAD: + return AbstractOperatorDesc( + &DML_RESAMPLE_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SLICE_GRAD: + return AbstractOperatorDesc( + &DML_SLICE_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ADAM_OPTIMIZER: + return AbstractOperatorDesc( + &DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ROI_ALIGN: + return AbstractOperatorDesc( + &DML_ROI_ALIGN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER_ND1: + return AbstractOperatorDesc( + &DML_GATHER_ND1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_CELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_CELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_HARDMAX: return AbstractOperatorDesc( &DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index ab5bcbca97..f3f4caab3f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -5,20 +5,12 @@ #include "MLOperatorAuthorImpl.h" #include "FusedGraphKernel.h" +#include "GraphKernelHelper.h" using namespace Windows::AI::MachineLearning::Adapter; namespace Dml { - template - static T AlignToPow2(T offset, T alignment) - { - static_assert(std::is_unsigned_v); - assert(alignment != 0); - assert((alignment & (alignment - 1)) == 0); - return (offset + alignment - 1) & ~(alignment - 1); - } - class FusedGraphKernel : public onnxruntime::OpKernel { public: @@ -73,42 +65,16 @@ namespace Dml const uint32_t graphInputCount = kernelInfo.GetInputCount(); - auto gpuGraphInputConstnessGetter = [&kernelInfo, &fusedNodeInputDefs, &transferredInitializerMap](uint32_t index) - { - // Transferred initializers are uploaded to GPU memory - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name()); - if (iter != transferredInitializerMap.end()) - { - return true; - } - - // If an initializer wasn't transferred, the constant input may be available from ORT - const onnxruntime::Tensor* inputTensor = nullptr; - if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr) - { - return false; - } - - // Check that the constant ORT input is in GPU memory - if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) || - inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || - inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput) - { - return false; - } - - return true; - }; - m_inputsConstant.resize(graphInputCount); for (uint32_t i = 0; i < graphInputCount; ++i) { - m_inputsConstant[i] = gpuGraphInputConstnessGetter(i); + m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap); } GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( kernelInfo, - m_inputsConstant, + m_inputsConstant.data(), + m_inputsConstant.size(), transferredInitializerMap, graph, fusedNodeInputDefs, @@ -117,116 +83,27 @@ namespace Dml device.Get(), m_executionHandle); - // Determine the last input which uses an initializer, so initializers can be freed incrementally - // while processing each input in order. - std::map initializerToLastInputIndexMap; - for (uint32_t i = 0; i < graphInputCount; i++) - { - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); - if (iter != transferredInitializerMap.end()) - { - initializerToLastInputIndexMap[&iter->second] = i; - } - } - - // Walk through each graph edge and mark used inputs - m_inputsUsed.assign(graphInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; - } - // Populate input bindings for operator initialization - std::vector> initInputResources; // For lifetime control + std::vector> initInputResources; // For lifetime control std::vector initInputBindings(graphInputCount); m_nonOwnedGraphInputsFromInitializers.resize(graphInputCount); - std::vector> initializeResourceRefs; - - for (uint32_t i = 0; i < initInputBindings.size(); i++) - { - // If the input isn't actually used by the graph, nothing ever needs to be bound (either for - // initialization or execution). So just throw away the transferred initializer and skip this input. - if (!m_inputsUsed[i]) - { - transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name()); - continue; - } - - // Look for the initializer among those transferred from the graph during partitioning - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); - if (iter != transferredInitializerMap.end()) - { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::unique_ptr unpackedTensor; - - auto& initializer = iter->second; - - // The tensor may be stored as raw data or in typed fields. - if (initializer.has_raw_data()) - { - tensorPtr = (std::byte*)(initializer.raw_data().c_str()); - tensorByteSize = initializer.raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer); - tensorPtr = unpackedTensor.get(); - } - - // Tensor sizes in DML must be a multiple of 4 bytes large. - tensorByteSize = AlignToPow2(tensorByteSize, 4); - - if (!m_inputsConstant[i]) - { - // Store the resource to use during execution - ComPtr defaultBuffer = CreateResource(tensorPtr, tensorByteSize); - m_nonOwnedGraphInputsFromInitializers[i] = defaultBuffer; - initializeResourceRefs.push_back(std::move(defaultBuffer)); - } - else - { - ComPtr initializeInputBuffer; - - // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps - if (m_provider->IsMcdmDevice()) - { - initializeInputBuffer = CreateResource(tensorPtr, tensorByteSize); - } - else - { - initializeInputBuffer = CreateCpuResource(tensorPtr, tensorByteSize); - } - - // Set the binding for operator initialization to the buffer - initInputBindings[i].Buffer = initializeInputBuffer.Get(); - initInputBindings[i].SizeInBytes = tensorByteSize; - initializeResourceRefs.push_back(std::move(initializeInputBuffer)); - } - - // Free the initializer if this is the last usage of it. - if (initializerToLastInputIndexMap[&initializer] == i) - { - transferredInitializerMap.erase(iter); - } - } - else if (m_inputsConstant[i]) - { - const onnxruntime::Tensor* inputTensor = nullptr; - THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor)); - - uint64_t allocId; - UnwrapTensor(inputTensor, &initInputBindings[i].Buffer, &allocId); - initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width; - - initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference - initInputResources.push_back(initInputBindings[i].Buffer); - } - } - - // All initializers should have been consumed and freed above - assert(transferredInitializerMap.empty()); + std::vector> initializeResourceRefs; + + GraphKernelHelper::PopulateInputBindings( + m_provider.Get(), + m_winmlProvider.Get(), + m_inputsConstant, + kernelInfo, + graphDesc, + fusedNodeInputDefs, + m_inputsUsed, + initInputBindings, + initInputResources, + m_nonOwnedGraphInputsFromInitializers, + initializeResourceRefs, + transferredInitializerMap); + DML_GRAPH_DESC dmlGraphDesc = {}; std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); std::vector dmlGraphNodes(graphDesc.nodes.size()); @@ -234,38 +111,15 @@ namespace Dml std::vector dmlOutputEdges(graphDesc.outputEdges.size()); std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) - { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{ graphDesc.nodes[i].op.Get() }; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{ DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i] }; - } - - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) - { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i] }; - } - - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) - { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i] }; - } - - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) - { - dmlIntermediateEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i] }; - } - - DML_GRAPH_DESC dmlGraphDesc = {}; - dmlGraphDesc.InputCount = graphInputCount; - dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount(); - dmlGraphDesc.NodeCount = gsl::narrow_cast(dmlGraphNodes.size()); - dmlGraphDesc.Nodes = dmlGraphNodes.data(); - dmlGraphDesc.InputEdgeCount = gsl::narrow_cast(dmlInputEdges.size()); - dmlGraphDesc.InputEdges = dmlInputEdges.data(); - dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast(dmlOutputEdges.size()); - dmlGraphDesc.OutputEdges = dmlOutputEdges.data(); - dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); - dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); + GraphKernelHelper::ConvertGraphDesc( + graphDesc, + dmlGraphDesc, + kernelInfo, + dmlOperatorGraphNodes, + dmlGraphNodes, + dmlInputEdges, + dmlOutputEdges, + dmlIntermediateEdges); DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE; if (graphDesc.reuseCommandList) @@ -533,10 +387,10 @@ namespace Dml const onnxruntime::Tensor* tensor = kernelContext->Input(i); uint64_t allocId; - UnwrapTensor(tensor, &inputBindings[i].Buffer, &allocId); + GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); inputBindings[i].Buffer->Release(); // Avoid holding an additional reference - inputBindings[i].SizeInBytes = AlignToPow2(tensor->SizeInBytes(), 4); + inputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2(tensor->SizeInBytes(), 4); inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; m_inputBindingAllocIds[i] = allocId; } @@ -570,10 +424,10 @@ namespace Dml ); uint64_t allocId; - UnwrapTensor(tensor, &outputBindings[i].Buffer, &allocId); + GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); outputBindings[i].Buffer->Release(); // Avoid holding an additional reference - outputBindings[i].SizeInBytes = AlignToPow2(tensor->SizeInBytes(), 4); + outputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2(tensor->SizeInBytes(), 4); outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; m_outputBindingAllocIds[i] = allocId; } @@ -623,106 +477,6 @@ namespace Dml m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); } - void UnwrapTensor(const onnxruntime::Tensor* tensor, ID3D12Resource** resource, uint64_t* allocId) const - { - IUnknown* allocationUnk = static_cast(const_cast(tensor->DataRaw())); - ComPtr resourceUnk; - m_winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk); - - *allocId = m_winmlProvider->TryGetPooledAllocationId(allocationUnk, 0); - - THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); - } - - ComPtr CreateResource(const std::byte* tensorPtr, size_t tensorByteSize) const - { - ComPtr buffer; - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_DEFAULT, - D3D12_CPU_PAGE_PROPERTY_UNKNOWN, - D3D12_MEMORY_POOL_UNKNOWN, - 0, - 0 - }; - - D3D12_RESOURCE_DESC resourceDesc = { - D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - static_cast((tensorByteSize + 3) & ~3), - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - { 1, 0 }, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS - }; - - ComPtr d3dDevice; - THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); - - THROW_IF_FAILED(d3dDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - nullptr, - IID_PPV_ARGS(buffer.GetAddressOf()) - )); - - THROW_IF_FAILED(m_provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize)); - - return buffer; - } - - ComPtr CreateCpuResource(const std::byte* tensorPtr, size_t tensorByteSize) const - { - ComPtr buffer; - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_CUSTOM, - D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, - D3D12_MEMORY_POOL_L0, - 0, - 0 - }; - - D3D12_RESOURCE_DESC resourceDesc = { - D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - static_cast((tensorByteSize + 3) & ~3), - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - { 1, 0 }, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS - }; - - ComPtr d3dDevice; - THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); - - THROW_IF_FAILED(d3dDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - nullptr, - IID_PPV_ARGS(buffer.GetAddressOf()) - )); - - // Map the buffer and copy the data - void* bufferData = nullptr; - D3D12_RANGE range = {0, tensorByteSize}; - THROW_IF_FAILED(buffer->Map(0, &range, &bufferData)); - memcpy(bufferData, tensorPtr, tensorByteSize); - buffer->Unmap(0, &range); - - return buffer; - } - ComPtr m_compiledExecutionPlanOperator; std::vector m_inputsUsed; const void* m_executionHandle = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 0dc908f5aa..9cd8f92ae2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -60,7 +60,8 @@ namespace Dml::GraphDescBuilder GraphDesc BuildGraphDesc( const onnxruntime::OpKernelInfo& kernelInfo, - gsl::span isConstGpuGraphInput, + const uint8_t* isConstGpuGraphInput, + const size_t isConstGpuGraphInputCount, std::unordered_map& transferredInitializerMap, const onnxruntime::Graph& graph, const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, @@ -226,7 +227,7 @@ namespace Dml::GraphDescBuilder graphInputEdges.push_back(edge); // If this is a constant input, set the appropriate flags on the desc - if (isConstGpuGraphInput[fusedNodeInputIndex]) + if (fusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[fusedNodeInputIndex]) { DmlBufferTensorDesc* tensorDesc = inputTensorDescs[inputIndex]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 58741511d6..0d013e8fe4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -41,7 +41,8 @@ namespace Dml GraphDesc BuildGraphDesc( const onnxruntime::OpKernelInfo& kernelInfo, - gsl::span isConstGpuGraphInput, + const uint8_t* isConstGpuGraphInput, + const size_t isConstGpuGraphInputCount, std::unordered_map& transferredInitializerMap, const onnxruntime::Graph& graph, const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp new file mode 100644 index 0000000000..fa9792b08c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp @@ -0,0 +1,313 @@ +#include "precomp.h" + +#include "GraphKernelHelper.h" + +namespace Dml +{ +namespace GraphKernelHelper +{ + Microsoft::WRL::ComPtr + CreateResource( + Dml::IExecutionProvider* provider, + const std::byte* tensorPtr, + size_t tensorByteSize) + { + Microsoft::WRL::ComPtr buffer; + + D3D12_HEAP_PROPERTIES heapProperties = { + D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0}; + + D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + static_cast((tensorByteSize + 3) & ~3), + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + Microsoft::WRL::ComPtr d3dDevice; + THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + THROW_IF_FAILED(d3dDevice->CreateCommittedResource( + &heapProperties, + D3D12_HEAP_FLAG_NONE, + &resourceDesc, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_PPV_ARGS(buffer.GetAddressOf()))); + + THROW_IF_FAILED(provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize)); + + return buffer; + } + + Microsoft::WRL::ComPtr + CreateCpuResource( + Dml::IExecutionProvider* provider, + const std::byte* tensorPtr, + size_t tensorByteSize) + { + Microsoft::WRL::ComPtr buffer; + + D3D12_HEAP_PROPERTIES heapProperties = { + D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, D3D12_MEMORY_POOL_L0, 0, 0}; + + D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + static_cast((tensorByteSize + 3) & ~3), + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + Microsoft::WRL::ComPtr d3dDevice; + THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + THROW_IF_FAILED(d3dDevice->CreateCommittedResource( + &heapProperties, + D3D12_HEAP_FLAG_NONE, + &resourceDesc, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_PPV_ARGS(buffer.GetAddressOf()))); + + // Map the buffer and copy the data + void* bufferData = nullptr; + D3D12_RANGE range = {0, tensorByteSize}; + THROW_IF_FAILED(buffer->Map(0, &range, &bufferData)); + memcpy(bufferData, tensorPtr, tensorByteSize); + buffer->Unmap(0, &range); + + return buffer; + } + + void UnwrapTensor( + IWinmlExecutionProvider* winmlProvider, + const onnxruntime::Tensor* tensor, + ID3D12Resource** resource, + uint64_t* allocId) + { + IUnknown* allocationUnk = static_cast(const_cast(tensor->DataRaw())); + Microsoft::WRL::ComPtr resourceUnk; + winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk); + + *allocId = winmlProvider->TryGetPooledAllocationId(allocationUnk, 0); + + THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); + } + + bool GetGraphInputConstness( + uint32_t index, + const onnxruntime::OpKernelInfo& kernelInfo, + const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const std::unordered_map& transferredInitializerMap) + { + // Transferred initializers are uploaded to GPU memory + auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name()); + if (iter != transferredInitializerMap.end()) + { + return true; + } + + // If an initializer wasn't transferred, the constant input may be available from ORT + const onnxruntime::Tensor* inputTensor = nullptr; + if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr) + { + return false; + } + + // Check that the constant ORT input is in GPU memory + if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) || + inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || + inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput) + { + return false; + } + + return true; + }; + + std::vector> PopulateInputBindings( + Dml::IExecutionProvider* provider, + IWinmlExecutionProvider* winmlProvider, + const std::vector& inputsConstant, + const onnxruntime::OpKernelInfo& kernelInfo, + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + _Out_ std::vector& inputsUsed, + _Out_ std::vector& initInputBindings, + _Out_ std::vector>& initInputResources, + _Out_ std::vector>& nonOwnedGraphInputsFromInitializers, + _Out_ std::vector>& initializeResourceRefs, + _Inout_ std::unordered_map& transferredInitializerMap) + { + std::vector> inputRawData; + + const uint32_t graphInputCount = kernelInfo.GetInputCount(); + // Determine the last input which uses an initializer, so initializers can be freed incrementally + // while processing each input in order. + std::map initializerToLastInputIndexMap; + for (uint32_t i = 0; i < graphInputCount; i++) + { + auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + if (iter != transferredInitializerMap.end()) { + initializerToLastInputIndexMap[&iter->second] = i; + } + } + + // Walk through each graph edge and mark used inputs + inputsUsed.assign(graphInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) { + inputsUsed[edge.GraphInputIndex] = true; + } + for (uint32_t i = 0; i < initInputBindings.size(); i++) + { + // If the input isn't actually used by the graph, nothing ever needs to be bound (either for + // initialization or execution). So just throw away the transferred initializer and skip this input. + if (!inputsUsed[i]) + { + transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name()); + inputRawData.push_back(std::vector()); + continue; + } + + // Look for the initializer among those transferred from the graph during partitioning + auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + if (iter != transferredInitializerMap.end()) + { + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + std::unique_ptr unpackedTensor; + + auto& initializer = iter->second; + + // The tensor may be stored as raw data or in typed fields. + if (initializer.has_raw_data()) + { + tensorPtr = (std::byte*)(initializer.raw_data().c_str()); + tensorByteSize = initializer.raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer); + tensorPtr = unpackedTensor.get(); + } + + // Tensor sizes in DML must be a multiple of 4 bytes large. + tensorByteSize = AlignToPow2(tensorByteSize, 4); + + inputRawData.push_back(std::vector(tensorPtr, tensorPtr + tensorByteSize)); + + if (!inputsConstant[i]) + { + // Store the resource to use during execution + ComPtr defaultBuffer = CreateResource(provider, tensorPtr, tensorByteSize); + nonOwnedGraphInputsFromInitializers[i] = defaultBuffer; + initializeResourceRefs.push_back(std::move(defaultBuffer)); + } + else + { + ComPtr initializeInputBuffer; + + // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps + if (provider->IsMcdmDevice()) + { + initializeInputBuffer = CreateResource(provider, tensorPtr, tensorByteSize); + } + else + { + initializeInputBuffer = CreateCpuResource(provider, tensorPtr, tensorByteSize); + } + + // Set the binding for operator initialization to the buffer + initInputBindings[i].Buffer = initializeInputBuffer.Get(); + initInputBindings[i].SizeInBytes = tensorByteSize; + initializeResourceRefs.push_back(std::move(initializeInputBuffer)); + } + + // Free the initializer if this is the last usage of it. + if (initializerToLastInputIndexMap[&initializer] == i) + { + transferredInitializerMap.erase(iter); + } + } + else if (inputsConstant[i]) + { + const onnxruntime::Tensor* inputTensor = nullptr; + THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor)); + + const std::byte* tensorData = reinterpret_cast(inputTensor->DataRaw()); + inputRawData.push_back( + std::vector(tensorData, tensorData + inputTensor->SizeInBytes())); + + uint64_t allocId; + UnwrapTensor(winmlProvider, inputTensor, &initInputBindings[i].Buffer, &allocId); + initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width; + + initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference + initInputResources.push_back(initInputBindings[i].Buffer); + } + else + { + inputRawData.push_back(std::vector()); + } + } + + // All initializers should have been consumed and freed above + assert(transferredInitializerMap.empty()); + return inputRawData; + } + + void ConvertGraphDesc( + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + const onnxruntime::OpKernelInfo& kernelInfo, + _Out_ std::vector& dmlOperatorGraphNodes, + _Out_ std::vector& dmlGraphNodes, + _Out_ std::vector& dmlInputEdges, + _Out_ std::vector& dmlOutputEdges, + _Out_ std::vector& dmlIntermediateEdges) + { + const uint32_t graphInputCount = kernelInfo.GetInputCount(); + + for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + { + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{graphDesc.nodes[i].op.Get()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + + for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + { + dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + } + + for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + { + dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + } + + for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + { + dmlIntermediateEdges[i] = + DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + } + + dmlGraphDesc.InputCount = graphInputCount; + dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount(); + dmlGraphDesc.NodeCount = gsl::narrow_cast(dmlGraphNodes.size()); + dmlGraphDesc.Nodes = dmlGraphNodes.data(); + dmlGraphDesc.InputEdgeCount = gsl::narrow_cast(dmlInputEdges.size()); + dmlGraphDesc.InputEdges = dmlInputEdges.data(); + dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast(dmlOutputEdges.size()); + dmlGraphDesc.OutputEdges = dmlOutputEdges.data(); + dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); + dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); + } +} // namespace GraphKernelHelper +} // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h new file mode 100644 index 0000000000..b1b2e87cf8 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h @@ -0,0 +1,67 @@ +#include "GraphDescBuilder.h" + +namespace Dml +{ +namespace GraphKernelHelper +{ + using namespace Windows::AI::MachineLearning::Adapter; + + template + static T AlignToPow2(T offset, T alignment) + { + static_assert(std::is_unsigned_v); + assert(alignment != 0); + assert((alignment & (alignment - 1)) == 0); + return (offset + alignment - 1) & ~(alignment - 1); + } + + Microsoft::WRL::ComPtr + CreateResource( + Dml::IExecutionProvider* provider, + const std::byte* tensorPtr, + size_t tensorByteSize); + + Microsoft::WRL::ComPtr + CreateCpuResource( + Dml::IExecutionProvider* provider, + const std::byte* tensorPtr, + size_t tensorByteSize); + + void UnwrapTensor( + IWinmlExecutionProvider* winmlProvider, + const onnxruntime::Tensor* tensor, + ID3D12Resource** resource, + uint64_t* allocId); + + bool GetGraphInputConstness( + uint32_t index, + const onnxruntime::OpKernelInfo& kernelInfo, + const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const std::unordered_map& transferredInitializerMap); + + std::vector> PopulateInputBindings( + Dml::IExecutionProvider* provider, + IWinmlExecutionProvider* winmlProvider, + const std::vector& inputsConstant, + const onnxruntime::OpKernelInfo& kernelInfo, + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + _Out_ std::vector& inputsUsed, + _Out_ std::vector& initInputBindings, + _Out_ std::vector>& initInputResources, + _Out_ std::vector>& nonOwnedGraphInputsFromInitializers, + _Out_ std::vector>& initializeResourceRefs, + _Inout_ std::unordered_map& transferredInitializerMap); + + void ConvertGraphDesc( + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + const onnxruntime::OpKernelInfo& kernelInfo, + _Out_ std::vector& dmlOperatorGraphNodes, + _Out_ std::vector& dmlGraphNodes, + _Out_ std::vector& dmlInputEdges, + _Out_ std::vector& dmlOutputEdges, + _Out_ std::vector& dmlIntermediateEdges); + +} // namespace GraphKernelHelper +} // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 820e0fa254..a139b5fdda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -139,7 +139,7 @@ namespace Dml } }; - bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats) + bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask) { if (arg->Exists()) { @@ -151,16 +151,23 @@ namespace Dml { // TODO: Remove this by handling zeroing on the output of fused graph nodes and handling of non-float // types in DML's identity operator, which is used for strided copies. - if (ToMLTensorDataType(static_cast(tensorType.elem_type())) == MLOperatorTensorDataType::UInt64 || - ToMLTensorDataType(static_cast(tensorType.elem_type())) == MLOperatorTensorDataType::Int64) + + MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast(tensorType.elem_type())); + + if (mlDataType == MLOperatorTensorDataType::UInt64 || + mlDataType == MLOperatorTensorDataType::Int64) { - return false; + constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64); + if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit) + { + return false; + } } if (requiresFloatFormats) { - if (ToMLTensorDataType(static_cast(tensorType.elem_type())) != MLOperatorTensorDataType::Float && - ToMLTensorDataType(static_cast(tensorType.elem_type())) != MLOperatorTensorDataType::Float16) + if (mlDataType != MLOperatorTensorDataType::Float && + mlDataType != MLOperatorTensorDataType::Float16) { return false; } @@ -172,14 +179,19 @@ namespace Dml return true; } - bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration) + bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration, uint32_t supportedDeviceDataTypeMask) { for (size_t i = 0; i < node.InputDefs().size(); ++i) { bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) != registration.requiredConstantCpuInputs.end(); - if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs)) + if (!isConstantCpuInput && + !NodeArgSupportedInGraph( + node.InputDefs()[i], + registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, + supportedDeviceDataTypeMask + )) { return false; } @@ -187,7 +199,11 @@ namespace Dml for (auto arg : node.OutputDefs()) { - if (!NodeArgSupportedInGraph(arg, registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs)) + if (!NodeArgSupportedInGraph( + arg, + registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, + supportedDeviceDataTypeMask + )) { return false; } @@ -220,97 +236,155 @@ namespace Dml bool DoesNodeContainSupportedDataTypes( const onnxruntime::Node& node, bool allow64BitInputThroughStrides, - _In_opt_ const std::unordered_map* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true + _In_opt_ const std::unordered_map* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true _In_opt_ const InternalRegistrationInfo* regInfo, uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE. ) { THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap); + bool prefer64BitTensorsDirectly = false; + bool supportedWith64BitTensorsVia32BitStrides = false; + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false; + std::vector constantCpuInputs; + + if (regInfo != nullptr) + { + // Read the operator flags for handling 64-bit tensors and whether it's allowed to fall back + // to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false + // in this particular call, then the operator-specific flags do not matter as the caller has + // disabled 64-bit support. + if (allow64BitInputThroughStrides) + { + prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; + supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp; + supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp; + } + + // Collect the list of CPU-bound input tensors, needed when checking 64-bit fallback + // or for other data types like int-8 which may be supported for CPU inputs but not + // GPU inputs. + auto inputDefinitions = node.InputDefs(); + for (uint32_t i : regInfo->requiredConstantCpuInputs) + { + if (i < inputDefinitions.size()) + { + constantCpuInputs.push_back(inputDefinitions[i]); + } + } + } + // Assume data types are supported until proven otherwise. bool nodeContainsSupportedDataTypes = true; - // Callback to check each node's data type. + // Callback to check each node's data type against registered operator support. std::function nodeCallback = [&](const onnxruntime::NodeArg& nodeArg, bool isInput) -> void { // Get the tensor element data type for this node, comparing against what the device actually supports. // Use the enumeration from the proto instead of nodeArg.Type() which returns a string. + + // Reject node if undefined data type or non-tensor, as DML cannot handle it. MLOperatorTensorDataType onnxElementType; - if (TryGetTensorDataType(nodeArg, &onnxElementType)) + if (!TryGetTensorDataType(nodeArg, &onnxElementType)) { - DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType); - if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN) + // We shouldn't have arrived here because (1) no DML operators should have been + // registered which use non-tensor types (2) ONNX validation should have already + // been done, checking for the right kind of inputs and attributes. In theory, + // this branch could be reached with a bad custom operator or malformed file. If + // a legitimate case reaches here and DML needs to support a new input/output type + // besides tensors, then remove the assert. + assert(false); + nodeContainsSupportedDataTypes = false; + return; + } + + // Reject node for unknown DML data types. + DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType); + if (dmlElementType == DML_TENSOR_DATA_TYPE_UNKNOWN) + { + nodeContainsSupportedDataTypes = false; + return; + } + + // Succeed if the tensor is CPU-bound, as the CPU-side reading code is generic enough + // to handle multiple types regardless of GPU capability (typically these are just + // scalars or simple 1D arrays). + bool isConstantCpuInput = isInput && std::find(constantCpuInputs.begin(), constantCpuInputs.end(), &nodeArg) != constantCpuInputs.end(); + if (isConstantCpuInput) + { + // Leave nodeContainsSupportedDataTypes alone. + return; + } + + // If this operator implements 64-bit support in terms of strided 32-bit tensors, + // then the data type needs to be remapped, regardless of whether input or output. + // + // Some operators can fairly safely implement 64-bit tensors in terms of + // strided 32-bit tensors regardless of input tensor's execution provider + // because the indices measure along a single axis and should fall within + // the range of an int32/uint32. + // + // Currently all DML kernels outputting int64 and uint64 are expected to + // not *introduce* values out of range, which allows the temporary trick + // using strides to emulate 64 bit tensors to work. If the source is a CPU + // operator, graph input or initializer, it's not safe to assume the input + // can be represented with 32 bits. + // + bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64); + bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask); + if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit) + { + dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); + + if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) { - if (((1 << dmlElementType) & supportedDeviceDataTypeMask) == 0) + // Look up the input partition. If it's a graph input or initializer it will be missing + // from the partition map. + const std::string& argName = nodeArg.Name(); + + // If input tensor's data comes from the output of a different execution provider, + // consider it unsafe to apply fallback to. + auto partitionIter = nodeNameToPartitionMap->find(argName); + if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) { nodeContainsSupportedDataTypes = false; + return; } } } + + // Reject node if the data type is unsupported by the device. + if (!((1 << dmlElementType) & supportedDeviceDataTypeMask)) + { + nodeContainsSupportedDataTypes = false; + return; + } + + // Otherwise the node supports the tensor data type. }; // Check whether the node uses any data types which are unsupported by the device. node.ForEachDef(nodeCallback); - // DML kernels supporting int64 and uint64 are expected to not *introduce* values out of range, which allows - // the temporary trick using strides to emulate 64 bit tensors to work. If the source is a CPU operator, - // graph input or initializer, it's not safe to assume the input can be represented with 32 bits. - if (regInfo) - { - for (uint32_t i = 0; i < node.InputDefs().size(); ++i) - { - const auto* arg = node.InputDefs()[i]; - MLOperatorTensorDataType onnxElementType; - if (arg->Exists() && TryGetTensorDataType(*arg, &onnxElementType)) - { - if (((onnxElementType == MLOperatorTensorDataType::UInt64) || (onnxElementType == MLOperatorTensorDataType::Int64))) - { - // Look up the input partition. If it's a graph input or initializer it will be missing - // from the map. In this case or if the input comes from a CPU partition, it might be - // out of range. - const std::string& argName = arg->Name(); - // Check if the operator handles the input on the CPU as a constant input - bool isConstantCpuInput = std::find(regInfo->requiredConstantCpuInputs.begin(), regInfo->requiredConstantCpuInputs.end(), i) != - regInfo->requiredConstantCpuInputs.end(); - - if (!isConstantCpuInput) - { - if (!allow64BitInputThroughStrides) - { - nodeContainsSupportedDataTypes = false; - break; - } - - auto partitionIter = nodeNameToPartitionMap->find(argName); - if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) - { - nodeContainsSupportedDataTypes = false; - break; - } - } - } - } - } - } - return nodeContainsSupportedDataTypes; } bool IsNodeSupportedByDml( - const onnxruntime::Node& node, + const onnxruntime::Node& node, const onnxruntime::KernelRegistry& registry, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const InternalRegistrationInfoMap& internalRegInfoMap, bool allow64BitInputThroughStrides, _In_opt_ const std::unordered_map* nodeNameToPartitionMap - ) + ) { THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap); const onnxruntime::KernelCreateInfo* createInfo; Status st = registry.TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo); - if (!st.IsOK()) { - return false; + if (!st.IsOK()) + { + return false; } auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get()); @@ -337,7 +411,7 @@ namespace Dml // Gets properties of the registration for a node void GetRegistrationProperties( const onnxruntime::GraphViewer& graph, - const onnxruntime::Node& node, + const onnxruntime::Node& node, const std::vector& dmlRegistries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const InternalRegistrationInfoMap& internalRegInfoMap, @@ -368,7 +442,9 @@ namespace Dml // which is required for MLGraph compilation. const onnxruntime::KernelCreateInfo* createInfo; if (!registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo).IsOK()) - continue; + { + continue; + } auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get()); if (regInfoIter != internalRegInfoMap.end()) @@ -376,7 +452,7 @@ namespace Dml auto internalRegInfo = regInfoIter->second; if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration && - NodeTensorTypesSupportedInGraph(node, *internalRegInfo)) + NodeTensorTypesSupportedInGraph(node, *internalRegInfo, supportedDeviceDataTypeMask)) { bool requiredCpuInputsConstant = true; for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) @@ -912,7 +988,8 @@ namespace Dml std::move(graphNodePropertyMap), registryForPartitionKernels, partitionKernelPrefix, - transferredInitializerMap)); + transferredInitializerMap + )); } return result; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 2c9dc497e1..1b2744ecb4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -61,7 +61,7 @@ namespace Dml ); bool IsNodeSupportedByDml( - const onnxruntime::Node& node, + const onnxruntime::Node& node, const onnxruntime::KernelRegistry& registry, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 9667336dd8..112e7279d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -87,7 +87,7 @@ namespace Dml if (!IsNodeSupportedByDml( node, *registry, - m_providerImpl->GetSuppportedDeviceDataTypeMask(), + m_providerImpl->GetSupportedDeviceDataTypeMask(), *m_providerImpl->GetInternalRegistrationInfoMap().get(), allow64BitInputThroughStrides, nullptr)) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index a44e11dbc2..2308766d71 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -133,12 +133,21 @@ namespace Dml )); } + void DmlOperator::Initialize( + const MLOperatorKernelCreationContext& kernelInfo, + uint32_t minDimensionCount + ) + { + Initialize(kernelInfo, std::nullopt, std::nullopt, std::nullopt, std::nullopt, minDimensionCount); + } + void DmlOperator::Initialize( const MLOperatorKernelCreationContext& kernelInfo, const std::optional>>& kernelInputIndices, const std::optional>>& kernelOutputIndices, const std::optional> inputShape, - const std::optional> outputShape + const std::optional> outputShape, + uint32_t minDimensionCount ) { if (kernelInputIndices) @@ -179,7 +188,7 @@ namespace Dml TensorAxis::W, TensorAxis::RightAligned, inputShape, - NchwDimensionCount)); + minDimensionCount)); } } @@ -200,7 +209,8 @@ namespace Dml TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, - outputShape)); + outputShape, + minDimensionCount)); } } } @@ -373,6 +383,34 @@ namespace Dml )); } + void DmlOperator::Remap64bitDmlDataTypesTo32bit() + { + for (auto& tensor : m_inputTensorDescs) + { + tensor.Remap64bitDmlDataTypeTo32bit(); + } + + for (auto& tensor : m_outputTensorDescs) + { + tensor.Remap64bitDmlDataTypeTo32bit(); + } + } + + void DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded() + { + // Conditionally remap 64-bit data types to strided 32-bit if DML does not + // support 64-bit data types directly on the device. + + uint32_t deviceTypeMask = Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get()); + uint32_t deviceTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_INT64) | (1 << DML_TENSOR_DATA_TYPE_UINT64); + + // If the device doesn't support 64-bit tensors, fall back to 32-bit with strides. + if (!(deviceTypeMask & deviceTypeMask64bit)) + { + Remap64bitDmlDataTypesTo32bit(); + } + } + TensorDesc DmlOperator::CreateTensorDescFromInput( const MLOperatorKernelCreationContext& kernelInfo, uint32_t index, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index 5c75c1a11a..dedfb41e71 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -29,12 +29,18 @@ namespace Dml ComPtr m_persistentResourcePoolingUnk; // Controls when the persistent resource is returned to the pool std::optional m_persistentResourceBinding; + void Initialize( + const MLOperatorKernelCreationContext& kernelInfo, + uint32_t minDimensionCount + ); + void Initialize( const MLOperatorKernelCreationContext& kernelInfo, const std::optional>>& kernelInputIndices = std::nullopt, const std::optional>>& kernelOutputIndices = std::nullopt, const std::optional> inputShape = std::nullopt, - const std::optional> outputShape = std::nullopt + const std::optional> outputShape = std::nullopt, + uint32_t minDimensionCount = NchwDimensionCount ); bool AllowHalfPrecisionComputation() const; @@ -77,6 +83,11 @@ namespace Dml void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor); + // Remap 64-bit data types to 32-bit via doubled strides. + // These should be called before GetDmlInputDescs or GetDmlOutputDescs. + void Remap64bitDmlDataTypesTo32bit(); + void Remap64bitDmlDataTypesTo32bitIfNeeded(); + TensorDesc CreateTensorDescFromInput( const MLOperatorKernelCreationContext& kernelInfo, uint32_t index, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp index ab9787757b..c0d08f83d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp @@ -30,6 +30,7 @@ public: switch (operatorType) { case DML_OPERATOR_ACTIVATION_ELU: + case DML_OPERATOR_ACTIVATION_CELU: operatorDesc.elu.Alpha = kernelCreationContext.GetOptionalAttribute(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType)); break; @@ -154,6 +155,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(HardSigmoid, DmlOperatorActivationTempla DML_OP_DEFINE_CREATION_FUNCTION(Tanh, DmlOperatorActivationTemplate); DML_OP_DEFINE_CREATION_FUNCTION(ScaledTanh, DmlOperatorActivationTemplate); DML_OP_DEFINE_CREATION_FUNCTION(Relu, DmlOperatorActivationTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(Celu, DmlOperatorActivationTemplate); DML_OP_DEFINE_CREATION_FUNCTION(LeakyRelu, DmlOperatorActivationTemplate); DML_OP_DEFINE_CREATION_FUNCTION(PRelu, DmlOperatorActivationTemplate); DML_OP_DEFINE_CREATION_FUNCTION(ThresholdedRelu, DmlOperatorActivationTemplate); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcat.cpp index ce5c35935c..693c387d8c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcat.cpp @@ -42,7 +42,7 @@ public: for (size_t i = 0; i < m_inputTensorDescs.size(); i++) { // DML doesn't support empty tensors for concat, so we ignore them - if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetDmlSizes())) + if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetSizes())) { inputDescs.push_back(m_inputTensorDescs[i].GetDmlDesc()); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp new file mode 100644 index 0000000000..84d8fc0e34 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorEinSum : public DmlOperator, public EinSumHelper +{ +public: + DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) + : DmlOperator(kernelCreationContext), + EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "EinSum expects one output tensor."); + + DmlOperator::Initialize(kernelCreationContext); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + static_assert(RecognizedOperatorType::Total == static_cast(8), "Update this switch."); + switch (m_recognizedOperatorType) + { + case RecognizedOperatorType::Multiply: + { + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC operatorDesc = {}; + operatorDesc.ATensor = &inputDescs[0]; + operatorDesc.BTensor = &inputDescs[1]; + operatorDesc.OutputTensor = outputDescs.data(); + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &operatorDesc}, kernelCreationContext); + } + break; + + case RecognizedOperatorType::MatMul: + case RecognizedOperatorType::MatMulTransposeA: + case RecognizedOperatorType::MatMulTransposeB: + { + DML_GEMM_OPERATOR_DESC operatorDesc = {}; + operatorDesc.ATensor = &inputDescs[0]; + operatorDesc.BTensor = &inputDescs[1]; + // No operatorDesc.CTensor + operatorDesc.OutputTensor = &outputDescs[0]; + operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; + operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; + operatorDesc.Alpha = 1.0; + operatorDesc.Beta = 0.0; + operatorDesc.FusedActivation = nullptr; + + SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); + } + break; + + case RecognizedOperatorType::ReduceSum: + { + // Get how many axes are kept in the final output, either 0 or 1 supported + // meaning full reduction or partial with one dimension left. *It could be + // generalized to support any number of output dimensions, but it would need + // to accomodate for Transposition too if the output labels are reordered. + auto keptAxes = m_components.back().GetLabels(m_labelIndices); + assert(keptAxes.size() <= 1); + + // DML expects output rank to match input rank (as if ONNX ReduceSum keepdims=1). + // So replace the existing tensor description with the input sizes, except that + // reduced dimensions have size 1. + std::vector reducedAxes; + std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + std::vector outputSizes = inputSizes; + + // Determine which axes are being reduced by taking the opposite of those kept. + uint32_t keptAxesMask = 0; + for (auto axis : keptAxes) + { + keptAxesMask |= (1 << axis); + } + for (uint32_t axis = 0, axisCount = static_cast(outputSizes.size()); axis < axisCount; ++axis) + { + if (~keptAxesMask & (1<(reducedAxes.size()); + + SetDmlOperatorDesc({ DML_OPERATOR_REDUCE, &operatorDesc }, kernelCreationContext); + } + break; + + case RecognizedOperatorType::Transpose: + case RecognizedOperatorType::Identity: + { + if (m_recognizedOperatorType == RecognizedOperatorType::Transpose) + { + // Transpose via input strides. The output tensor is not strided. + assert(m_components.front().GetDimensionCount() == m_components.back().GetDimensionCount()); + auto originalStrides = m_inputTensorDescs.front().GetStrides(); + std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + std::vector inputStrides(inputSizes.size()); + + // If there were no strides, compute them based in descending packed order + // based on the input sizes. + if (originalStrides.empty()) + { + Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides); + } + else // Copy the original strides. + { + assert(originalStrides.size() >= inputStrides.size()); + size_t offset = originalStrides.size() - inputStrides.size(); + inputStrides.assign(originalStrides.begin() + offset, originalStrides.end()); + } + + // Remap transposed strides using the component labels from input to output. + auto labelIndices = m_components.back().GetLabels(m_labelIndices); + + std::vector newStrides(inputStrides.size()); + std::vector newSizes(inputStrides.size()); + for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i) + { + uint32_t labelIndex = labelIndices[i]; + assert(labelIndex < inputStrides.size()); + newSizes[i] = inputSizes[labelIndex]; + newStrides[i] = inputStrides[labelIndex]; + } + + // Override the initial input tensor with the new strides. + m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), newSizes, newStrides, 0); + m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newSizes, std::nullopt, 0); + m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. + m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. + } + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &operatorDesc}, kernelCreationContext); + } + break; + + default: + return; + } + } +}; + +void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) +{ + *isSupported = false; + + MLOperatorAttributes attributes(context); + EinSumHelper helper(attributes); + auto recognizedOperatorType = helper.GetRecognizedOperatorType(); + + static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(8), "Verify this test still matches the switch above."); + *isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None); +} + +DML_OP_DEFINE_CREATION_FUNCTION(Einsum12, VersionedKernel); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 78a8d9e417..9de43a793a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -462,6 +462,9 @@ public: } }; +// Same operator signature as 11. Only difference is new type support +using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11; + class DmlOperatorElementwisePow : public DmlOperator { public: @@ -700,6 +703,8 @@ DML_OP_DEFINE_CREATION_FUNCTION(Erf, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Less, DmlOperatorElementwiseBinary); +DML_OP_DEFINE_CREATION_FUNCTION(GreaterOrEqual, DmlOperatorElementwiseBinary); +DML_OP_DEFINE_CREATION_FUNCTION(LessOrEqual, DmlOperatorElementwiseBinary); DML_OP_DEFINE_CREATION_FUNCTION(Equal, DmlOperatorElementwiseBinary); DML_OP_DEFINE_CREATION_FUNCTION(And, DmlOperatorElementwiseBinary); DML_OP_DEFINE_CREATION_FUNCTION(Or, DmlOperatorElementwiseBinary); @@ -718,6 +723,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean); // Operators with extra attributes: DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7); DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11); +DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12); DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow); DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear); DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorExpand.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorExpand.cpp index bf886fbd4f..a9f0213337 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorExpand.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorExpand.cpp @@ -25,8 +25,8 @@ public: TensorDesc inputTensorDesc = TensorDesc( kernelCreationContext.GetInputEdgeDescription(0).tensorDataType, - m_outputTensorDescs[0].GetDmlSizes(), - m_inputTensorDescs[0].GetDmlSizes(), + m_outputTensorDescs[0].GetSizes(), + m_inputTensorDescs[0].GetSizes(), TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, @@ -36,8 +36,8 @@ public: TensorDesc outputTensorDesc = TensorDesc( kernelCreationContext.GetOutputEdgeDescription(0).tensorDataType, - m_outputTensorDescs[0].GetDmlSizes(), - m_outputTensorDescs[0].GetDmlSizes(), + m_outputTensorDescs[0].GetSizes(), + m_outputTensorDescs[0].GetSizes(), TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp index 54ef525b46..6dfdf0cec6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp @@ -16,19 +16,21 @@ public: ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "Gather expects 2 inputs."); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Gather expects 1 output."); - DmlOperator::Initialize(kernelCreationContext); + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = tensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = tensorShapeDescription.GetInputTensorShape(1); + std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); + + size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()}); + DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast(dimensionCountMax)); + + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); assert(inputDescs.size() == 2); assert(outputDescs.size() == 1); - m_inputTensorDescs[1].ForceUnsignedDataType(); - - auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); - std::vector dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0); - std::vector indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1); - ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount()); DML_GATHER_OPERATOR_DESC operatorDesc = {}; @@ -52,20 +54,22 @@ public: ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs."); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output."); - DmlOperator::Initialize(kernelCreationContext); + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = tensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = tensorShapeDescription.GetInputTensorShape(1); + std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); + + size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()}); + DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast(dimensionCountMax)); + + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); assert(inputDescs.size() == 2); assert(outputDescs.size() == 1); - m_inputTensorDescs[1].ForceUnsignedDataType(); - int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute(AttrName::Axis, 0); - auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); - std::vector dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0); - std::vector indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1); - ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount()); DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {}; @@ -89,29 +93,30 @@ public: ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs."); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output."); - DmlOperator::Initialize(kernelCreationContext); + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = tensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = tensorShapeDescription.GetInputTensorShape(1); + std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); + + size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()}); + DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast(dimensionCountMax)); + + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); assert(inputDescs.size() == 2); assert(outputDescs.size() == 1); - m_inputTensorDescs[1].ForceUnsignedDataType(); - - auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); - std::vector dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0); - std::vector indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1); - ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); - ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount); - - DML_GATHER_ND_OPERATOR_DESC operatorDesc = {}; + DML_GATHER_ND1_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.IndicesTensor = &inputDescs[1]; operatorDesc.OutputTensor = outputDescs.data(); operatorDesc.InputDimensionCount = static_cast(dataDimensions.size()); operatorDesc.IndicesDimensionCount = static_cast(indicesDimensions.size()); + operatorDesc.BatchDimensionCount = m_batchCount; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc }; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND1, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMaxUnpool.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMaxUnpool.cpp index 416fd3620c..e58bb0a674 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMaxUnpool.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMaxUnpool.cpp @@ -22,6 +22,8 @@ public: std::vector> inputIndices = { 0, 1 }; // The 3rd tensor ('output_shape') is not bound, just 'X' and 'I' indices. std::vector> outputIndices = { 0 }; DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices); + DmlOperator::Remap64bitDmlDataTypesTo32bit(); + m_inputTensorDescs[1].ForceUnsignedDataType(); // MaxUnpool accepts uint32_t. std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorOneHot.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorOneHot.cpp index ec01a54acc..f1ba681d99 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorOneHot.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorOneHot.cpp @@ -37,7 +37,7 @@ public: TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, - NchwDimensionCount, // minDimensionCount + 1, // minDimensionCount 0 ); @@ -49,10 +49,12 @@ public: TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, - NchwDimensionCount, // minDimensionCount + 1, // minDimensionCount 0 ); + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); + // Adjust the axis so it's in DML's terms rather than the original ONNX indexing. uint32_t dmlAxis = GetDmlAdjustedAxis( m_absoluteAxis, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index 0fcb935416..2598c8c015 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -89,18 +89,7 @@ public: } }; -// A specific type of operation for registration. -template -class DmlOperatorPaddingTemplate : public DmlOperatorPadding -{ -public: - DmlOperatorPaddingTemplate(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperatorPadding(kernelInfo, opsetVersion) - { - } -}; - -DML_OP_DEFINE_CREATION_FUNCTION(Pad7, DmlOperatorPaddingTemplate<7>); -DML_OP_DEFINE_CREATION_FUNCTION(Pad11, DmlOperatorPaddingTemplate<11>); +DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index ac62584c28..439865d05f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -108,7 +108,14 @@ public: if (hasOutputIndices || hasDilations) { DML_MAX_POOLING2_OPERATOR_DESC desc = {}; - desc.OutputIndicesTensor = hasOutputIndices ? &outputDescs[1] : nullptr; + + if (hasOutputIndices) + { + DmlOperator::Remap64bitDmlDataTypesTo32bit(); + m_outputTensorDescs[1].ForceUnsignedDataType(); // MaxPool accepts uint32_t. + desc.OutputIndicesTensor = &outputDescs[1]; + } + desc.Dilations = m_kernel.dilations; SetOpDesc(desc); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReduce.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReduce.cpp index f6f623dc4e..72758d743b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReduce.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReduce.cpp @@ -22,17 +22,12 @@ public: ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); DmlOperator::Initialize(kernelInfo); - // Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output. - if ((function == DML_REDUCE_FUNCTION_ARGMAX) || (function == DML_REDUCE_FUNCTION_ARGMIN)) - { - m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes()); - } - std::vector dmlAxes; std::vector reducedDims = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0); - int dimOffset = gsl::narrow_cast(OperatorHelper::NchwDimensionCount - reducedDims.size()); + int dimOffset = gsl::narrow_cast(m_inputTensorDescs[0].GetDimensionCount() - reducedDims.size()); for (auto& dim : m_axes) { + assert(dim < reducedDims.size()); // ReduceHelperBase already validated this. reducedDims[dim] = 1; dmlAxes.push_back(static_cast(dim + dimOffset)); } @@ -62,15 +57,59 @@ public: std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_REDUCE_OPERATOR_DESC reduceDesc = {}; - reduceDesc.InputTensor = inputDescs.data(); - reduceDesc.OutputTensor = outputDescs.data(); - reduceDesc.Function = function; - reduceDesc.Axes = dmlAxes.data(); - reduceDesc.AxisCount = gsl::narrow_cast(dmlAxes.size()); + // Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output. + if (function == DML_REDUCE_FUNCTION_ARGMAX) + { + DML_ARGMAX_OPERATOR_DESC argmaxDesc; + argmaxDesc.AxisDirection = static_cast(m_selectLastIndex); + argmaxDesc.InputTensor = inputDescs.data(); + argmaxDesc.OutputTensor = outputDescs.data(); + argmaxDesc.Axes = dmlAxes.data(); + argmaxDesc.AxisCount = gsl::narrow_cast(dmlAxes.size()); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc }; - SetDmlOperatorDesc(opDesc, kernelInfo); + // If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits + // of each element. If the device directly supports 64-bit elements, then no need. + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); + if (m_outputTensorDescs[0].WasRemapped64bitTo32bit()) + { + m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes()); + } + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMAX, &argmaxDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } + else if (function == DML_REDUCE_FUNCTION_ARGMIN) + { + DML_ARGMIN_OPERATOR_DESC argminDesc; + argminDesc.AxisDirection = static_cast(m_selectLastIndex); + argminDesc.InputTensor = inputDescs.data(); + argminDesc.OutputTensor = outputDescs.data(); + argminDesc.Axes = dmlAxes.data(); + argminDesc.AxisCount = gsl::narrow_cast(dmlAxes.size()); + + // If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits + // of each element. If the device directly supports 64-bit elements, then no need. + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); + if (m_outputTensorDescs[0].WasRemapped64bitTo32bit()) + { + m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes()); + } + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMIN, &argminDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } + else + { + DML_REDUCE_OPERATOR_DESC reduceDesc = {}; + reduceDesc.InputTensor = inputDescs.data(); + reduceDesc.OutputTensor = outputDescs.data(); + reduceDesc.Function = function; + reduceDesc.Axes = dmlAxes.data(); + reduceDesc.AxisCount = gsl::narrow_cast(dmlAxes.size()); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } } void Compute(const MLOperatorKernelContext& kernelContext) override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index 095c5aa5b8..a3547188f0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -251,17 +251,6 @@ public: } }; -// A specific type of operation for registration. -template -struct DmlOperatorResizeTemplate : public DmlOperatorResize -{ -public: - DmlOperatorResizeTemplate(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperatorResize(kernelInfo, OpsetVersion) - { - } -}; - void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { *isSupported = false; @@ -304,10 +293,10 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* *isSupported = true; } -DML_OP_DEFINE_CREATION_FUNCTION(Resize10, DmlOperatorResizeTemplate<10>); -DML_OP_DEFINE_CREATION_FUNCTION(Resize11, DmlOperatorResizeTemplate<11>); -DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, DmlOperatorResizeTemplate<7>); -DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, DmlOperatorResizeTemplate<9>); -DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, DmlOperatorResizeTemplate<10>); +DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReverseSequence.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReverseSequence.cpp index f13756043f..8e2af57299 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReverseSequence.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorReverseSequence.cpp @@ -47,6 +47,9 @@ public: 0 ); + DmlOperator::Remap64bitDmlDataTypesTo32bit(); + m_inputTensorDescs[1].ForceUnsignedDataType(); // DML operator accepts uint32_t. + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp new file mode 100644 index 0000000000..75e7595a3a --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelper +{ +public: + using Self = DmlOperatorRegionOfInterestAlign; + + DmlOperatorRegionOfInterestAlign(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) + : DmlOperator(kernelCreationContext), + RoiAlignHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "RoiAlign expects 3 input tensors."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "RoiAlign expects 1 output tensor."); + + DmlOperator::Initialize(kernelCreationContext); + DmlOperator::Remap64bitDmlDataTypesTo32bit(); + m_inputTensorDescs[2].ForceUnsignedDataType(); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + constexpr NameAndIndex mapping[] = + { + {"max", DML_REDUCE_FUNCTION_MAX}, + {"avg", DML_REDUCE_FUNCTION_AVERAGE}, + }; + const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg"); + const auto reductionFunction = MapStringToIndex(mode, mapping); + const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f); + const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u); + ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive."); + + DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[0]; + operatorDesc.ROITensor = &inputDescs[1]; + operatorDesc.BatchIndicesTensor = &inputDescs[2]; + operatorDesc.OutputTensor = &outputDescs[0]; + operatorDesc.SpatialScaleX = spatialScale; // ONNX uses the same scale for X and Y. + operatorDesc.SpatialScaleY = spatialScale; + operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds. + operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput; + operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput; + operatorDesc.ReductionFunction = reductionFunction; + operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc }; + + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiPooling.cpp index ac17a2ba16..1d50d3f42d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiPooling.cpp @@ -26,7 +26,7 @@ public: poolingDesc.ROITensor = &inputDescs[1]; poolingDesc.OutputTensor = &outputDescs[0]; poolingDesc.SpatialScale = m_spatialScale; - poolingDesc.PooledSize = { m_pooledSizeH, m_pooledSizeW }; + poolingDesc.PooledSize = { m_outputSizeH, m_outputSizeW }; DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_POOLING, &poolingDesc }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp index b9dfef8ba3..9a7f7de526 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp @@ -23,7 +23,6 @@ public: ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions); ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions); ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size()); - ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); // When the indices tensor is empty, Scatter is basically Identity. But since DML doesn't support empty or null // tensors, we have to special-case it outside of DML. @@ -31,6 +30,7 @@ public: { std::vector> kernelInputIndices(1, 0); DmlOperator::Initialize(kernelCreationContext, kernelInputIndices); + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); @@ -49,14 +49,13 @@ public: else { DmlOperator::Initialize(kernelCreationContext); + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); assert(inputDescs.size() == 3); assert(outputDescs.size() == 1); - m_inputTensorDescs[1].ForceUnsignedDataType(); - // Read the axis. int onnxAxis = kernelCreationContext.GetOptionalAttribute(AttrName::Axis, 0); uint32_t dmlAxis = GetDmlAdjustedAxis(onnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount()); @@ -89,20 +88,16 @@ public: std::vector updatesDimensions = tensorShapeDescription.GetInputTensorShape(2); std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions); - ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); - ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount); - ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount); - ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount); - DmlOperator::Initialize(kernelCreationContext); + size_t dimensionCountMax = std::max({dataDimensions.size(), updatesDimensions.size(), indicesDimensions.size(), outputDimensions.size()}); + DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast(dimensionCountMax)); + DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); assert(inputDescs.size() == 3); assert(outputDescs.size() == 1); - m_inputTensorDescs[1].ForceUnsignedDataType(); - DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.IndicesTensor = &inputDescs[1]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index 038cdc6a62..fd6a12ecae 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -47,23 +47,12 @@ public: } }; -// A specific type of operation for registration. -template -class DmlOperatorSliceTemplate : public DmlOperatorSlice -{ -public: - DmlOperatorSliceTemplate(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperatorSlice(kernelInfo, opsetVersion) - { - } -}; - void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { *isSupported = (context->GetInputCount() <= 5); } -DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>); -DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>); -DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>); +DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel ); +DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTopk.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTopk.cpp index d0d04dd795..adcc5b07f8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTopk.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTopk.cpp @@ -22,6 +22,8 @@ public: std::vector> inputIndices = { 0 }; // Use only the first tensor. The second tensor is CPU-based. std::vector> outputIndices = { 0, 1 }; DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices); + DmlOperator::Remap64bitDmlDataTypesTo32bit(); + m_outputTensorDescs[1].ForceUnsignedDataType(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); @@ -70,19 +72,8 @@ private: ComPtr m_zeroOperator; }; -// A specific type of operation for registration. -template -class DmlOperatorTopKTemplate : public DmlOperatorTopK -{ -public: - DmlOperatorTopKTemplate(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperatorTopK(kernelInfo, OpsetVersion) - { - } -}; - -DML_OP_DEFINE_CREATION_FUNCTION(TopK7, DmlOperatorTopKTemplate<7>); -DML_OP_DEFINE_CREATION_FUNCTION(TopK10, DmlOperatorTopKTemplate<10>); -DML_OP_DEFINE_CREATION_FUNCTION(TopK11, DmlOperatorTopKTemplate<11>); +DML_OP_DEFINE_CREATION_FUNCTION(TopK7, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(TopK10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(TopK11, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTranspose.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTranspose.cpp index 24258b3daf..851dc2ec9e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTranspose.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorTranspose.cpp @@ -19,41 +19,25 @@ public: ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() >= 1); DmlOperator::Initialize(kernelInfo); - const MLOperatorEdgeDescription inputEdgeDescription = kernelInfo.GetInputEdgeDescription(0); - const std::vector originalSizes = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0); ML_CHECK_VALID_ARGUMENT(m_permutations.size() == originalSizes.size()); // Calculate strides from original shape. ML_CHECK_VALID_ARGUMENT(!originalSizes.empty()); std::vector inputStrides(originalSizes.size()); - inputStrides.back() = 1; - for (int i = gsl::narrow_cast(inputStrides.size()) - 2; i >= 0; i--) - { - inputStrides[i] = inputStrides[i + 1] * gsl::narrow_cast(originalSizes[i + 1]); - } + Dml::GetDescendingPackedStrides(originalSizes, /*out*/ inputStrides); - const int leadingDims = gsl::narrow_cast(m_inputTensorDescs.front().GetDimensionCount() - originalSizes.size()); - - std::vector sizes(m_inputTensorDescs.front().GetDimensionCount()); - std::vector strides(m_inputTensorDescs.front().GetDimensionCount()); - - // Fill leading tensor desc sizes/strides with defaults. - for (int dimDML = 0; dimDML < leadingDims; ++dimDML) - { - sizes[dimDML] = 1; - strides[dimDML] = 0; - } + std::vector sizes(inputStrides.size()); + std::vector strides(inputStrides.size()); // Permute the shape and strides. for (int dimInput = 0, dimCount = gsl::narrow_cast(originalSizes.size()); dimInput < dimCount; ++dimInput) { - int dimDML = dimInput + leadingDims; int dimPermuted = m_permutations[dimInput]; ML_CHECK_VALID_ARGUMENT(gsl::narrow_cast(dimPermuted) < originalSizes.size()); - sizes[dimDML] = gsl::narrow_cast(originalSizes[dimPermuted]); - strides[dimDML] = inputStrides[dimPermuted]; + sizes[dimInput] = originalSizes[dimPermuted]; + strides[dimInput] = inputStrides[dimPermuted]; } // Override the initial tensor descs. The output tensor is not strided. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 4557dd568f..3292e9d6fb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -34,22 +34,28 @@ enum class SupportedTensorDataTypes : uint32_t UInt64 = 1<<13, Complex64 = 1<<14, Complex128 = 1<<15, - Int8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, + Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, Int32to64 = UInt32|Int32|UInt64|Int64, Float16to32 = Float16|Float32, // Float64 is not supported by DirectML. - NumericDefault = Int8to32|Float16to32, + NumericDefault = Ints8to32|Float16to32, Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, Ints8Bit = UInt8|Int8, + Ints16Bit = UInt16|Int16, + Ints32Bit = UInt32|Int32, All = static_cast(-1), }; DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes); -enum class DmGraphSupport +enum class DmlGraphSupport : uint32_t { - Supported = 0, - NotSupported = 1, + Supported = 0, + NotSupported = 1, + SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides. + SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes) + Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support. }; +DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport); struct OperatorRegistrationInformation { @@ -63,7 +69,7 @@ struct OperatorRegistrationInformation gsl::span tensorTypeNames; gsl::span supportedTensorDataTypes; - DmGraphSupport DmGraphSupport; + DmlGraphSupport dmlGraphSupport; std::pair, int> requiredConstantCpuInputs = {{}, 0}; @@ -86,6 +92,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool); DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); +DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization); DML_OP_EXTERN_CREATION_FUNCTION(LRN); @@ -117,8 +124,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Ceil); DML_OP_EXTERN_CREATION_FUNCTION(Floor); DML_OP_EXTERN_CREATION_FUNCTION(Clip7); DML_OP_EXTERN_CREATION_FUNCTION(Clip11); +DML_OP_EXTERN_CREATION_FUNCTION(Clip12); DML_OP_EXTERN_CREATION_FUNCTION(Greater); DML_OP_EXTERN_CREATION_FUNCTION(Less); +DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual); +DML_OP_EXTERN_CREATION_FUNCTION(LessOrEqual); DML_OP_EXTERN_CREATION_FUNCTION(Equal); DML_OP_EXTERN_CREATION_FUNCTION(Not); DML_OP_EXTERN_CREATION_FUNCTION(And); @@ -133,6 +143,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Mean); DML_OP_EXTERN_CREATION_FUNCTION(Max); DML_OP_EXTERN_CREATION_FUNCTION(Min); DML_OP_EXTERN_CREATION_FUNCTION(ReduceSum); +DML_OP_EXTERN_CREATION_FUNCTION(Einsum12); DML_OP_EXTERN_CREATION_FUNCTION(ReduceMean); DML_OP_EXTERN_CREATION_FUNCTION(ReduceProd); DML_OP_EXTERN_CREATION_FUNCTION(ReduceLogSum); @@ -160,6 +171,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LeakyRelu); DML_OP_EXTERN_CREATION_FUNCTION(PRelu); DML_OP_EXTERN_CREATION_FUNCTION(ThresholdedRelu); DML_OP_EXTERN_CREATION_FUNCTION(Elu); +DML_OP_EXTERN_CREATION_FUNCTION(Celu); DML_OP_EXTERN_CREATION_FUNCTION(Selu); DML_OP_EXTERN_CREATION_FUNCTION(Softmax); DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax); @@ -233,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); DML_OP_EXTERN_QUERY_FUNCTION(Resize); +DML_OP_EXTERN_QUERY_FUNCTION(EinSum); constexpr static std::array typeNameListDefault = {"T"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; @@ -240,6 +253,7 @@ constexpr static std::array typeNameListThree = { "T1", "T2", "T constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; constexpr static std::array typeNameListLogicalComparison = { "T", "T1" }; +constexpr static std::array typeNameListPow12 = {"T", "T1"}; constexpr static std::array typeNameListConstantOfShape = { "T1", "T2" }; constexpr static std::array typeNameListScatterGather = { "T", "Tind" }; constexpr static std::array typeNameListScatterGatherND = { "T" }; // Tind is curiously missing, only allowing 64-bit. @@ -249,15 +263,18 @@ constexpr static std::array typeNameListEyeLike = { "T2" }; constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; -constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32}; +constexpr static std::array supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; +constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32}; constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; +constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 }; @@ -275,7 +292,7 @@ constexpr static std::array supportedTypeListLogica constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; -constexpr static std::array supportedTypeListPadWithoutFloat16 = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 }; +constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQLinearMatMul = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -305,7 +322,7 @@ constexpr auto requiredConstantCpuInputs(Args... args) // Identity operators use Copy, alias their first input, and require floating point formats // for usage in the graph, besides constant inputs. This is because they currently use -// element-wise identity operators in the graph for striding support, but issue actual copies +// element-wise identity operators in the graph for striding support, but issue actual copies // outside the graph. Element-wise identity currently only supports floating point types. #define REG_INFO_ID(version, operatorName, ...) \ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, true, ##__VA_ARGS__, @@ -325,230 +342,244 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation /// Support query function // Deep Learning Standard Layers - {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. - {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, - {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, - {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, - {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, + {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. + {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers - {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis. - {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis. - {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, - {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 - {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, - {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)}, - {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)}, - {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, + {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. + {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, + {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. - {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, // Elementwise - {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmGraphSupport::Supported)}, - {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)}, - {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmGraphSupport::Supported)}, - {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmGraphSupport::Supported)}, - {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)}, - {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, - {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)}, - {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, - {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)}, - {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, - {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, - {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, - {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, - {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, + {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )}, + {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, + {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, + {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, + {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, + {REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, + {REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, + {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, + {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, + {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, // Imaging Operators - {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions - {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, + {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, // Uncategorized - {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)}, - {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)}, + {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, - {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported)}, - {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, // Fused operators - {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + + {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, + {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, - {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmGraphSupport::Supported)}, - {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, - - {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, - {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. - - {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmGraphSupport::NotSupported)}, - {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmGraphSupport::NotSupported)}, - {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)}, - {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)}, + {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, + {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)}, + {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)}, + {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, + {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, }; template @@ -572,10 +603,13 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) MLOperatorKernelDescription desc = {}; desc.domain = information.domain; desc.name = information.operatorName; - desc.executionType = MLOperatorExecutionType::D3D12; + desc.executionType = MLOperatorExecutionType::D3D12; // The graph must be configured with operators from only the legacy DML API, or only the new DML API - bool kernelSupportsGraph = (information.DmGraphSupport == DmGraphSupport::Supported); + bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported); + bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly); + bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides); + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp); desc.options = information.shapeInferenceFunction ? MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes; @@ -651,6 +685,9 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) kernelSupportsGraph, // supportsGraph information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr, information.requiresFloatFormatsForGraph, + supportedWith64BitTensorsVia32BitStrides, + supportedWith64BitTensorsVia32BitStridesFromAnyEp, + prefer64BitTensorsDirectly, information.requiredConstantCpuInputs.first.data(), static_cast(information.requiredConstantCpuInputs.second) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h index 45923d528d..8d43fb7a70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h @@ -11,6 +11,19 @@ class MLOperatorKernelCreationContext; #define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) #define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void CALLBACK Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported); +// A specific opset version for registration. +// e.g. +// DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); +template +class VersionedKernel : public BaseClass +{ +public: + VersionedKernel(const MLOperatorKernelCreationContext& kernelInfo) + : BaseClass(kernelInfo, opsetVersion) + { + } +}; + // Declares a callback creation function of the given operator class. // This does not register it, just declares it for usage by registration later. // diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index ea498d9621..fe497bec34 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -14,6 +14,7 @@ namespace Dml switch (function) { case DML_OPERATOR_ACTIVATION_ELU: + case DML_OPERATOR_ACTIVATION_CELU: return 1.0f; case DML_OPERATOR_ACTIVATION_LEAKY_RELU: @@ -358,37 +359,45 @@ namespace Dml } } + uint32_t MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + ML_INVALID_ARGUMENT("Unknown mode value."); + } + DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode) { // The ONNX modes are "nearest" and "linear." Other modes exist for compatibility, // since Winml supported them in the past. - if (mode == "NEAREST" || mode == "nearest" || mode == "nn" || mode == "NN") { - return DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR; - } - else if (mode == "BILINEAR" || mode == "bilinear" || mode == "linear") + + constexpr NameAndIndex mapping[] = { - return DML_INTERPOLATION_MODE_LINEAR; - } - else - { - ML_INVALID_ARGUMENT("Unknown sampling interpolation mode."); - } + {"nearest", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"NEAREST", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"NN", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"nn", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"linear", DML_INTERPOLATION_MODE_LINEAR}, + {"BILINEAR", DML_INTERPOLATION_MODE_LINEAR}, + {"bilinear", DML_INTERPOLATION_MODE_LINEAR}, + }; + return MapStringToIndex(mode, mapping); } DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode) { - if (mode == "DCR") + constexpr NameAndIndex mapping[] = { - return DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW; - } - else if (mode == "CRD") - { - return DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH; - } - else - { - ML_INVALID_ARGUMENT("Unknown depth space mode."); - } + {"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, + {"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, + }; + return MapStringToIndex(mode, mapping); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index d99811790c..027c2228bb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -57,6 +57,20 @@ namespace Dml void GetDmlAdjustedAxes(/*inout*/ gsl::span axes, uint32_t onnxDimCount, uint32_t dmlDimCount, std::vector& dmlAxes); + struct NameAndIndex + { + const char* name; // Null terminated. + uint32_t index; + }; + + template + T MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) + { + return static_cast(MapStringToIndex(mode, nameAndIndexList)); + } + + uint32_t MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList); + DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode); DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index b4a286e08e..2b352efb9a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -184,53 +184,6 @@ TensorDesc::TensorDesc( } } - //////////////////////////////////////// - // Handle 64-bit tensors. - - uint64_t endPaddingInBytes = 0; - - if (dataType == MLOperatorTensorDataType::UInt64 || dataType == MLOperatorTensorDataType::Int64) - { - // DirectML doesn't support tensor of int64 because Direct3D doesn't support - // the data type. A workaround is to use strides to fake 64-bit memory access - // while only the lower 32 bits contains the data. This trick obviously doesn't - // work if the data element is genuine 64-bit. It also doesn't work if the data - // element is negative as the signed bit will be incorrectly interpreted. - m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT32; - - // If the strides haven't been calculated yet, initialize them as packed. - if (!useStrides) - { - uint32_t stride = 1; - for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--) - { - m_strides[i] = stride; - stride *= m_sizes[i]; - } - } - - // Double the stride values to emulate 64-bit integer support. - for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i) - { - m_strides[i] *= 2; - } - - useStrides = true; - - // The physical size of the tensor will have an extra 4 bytes at the end. - // DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last - // addressable element of the tensor plus the space for the last element. However, the size - // of the last element is now halved from 8 bytes to 4 bytes. - // - // Example: - // Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48 - // Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44 - // - // DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate - // so that the entire region can be zeroed. - endPaddingInBytes = sizeof(uint32_t); - } - if (useStrides) { m_bufferTensorDesc.Strides = m_strides; @@ -239,20 +192,84 @@ TensorDesc::TensorDesc( m_bufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE; m_bufferTensorDesc.GuaranteedBaseOffsetAlignment = guaranteedBaseOffsetAlignment; m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( - m_bufferTensorDesc.DataType, - m_bufferTensorDesc.DimensionCount, - m_sizes, + m_bufferTensorDesc.DataType, + m_bufferTensorDesc.DimensionCount, + m_sizes, useStrides ? m_strides : nullptr + ); + assert(m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)); +} + +void TensorDesc::Remap64bitDmlDataTypeTo32bit() +{ + if (m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_UINT64 && + m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_INT64) + { + return; // Nothing to do. + } + + uint64_t endPaddingInBytes = 0; + + // A workaround for older devices is to use strides to fake 64-bit memory access + // while only the lower 32 bits contains the data. This trick obviously doesn't + // work if the data element is genuine 64-bit. It also doesn't work if the data + // element is negative as the signed bit will be incorrectly interpreted. + m_bufferTensorDesc.DataType = Dml::Remap64bitDmlDataTypeTo32bit(m_bufferTensorDesc.DataType); + + // If the strides haven't been calculated yet, initialize them as packed. + if (m_bufferTensorDesc.Strides == nullptr) + { + uint32_t stride = 1; + for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--) + { + m_strides[i] = stride; + stride *= m_sizes[i]; + } + } + + // Double the stride values to emulate 64-bit integer support. + for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i) + { + m_strides[i] *= 2; + } + + // The physical size of the tensor will have an extra 4 bytes at the end. + // DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last + // addressable element of the tensor plus the space for the last element. However, the size + // of the last element is now halved from 8 bytes to 4 bytes. + // + // Example: + // Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48 + // Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44 + // + // DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate + // so that the entire region can be zeroed. + endPaddingInBytes = sizeof(uint32_t); + + m_bufferTensorDesc.Strides = m_strides; + + m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + m_bufferTensorDesc.DataType, + m_bufferTensorDesc.DimensionCount, + m_sizes, + m_strides ) + endPaddingInBytes; } +bool TensorDesc::WasRemapped64bitTo32bit() const +{ + bool was64BitIntType = (m_mlOperatorTensorDataType == MLOperatorTensorDataType::UInt64 || m_mlOperatorTensorDataType == MLOperatorTensorDataType::Int64); + bool is32BitIntType = (m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_UINT32 || m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_INT32); + return was64BitIntType && is32BitIntType; +} + gsl::span TensorDesc::GetStrides() const { if (m_bufferTensorDesc.Strides == nullptr) { return {}; } - return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount }; + return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount }; } DML_TENSOR_DESC TensorDesc::GetDmlDesc() @@ -297,7 +314,8 @@ void TensorDesc::ForceUnsignedDataType() m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8; break; - // Nothing to do if already unsigned + // Nothing to do if already unsigned. + case DML_TENSOR_DATA_TYPE_UINT64: case DML_TENSOR_DATA_TYPE_UINT32: case DML_TENSOR_DATA_TYPE_UINT16: case DML_TENSOR_DATA_TYPE_UINT8: @@ -307,3 +325,35 @@ void TensorDesc::ForceUnsignedDataType() ML_INVALID_ARGUMENT("Can't coerce unknown or non-integral data type"); } } + +void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment) +{ + ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount); + ML_CHECK_VALID_ARGUMENT(alignment == TensorAxis::RightAligned || alignment == TensorAxis::LeftAligned); + + const uint32_t oldDimensionCount = m_bufferTensorDesc.DimensionCount; + const int32_t difference = static_cast(newDimensionCount - oldDimensionCount); + if (difference == 0) + { + return; + } + + int32_t fillOffset = oldDimensionCount; + int32_t fillCount = std::max(0, difference); + + // alignment == TensorAxis::LeftAligned is the easy case. + // Right alignment needs more work, shifting values over. + if (alignment == TensorAxis::RightAligned) + { + fillOffset = 0; // Fill leading dimensions with 1's starting at the front. + uint32_t moveCount = std::min(newDimensionCount, oldDimensionCount); + memmove(&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof(m_sizes[0]) * moveCount); + memmove(&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof(m_strides[0]) * moveCount); + } + if (fillCount > 0) + { + std::fill(&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u); + std::fill(&m_strides[fillOffset], &m_strides[fillOffset] + fillCount, 0u); + } + m_bufferTensorDesc.DimensionCount = newDimensionCount; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index faa28fec78..e9b7d97c48 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -36,11 +36,13 @@ namespace Dml inline DML_TENSOR_DATA_TYPE GetDmlDataType() const { return m_bufferTensorDesc.DataType; } inline MLOperatorTensorDataType GetMlOperatorDataType() const { return m_mlOperatorTensorDataType; } - inline gsl::span GetDmlSizes() const { return m_sizes; } void ForceUnsignedDataType(); + void Remap64bitDmlDataTypeTo32bit(); + bool WasRemapped64bitTo32bit() const; inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; } inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; } + void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment); gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } gsl::span GetStrides() const; @@ -58,8 +60,6 @@ namespace Dml DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; }; - - class TensorDescBuilder { public: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e6f0f995b8..95ce76a9bf 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -15,6 +15,7 @@ namespace AttrName static constexpr const char* Axis = "axis"; static constexpr const char* AxisW = "axis_w"; static constexpr const char* BatchAxis = "batch_axis"; + static constexpr const char* BatchDimensions = "batch_dims"; static constexpr const char* Beta = "beta"; static constexpr const char* Bias = "bias"; static constexpr const char* BlockSize = "blocksize"; @@ -32,6 +33,7 @@ namespace AttrName static constexpr const char* Dtype = "dtype"; static constexpr const char* Ends = "ends"; static constexpr const char* Epsilon = "epsilon"; + static constexpr const char* Equation = "equation"; static constexpr const char* ExcludeOutside = "exclude_outside"; static constexpr const char* Exclusive = "exclusive"; static constexpr const char* Exponent = "exponent"; @@ -45,6 +47,7 @@ namespace AttrName static constexpr const char* InputForget = "input_forget"; static constexpr const char* K = "k"; static constexpr const char* KeepDims = "keepdims"; + static constexpr const char* SelectLastIndex = "select_last_index"; static constexpr const char* KernelShape = "kernel_shape"; static constexpr const char* LinearBeforeReset = "linear_before_reset"; static constexpr const char* Lambda = "lambd"; // Deliberate typo to match ONNX spec. @@ -57,12 +60,15 @@ namespace AttrName static constexpr const char* NearestMode = "nearest_mode"; static constexpr const char* NormalizeVariance = "normalize_variance"; static constexpr const char* P = "p"; + static constexpr const char* OutputHeight = "output_height"; static constexpr const char* OutputShape = "output_shape"; static constexpr const char* OutputPadding = "output_padding"; + static constexpr const char* OutputWidth = "output_width"; static constexpr const char* Pads = "pads"; static constexpr const char* PooledShape = "pooled_shape"; static constexpr const char* Reverse = "reverse"; static constexpr const char* SampleSize = "sample_size"; + static constexpr const char* SamplingRatio = "sampling_ratio"; static constexpr const char* Scale = "scale"; static constexpr const char* Scales = "scales"; static constexpr const char* Seed = "seed"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h index 34f60e1f6f..c85c1d7e52 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h @@ -20,13 +20,12 @@ }\ } -template T clamp_cast(I input) -{ - return static_cast(std::clamp(input, std::numeric_limits::lowest(), std::numeric_limits::max())); -} - namespace OperatorHelper { + template T clamp_cast(I input) + { + return static_cast(std::clamp(input, std::numeric_limits::lowest(), std::numeric_limits::max())); + } enum TensorAxis { N, C, H, W, DoNotCoerce = UINT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 }; enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 216dfdcb2b..f1d2a225df 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -113,6 +113,9 @@ IMLOperatorRegistryPrivate : public IUnknown bool supportsGraph, const uint32_t* requiredInputCountForGraph = nullptr, bool requiresFloatFormatsForGraph = false, + bool supportedWith64BitTensorsVia32BitStrides = false, + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, + bool prefer64BitTensorsDirectly = false, _In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr, uint32_t constantCpuInputCount = 0 ) const noexcept PURE; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 78c045a4e1..9698bbcbf0 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -113,6 +113,36 @@ namespace OperatorHelper } } + float CastFloat16ToFloat32(uint16_t input) + { + // Promote float16m10e5s1 to float32m23e8s1. + // Note this works on machines of both ascending and descending byte + // endianness, so long as float32 and uint32 endianness match. + // It does not work for a few abberant architectures which store + // float32 and uint32 with opposite endianness. + + const uint32_t float16unsignedValueMask = 0x7FFF; + const uint32_t float16signMask = 0x8000; + const uint32_t float16exponentMask = 0x7C00; + const uint32_t float32exponentMask = 0x7F800000; + + uint32_t float16unsignedValue = input & float16unsignedValueMask; + uint32_t float16sign = input & float16signMask; + uint32_t float16exponent = input & float16exponentMask; + + // Shift mantissa bits left (23 - 10 = 13). + // Adjust exponent bias (127 - 15 = 112, 112 << 23 == 0x38000000). + // Move sign bit to float32 MSB (32 - 16 = 16). + uint32_t float32unsignedValue = (float16unsignedValue << 13) + 0x38000000; + uint32_t float32sign = float16sign << 16; + uint32_t result = (float16exponent == 0) ? (float32unsignedValue & ~float32exponentMask) : // Denormal + (float16exponent == float16exponentMask) ? (float32unsignedValue | float32exponentMask) : // Infinity + float32unsignedValue; // Any other normal value + result |= float32sign; + + return reinterpret_cast(result); + } + int64_t CastToInt64(MLOperatorTensorDataType tensorDataType, const void* p) { switch (tensorDataType) @@ -150,7 +180,7 @@ namespace OperatorHelper case MLOperatorTensorDataType::Int64: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::String: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::String type is unsupported for reading as an integer."); case MLOperatorTensorDataType::Bool: return static_cast(*reinterpret_cast(p)); - case MLOperatorTensorDataType::Float16: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::Float16 type is unsupported for reading as an integer."); + case MLOperatorTensorDataType::Float16: return static_cast(CastFloat16ToFloat32(*reinterpret_cast(p))); case MLOperatorTensorDataType::Double: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::UInt32: return static_cast(*reinterpret_cast(p)); case MLOperatorTensorDataType::UInt64: return static_cast(*reinterpret_cast(p)); @@ -673,7 +703,7 @@ namespace OperatorHelper ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1); ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0); int outDimCount = gsl::narrow_cast(inputDimensions.size() + indicesDimensions.size() - 1); - ML_CHECK_VALID_ARGUMENT(outDimCount >= 0 && outDimCount <= NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(outDimCount >= 0); std::vector outputDimensions(outDimCount, 1); @@ -707,21 +737,27 @@ namespace OperatorHelper { std::vector inputDimensions = shapeInfo.GetInputTensorShape(0); std::vector indicesDimensions = shapeInfo.GetInputTensorShape(1); + int32_t batchCount = m_batchCount; // Determine the number of output dimensions. ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1); ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1); + ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > batchCount); + ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() > batchCount); const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back(); - ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex); - const uint32_t numberOfOutputDimensionsFromInput = static_cast(inputDimensions.size()) - numberOfCoordinatesPerIndex; - const uint32_t numberOfOutputDimensionsFromIndices = static_cast(indicesDimensions.size()) - 1; // Strip off last dimension. - uint32_t outputDimensionCount = gsl::narrow_cast(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput); - ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= batchCount + numberOfCoordinatesPerIndex); + const uint32_t numberOfOutputDimensionsFromInput = static_cast(inputDimensions.size()) - batchCount - numberOfCoordinatesPerIndex; + const uint32_t numberOfOutputDimensionsFromIndices = static_cast(indicesDimensions.size()) - batchCount - 1; // Strip off last dimension. + uint32_t outputDimensionCount = gsl::narrow_cast(batchCount + numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput); + ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0); - // Form the full expected size by concatenating the prefix part of the indices tensor shape - // with the suffix of the input tensor shape. + // Form the full expected size by concatenating fragments: + // 1 - batch count + // 2 - prefix part of the indices tensor shape + // 3 - suffix of the input tensor shape. std::vector outputDimensions; - outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1); + outputDimensions.assign(inputDimensions.begin(), inputDimensions.begin() + batchCount); + outputDimensions.insert(outputDimensions.end(), indicesDimensions.begin() + batchCount, indicesDimensions.end() - 1); outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end()); return { EdgeShapes(std::move(outputDimensions)) }; @@ -782,8 +818,6 @@ namespace OperatorHelper // Dim Offset : 1 std::vector reducedDims = shapeInfo.GetInputTensorShape(0); - ML_CHECK_VALID_ARGUMENT(reducedDims.size() <= NchwDimensionCount); - std::vector reduced(reducedDims.size(), false); for (auto& dim : m_axes) @@ -817,8 +851,6 @@ namespace OperatorHelper void ReduceHelperBase::AdjustAxesAndOutputShape(const std::vector& inputShape) { - ML_CHECK_VALID_ARGUMENT(inputShape.size() <= NchwDimensionCount); - // If axes is not specified, reduce over all the dimensions if (m_axes.empty()) { @@ -826,7 +858,264 @@ namespace OperatorHelper std::iota(m_axes.begin(), m_axes.end(), 0); } } - + + void EinSumHelper::Initialize() + { + ParseEquationComponents(); + m_recognizedOperatorType = DetermineRecognizedOperatorType(); + } + + void EinSumHelper::ParseEquationComponents() + { + // Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to + // numeric indices {(0,1}, {1,2}, {0,2}}. The last component is the output. + + std::map labelMap; + std::set repeatedLabels; + + uint32_t currentLabelIndex = 0; + Component currentComponent = {}; + bool foundOutput = false; + bool reachedEnd = false; + + // Read first to last character in equation, looking for letters, commas, and one arrow. + for (char* token = m_equation.data(); !reachedEnd; ++token) + { + char ch = *token; + + // Only ASCII letters are valid subscript symbols in numpy.einsum(). + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) + { + // Check whether label already has an index. + const auto [i, inserted] = labelMap.insert({ch, currentLabelIndex}); + if (inserted) + { + ML_CHECK_VALID_ARGUMENT(!foundOutput, "Found label in equation output not matching any label from inputs.") + ++currentLabelIndex; // New label found. + } + else if (!foundOutput) + { + // If label in input already found earlier, then keep track of this later + // to generate the default output in case one is not specified. + repeatedLabels.insert(ch); + } + m_labelIndices.push_back(i->second); + } + else if (ch == ' ') + { + // Ignore spaces. + } + else + { + currentComponent.labelIndexEnd = static_cast(m_labelIndices.size()); + m_components.push_back(currentComponent); + currentComponent.labelIndexBegin = currentComponent.labelIndexEnd; + + switch (ch) + { + case ',': + // Note it's valid for 2 commas be adjacent, which indicates a scalar and generates + // an empty component. + break; + + case '-': // Start of "->" (must be atomic, no space between them). + ++token; // Skip '-'. + ML_CHECK_VALID_ARGUMENT(*token == '>', "Expected '->' for output.") + ML_CHECK_VALID_ARGUMENT(foundOutput == false, "Only one output arrow '->' is valid.") + foundOutput = true; + break; + + case '.': + // Ellipsis is unsupported. Leave recognized operator as None, deferring to another EP. + m_components.clear(); + return; + + case '\0': + reachedEnd = true; + break; // End of string. + + default: + ML_INVALID_ARGUMENT("Unsupported character in equation string. Must be a-z, A-Z, ',', or '->'."); + } + } + } + + if (!foundOutput) + { + // If no explicit output was given, generate an implicit output by ordering all the + // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a). + // Exclude any labels that occurred more than once, as these cancel out. + + for (auto i : labelMap) + { + if (repeatedLabels.count(i.first) == 0) + { + m_labelIndices.push_back(i.second); + } + } + + // Push the final component, which is the output. + currentComponent.labelIndexEnd = static_cast(m_labelIndices.size()); + m_components.push_back(currentComponent); + } + } + + EinSumHelper::RecognizedOperatorType EinSumHelper::DetermineRecognizedOperatorType() + { + if (m_components.empty()) + { + return RecognizedOperatorType::None; // Parsing may have found unsupported components - treating as unknown. + } + + // std::ranges::equal is not supported yet. + auto equals = [](gsl::span a, gsl::span b) + { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); + }; + + std::array componentRanks; + if (m_components.size() > componentRanks.size()) + { + // No recognized operator takes more than 2 inputs and 1 output. + // EinSum itself is generic and can handle any variable number of inputs, + // but DML's operators expect fixed counts. + return RecognizedOperatorType::None; + } + else if (m_components.size() == 2) + { + auto& inputLabels = m_components[0].GetLabels(m_labelIndices); + auto& outputLabels = m_components[1].GetLabels(m_labelIndices); + if (inputLabels.size() == outputLabels.size()) + { + // Check identity. + if (equals(inputLabels, outputLabels)) + { + // Handles: "->", "i->i", "ij->ij", "ijk->ijk", "ijkl->ijkl" ... + return RecognizedOperatorType::Identity; + } + else // Transpose since a permutation exists. + { + // Handles: "ij->ji", "ijk->kji", "ijkl->lkji", "ijkl->ijkl" ... + return RecognizedOperatorType::Transpose; + } + } + else if (outputLabels.empty()) // Scalar output, with all inputs reduced. + { + // Handles: "i->", "ij->", "ijk->", "ijkl->" ... + return RecognizedOperatorType::ReduceSum; + } + } + else if (m_components.size() == 3) + { + // If all components have the same size and label order, then apply elementwise multiplication. + auto& inputALabels = m_components[0].GetLabels(m_labelIndices); + auto& inputBLabels = m_components[1].GetLabels(m_labelIndices); + auto& outputLabels = m_components[2].GetLabels(m_labelIndices); + if (equals(inputALabels, outputLabels) && equals(inputBLabels, outputLabels)) + { + // Handles: "i,i->i", "ij,ij->ij", "ijk,ijk->ijk", "ijkl,ijkl->ijkl" ... + return RecognizedOperatorType::Multiply; + } + } + + // Otherwise check for special cases of dedicated operators... + + struct RecognizedOperatorInfo + { + RecognizedOperatorType recognizedOperatorType; + std::initializer_list componentRanks; + std::initializer_list labelIndices; + }; + + const RecognizedOperatorInfo recognizedOperators[] = { + {RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik + {RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik + {RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik + {RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik + {RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik + {RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik + {RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik + {RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik + {RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik + {RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod) + {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i + {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j + }; + + // For each recognized operator, compare the labels-per-component and label indices. + for (auto& recognizedOperator : recognizedOperators) + { + if (equals(m_labelIndices, recognizedOperator.labelIndices) + && m_components.size() == recognizedOperator.componentRanks.size()) + { + for (size_t i = 0; i < m_components.size(); ++i) + { + componentRanks[i] = m_components[i].GetDimensionCount(); + } + + if (equals(gsl::make_span(componentRanks.data(), m_components.size()), recognizedOperator.componentRanks)) + { + return recognizedOperator.recognizedOperatorType; + } + } + } + + return RecognizedOperatorType::None; + } + + std::vector EinSumHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + assert(!m_components.empty()); // Should have already parsed components. + + uint32_t inputCount = shapeInfo.GetInputCount(); + uint32_t outputCount = shapeInfo.GetOutputCount(); + ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count."); + ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor."); + + std::vector labelSizes(m_labelIndices.size(), INT_MIN); + + // Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier. + for (uint32_t i = 0; i < inputCount; ++i) + { + auto inputShape = shapeInfo.GetInputTensorShape(i); + auto& component = m_components[i]; + auto labelIndices = component.GetLabels(m_labelIndices); + uint32_t dimensionCount = component.GetDimensionCount(); + + ML_CHECK_VALID_ARGUMENT(inputShape.size() == dimensionCount, "Mismatch between input tensor shape and string equation label count."); + + for (uint32_t i = 0; i < dimensionCount; ++i) + { + // If this is the first time seeing this label, then record the size. + // Otherwise any following occurrences of the label must match sizes. + // e.g. Given "ij,ji", both i's and both j's must match dimension sizes. + uint32_t dimensionSize = inputShape[i]; + uint32_t labelIndex = labelIndices[i]; + assert(labelIndex < labelSizes.size()); + + if (labelSizes[labelIndex] == INT_MIN) + { + labelSizes[labelIndex] = dimensionSize; + } + else + { + ML_CHECK_VALID_ARGUMENT(labelSizes[labelIndex] == dimensionSize, "All labels must have the same dimension sizes."); + } + } + } + + // Generate output dimensions from corresponding input tensor labels. + // e.g. Given ij,jk->ij with [2,3] and [3,5], the output is [2,5]. + std::vector outputDimensions; + auto outputLabelIndices = m_components.back().GetLabels(m_labelIndices); + for (auto labelIndex : outputLabelIndices) + { + outputDimensions.push_back(labelSizes[labelIndex]); + } + + return { std::move(EdgeShapes(outputDimensions)) }; + } + std::vector MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2); @@ -971,7 +1260,6 @@ namespace OperatorHelper std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto outputShape = shapeInfo.GetInputTensorShape(0); - ML_CHECK_VALID_ARGUMENT(outputShape.size() <= NchwDimensionCount); uint32_t inputCount = shapeInfo.GetInputCount(); @@ -1110,8 +1398,25 @@ namespace OperatorHelper { roiShape[0], // number of ROIs inputShape[C], // number of channels - static_cast(m_pooledSizeH), - static_cast(m_pooledSizeW), + static_cast(m_outputSizeH), + static_cast(m_outputSizeW), + }; + + return { std::move(EdgeShapes(outputDimensions)) }; + } + + std::vector RoiAlignHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); + auto inputShape = shapeInfo.GetInputTensorShape(InputTensors::INPUT); + ML_CHECK_VALID_ARGUMENT(inputShape.size() >= 4, "inputShape must be >= 4."); + + DimensionType outputDimensions[4] = + { + roiShape[0], // number of ROIs + inputShape[C], // number of channels + static_cast(m_outputSizeH), + static_cast(m_outputSizeW), }; return { std::move(EdgeShapes(outputDimensions)) }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0a845bb027..3bd8064dc6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -633,7 +633,7 @@ public: { int dimIndex = axes.empty() ? i : axes[i]; int stride = steps.empty() ? 1 : steps[i]; - ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); + ML_CHECK_VALID_ARGUMENT(static_cast(dimIndex) < static_cast(inputDimensions.size()), "'axes' must be valid with within actual input dimensions."); ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0."); // Positive values are offsets from 0. @@ -733,6 +733,7 @@ class ReduceHelperBase { template ReduceHelperBase(const Info_t& info, const Shape_t& shape, bool usingAxes) { m_keepDims = info.GetOptionalAttribute(AttrName::KeepDims, 1); + m_selectLastIndex = info.GetOptionalAttribute(AttrName::SelectLastIndex, 0); if (usingAxes) { m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes); } else { @@ -751,6 +752,7 @@ class ReduceHelperBase { protected: std::vector m_axes; int m_keepDims = 0; + int m_selectLastIndex = 0; }; class ArgMinArgMaxHelper : public ReduceHelperBase { @@ -769,6 +771,70 @@ class ReduceHelper : public ReduceHelperBase { ReduceHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, true) {} }; +class EinSumHelper +{ +public: + void Initialize(); + + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + EinSumHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) + { + m_equation = info.GetAttribute(AttrName::Equation); + Initialize(); + } + + EinSumHelper(const MLOperatorAttributes& info) + { + m_equation = info.GetAttribute(AttrName::Equation); + Initialize(); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + + enum class RecognizedOperatorType + { + None, + Identity, + Multiply, + MatMul, + MatMulTransposeA, + MatMulTransposeB, + ReduceSum, + Transpose, + Total, + }; + + RecognizedOperatorType GetRecognizedOperatorType() const noexcept { return m_recognizedOperatorType; } + +protected: + void ParseEquationComponents(); + RecognizedOperatorType DetermineRecognizedOperatorType(); + +protected: + struct Component + { + uint32_t labelIndexBegin; + uint32_t labelIndexEnd; + + uint32_t GetDimensionCount() const noexcept + { + return labelIndexEnd - labelIndexBegin; + } + gsl::span GetLabels(gsl::span labels) const + { + return labels.subspan(labelIndexBegin, labelIndexEnd - labelIndexBegin); + }; + }; + + std::string m_equation; + std::vector m_labelIndices; // Concatenation of all labels as rebased indices ("ij,ai" -> 0,1,2,0). + std::vector m_components; // All components in order, including inputs and output. + std::vector m_outputDimensions; + RecognizedOperatorType m_recognizedOperatorType = RecognizedOperatorType::None; +}; + class MatMulHelperBase { public: // Info_t is used to obtain attributes which will be used for calculating the output shape later. @@ -975,9 +1041,13 @@ class GatherNdHelper { // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template GatherNdHelper(const Info_t& info, const Shape_t& shape) { + m_batchCount = info.GetOptionalAttribute(AttrName::BatchDimensions, 0); } std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +protected: + int32_t m_batchCount; }; class PoolingHelperBase { @@ -1040,26 +1110,51 @@ class PoolingHelper : public PoolingHelperBase { PoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} }; -class RoiPoolingHelper { - public: - enum InputTensors { INPUT, - ROIS }; +class RoiPoolingHelperBase +{ +public: + enum InputTensors { INPUT, ROIS, BATCH_INDICES }; - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - RoiPoolingHelper(const Info_t& info, const Shape_t& shape) { - std::vector pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape); - ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2."); - m_pooledSizeH = pooledShape[0]; - m_pooledSizeW = pooledShape[1]; - } + RoiPoolingHelperBase() + {} - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - uint32_t m_pooledSizeW; - uint32_t m_pooledSizeH; +protected: + uint32_t m_outputSizeW = 1; + uint32_t m_outputSizeH = 1; +}; + +class RoiPoolingHelper : public RoiPoolingHelperBase +{ +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + RoiPoolingHelper(const Info_t& info, const Shape_t& shape) + { + std::vector pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape); + ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2."); + m_outputSizeH = pooledShape[0]; + m_outputSizeW = pooledShape[1]; + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + +class RoiAlignHelper : public RoiPoolingHelperBase +{ +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + RoiAlignHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) + { + m_outputSizeW = info.GetOptionalAttribute(AttrName::OutputWidth, 1); + m_outputSizeH = info.GetOptionalAttribute(AttrName::OutputHeight, 1); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; class SqueezeHelper { @@ -1330,6 +1425,7 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper; using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; +using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper; @@ -1382,8 +1478,11 @@ using ShapeInferenceHelper_Ceil = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_GreaterOrEqual = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_LessOrEqual = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Equal = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Not = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_And = GetBroadcastedOutputShapeHelper; @@ -1430,6 +1529,7 @@ using ShapeInferenceHelper_ReduceL1 = ReduceHelper; using ShapeInferenceHelper_ReduceL2 = ReduceHelper; using ShapeInferenceHelper_ReduceMax = ReduceHelper; using ShapeInferenceHelper_ReduceMin = ReduceHelper; +using ShapeInferenceHelper_Einsum12 = VersionedOpsetHelper; using ShapeInferenceHelper_ArgMax = ArgMinArgMaxHelper; using ShapeInferenceHelper_ArgMin = ArgMinArgMaxHelper; using ShapeInferenceHelper_Gemm = GemmHelper; @@ -1450,6 +1550,7 @@ using ShapeInferenceHelper_LeakyRelu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_PRelu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_ThresholdedRelu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index bb696c8fde..9d5c248559 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -245,6 +245,24 @@ namespace OperatorHelper static const int sc_sinceVer_Unsqueeze = 11; } // namespace OnnxOperatorSet11 + namespace OnnxOperatorSet12 + { + static const int sc_sinceVer_ArgMin = 12; + static const int sc_sinceVer_ArgMax = 12; + static const int sc_sinceVer_Celu = 12; + static const int sc_sinceVer_Clip = 12; + static const int sc_sinceVer_Einsum = 12; + static const int sc_sinceVer_GatherND = 12; + static const int sc_sinceVer_GreaterOrEqual = 12; + static const int sc_sinceVer_LessOrEqual = 12; + static const int sc_sinceVer_MaxPool = 12; + static const int sc_sinceVer_Min = 12; + static const int sc_sinceVer_Max = 12; + static const int sc_sinceVer_Pow = 12; + static const int sc_sinceVer_ReduceMax = 12; + static const int sc_sinceVer_ReduceMin = 12; + } // namespace OnnxOperatorSet12 + namespace MsftOperatorSet1 { static const int sc_sinceVer_FusedConv = 1; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h index 79e2cdba4e..6c47e60e63 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include diff --git a/packages.config b/packages.config index 816872da32..91960d191c 100644 --- a/packages.config +++ b/packages.config @@ -1,5 +1,5 @@  - + diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index b9b667618f..51cc4fc26e 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -288,7 +288,7 @@ def generate_files(list, args): '" target="runtimes\\win-' + args.target_architecture + '\\native" />') files_list.append('') - files_list.append('') if includes_winml: diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index f292bdff86..a04871b1c5 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -56,6 +56,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool supportsGraph, const uint32_t* requiredInputCountForGraph, bool requiresFloatFormatsForGraph, + bool supportedWith64BitTensorsVia32BitStrides, + bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, + bool prefer64BitTensorsDirectly, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { #ifdef LAYERING_DONE @@ -79,6 +82,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( supportsGraph, requiredInputCountForGraph, requiresFloatFormatsForGraph, + supportedWith64BitTensorsVia32BitStrides, + supportedWith64BitTensorsVia32BitStridesFromAnyEp, + prefer64BitTensorsDirectly, requiredConstantCpuInputs, constantCpuInputCount); } diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index 040f8e28c0..e1bfae7f36 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -29,6 +29,9 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool supports_graph, const uint32_t* required_input_count_for_graph = nullptr, bool requires_float_formats_for_graph = false, + bool supports_64bit_directly = false, + bool allows_64bit_via_strides = false, + bool allows_64bit_via_strides_from_any_ep = false, _In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr, uint32_t constant_cpu_input_count = 0) const noexcept override; diff --git a/winml/lib/Api/impl/FeatureCompatibility.h b/winml/lib/Api/impl/FeatureCompatibility.h index 5338623a9e..2966f854b2 100644 --- a/winml/lib/Api/impl/FeatureCompatibility.h +++ b/winml/lib/Api/impl/FeatureCompatibility.h @@ -54,9 +54,10 @@ static std::string ToString(wfc::IVectorView shape) { static std::string ToString( winml::TensorKind kind, wfc::IVectorView shape) { - FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(kind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); + // Any unrecognized data type is considered "Undefined". + if (static_cast(kind) >= std::size(SzTensorKind)) { + kind = winml::TensorKind::Undefined; + } std::ostringstream stream; stream << SzTensorKind[static_cast(kind)] << ToString(shape); @@ -73,9 +74,10 @@ static std::string ToString(winml::ITensor value) { static std::string ToString(winml::IMapFeatureDescriptor descriptor) { auto keyKind = descriptor.KeyKind(); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); + // Any unrecognized data type is considered "Undefined". + if (static_cast(keyKind) >= std::size(SzTensorKind)) { + keyKind = winml::TensorKind::Undefined; + } auto valueDescriptor = descriptor.ValueDescriptor(); std::ostringstream stream; @@ -86,9 +88,10 @@ static std::string ToString(winml::IMapFeatureDescriptor descriptor) { static std::string ToString(winrt::com_ptr<_winml::IMapFeatureValue> value) { winml::TensorKind keyKind; FAIL_FAST_IF_FAILED(value->get_KeyKind(&keyKind)); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); + // Any unrecognized data type is considered "Undefined". + if (static_cast(keyKind) >= std::size(SzTensorKind)) { + keyKind = winml::TensorKind::Undefined; + } winml::ILearningModelFeatureDescriptor valueDescriptor; FAIL_FAST_IF_FAILED(value->get_ValueDescriptor(&valueDescriptor));