Merge pull request #4925 from microsoft/user/dwayner/Iron

ORT DirectML EP for Iron release, ONNX 1.5
This commit is contained in:
Dwayne Robinson 2020-08-28 12:28:30 -07:00 committed by GitHub
commit 040c5fa3e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 2953 additions and 874 deletions

View file

@ -20,7 +20,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.2.1.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)
# Restore nuget packages, which will pull down the DirectML redist package
add_custom_command(

View file

@ -0,0 +1,3 @@
# Readability matters. Prevent syntax noise in pull requests for people who
# accidentally leave enabled the auto-formatting options in Visual Studio.
DisableFormat: true

View file

@ -94,11 +94,16 @@ namespace Windows::AI::MachineLearning::Adapter
const void* executionHandle,
DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;
struct GraphNodeFactoryRegistration
{
GraphNodeFactory factory;
std::optional<uint32_t> requiredInputCount;
// The operator inputs/outputs must be a floating point data type. When true,
// if the node's tensor data type is not-floating point, the node is partioned
// separately (unless the input/output is a CPU constant input, which is okay,
// as those can be read directly by the DML operator in the DML_OPERATOR_DESC).
bool requiresFloatFormatsExceptConstInputs = false;
};
@ -109,6 +114,20 @@ namespace Windows::AI::MachineLearning::Adapter
std::vector<uint32_t> requiredConstantCpuInputs;
std::optional<GraphNodeFactoryRegistration> graphNodeFactoryRegistration;
KernelSupportQuery supportQuery;
// Many ONNX operators use 64-bit tensors, but most DML operators only support
// 32-bit indices. This flag indicates to the graph whether it's okay to compute
// the result using 32-bit tensors (ignoring the upper bits) via doubled strides.
bool supportedWith64BitTensorsVia32BitStrides = false;
// When true, the input to the current operator may come from any execution
// provider. Otherwise it must have come from another DML node to assume it's safe
// to use 64-bit to 32-bit striding.
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
// Operator supports true 64-bit tensors directly, no strides needed.
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
bool prefer64BitTensorsDirectly = false;
};
using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;

View file

@ -334,6 +334,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try
{
@ -456,6 +459,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
{
auto regInfo = std::make_shared<InternalRegistrationInfo>();
regInfo->requiredConstantCpuInputs = constantCpuInputCapture;
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
// Only internal operators support usage in DML graphs
if (supportsGraph)
@ -527,8 +533,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
else
{
// Currently unsupported for external operators
if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph ||
requiresFloatFormatsForGraph || requiredConstantCpuInputs)
if (canAliasFirstInput ||
supportsGraph ||
requiredInputCountForGraph ||
requiresFloatFormatsForGraph ||
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
prefer64BitTensorsDirectly)
{
THROW_HR(E_INVALIDARG);
}

View file

@ -42,6 +42,9 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;

View file

@ -16,29 +16,42 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
case MLOperatorTensorDataType::UInt16: return DML_TENSOR_DATA_TYPE_UINT16;
case MLOperatorTensorDataType::Int16: return DML_TENSOR_DATA_TYPE_INT16;
case MLOperatorTensorDataType::Int32: return DML_TENSOR_DATA_TYPE_INT32;
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_INT64;
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT32; // Stride is used to access lower 32-bits.
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT64;
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Complex128: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Undefined:
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;;
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;
};
}
DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept
{
switch (dmlElementType)
{
case DML_TENSOR_DATA_TYPE_UINT64: return DML_TENSOR_DATA_TYPE_UINT32; break;
case DML_TENSOR_DATA_TYPE_INT64: return DML_TENSOR_DATA_TYPE_INT32; break;
default: return dmlElementType;
}
}
bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
{
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT64: return true;
case DML_TENSOR_DATA_TYPE_FLOAT32: return true;
case DML_TENSOR_DATA_TYPE_FLOAT16: return true;
case DML_TENSOR_DATA_TYPE_UINT64: return false;
case DML_TENSOR_DATA_TYPE_UINT32: return false;
case DML_TENSOR_DATA_TYPE_UINT16: return false;
case DML_TENSOR_DATA_TYPE_UINT8: return false;
case DML_TENSOR_DATA_TYPE_INT64: return true;
case DML_TENSOR_DATA_TYPE_INT32: return true;
case DML_TENSOR_DATA_TYPE_INT16: return true;
case DML_TENSOR_DATA_TYPE_INT8: return true;
@ -70,9 +83,14 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
case DML_TENSOR_DATA_TYPE_INT32: return MLOperatorTensorDataType::Int32;
case DML_TENSOR_DATA_TYPE_FLOAT16: return MLOperatorTensorDataType::Float16;
case DML_TENSOR_DATA_TYPE_UINT32: return MLOperatorTensorDataType::UInt32;
case DML_TENSOR_DATA_TYPE_UINT64: return MLOperatorTensorDataType::UInt64;
case DML_TENSOR_DATA_TYPE_INT64: return MLOperatorTensorDataType::Int64;
case DML_TENSOR_DATA_TYPE_FLOAT64: return MLOperatorTensorDataType::Double;
default: ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
};
}
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
{
return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType);
@ -90,4 +108,40 @@ size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor)
return ComputeByteSizeFromDimensions(gsl::make_span(dimensions.data(), dimensionCount), tensor.GetTensorDataType());
}
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
{
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
THROW_IF_FAILED(dmlDevice->CheckFeatureSupport(
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
sizeof(dataTypeQuery),
&dataTypeQuery,
sizeof(dataTypeSupport),
&dataTypeSupport
));
deviceTypeMask |= (dataTypeSupport.IsSupported << i);
}
return deviceTypeMask;
}
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides)
{
assert(sizes.size() == strides.size());
uint32_t stride = 1;
for (size_t i = strides.size(); i-- > 0; )
{
strides[i] = stride;
stride *= sizes[i];
}
}
} // namespace Dml

View file

@ -10,13 +10,16 @@ namespace Dml
{
using namespace OperatorHelper;
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX;
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX1;
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType);
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept;
DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept;
MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType);
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice);
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides);
bool IsSigned(DML_TENSOR_DATA_TYPE dataType);
@ -40,6 +43,12 @@ namespace Dml
UINT elementSizeInBytes = 0;
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT64:
case DML_TENSOR_DATA_TYPE_UINT64:
case DML_TENSOR_DATA_TYPE_INT64:
elementSizeInBytes = 8;
break;
case DML_TENSOR_DATA_TYPE_FLOAT32:
case DML_TENSOR_DATA_TYPE_UINT32:
case DML_TENSOR_DATA_TYPE_INT32:

View file

@ -376,8 +376,9 @@ namespace Dml
{
assert(!m_closed);
const size_t sourceSizeInBytes = ComputeByteSizeFromTensor(*src);
const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst);
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src)); // Tensors must be the same size
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != sourceSizeInBytes); // Tensors must be the same size
if (dataSizeInBytes == 0)
{
@ -461,7 +462,7 @@ namespace Dml
}
CATCH_RETURN();
uint32_t ExecutionProviderImpl::GetSuppportedDeviceDataTypeMask() const
uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const
{
// The DML provider registers all supported kernels up-front regardless of actual device capability,
// but this is problematic later when executing the graph because DirectML will fail to create
@ -470,26 +471,7 @@ namespace Dml
// handle them, similar to the fallback in CUDAExecutionProvider::GetCapability for certain RNN/GRU/Conv
// attributes.
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
THROW_IF_FAILED(m_dmlDevice->CheckFeatureSupport(
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
sizeof(dataTypeQuery),
&dataTypeQuery,
sizeof(dataTypeSupport),
&dataTypeSupport
));
deviceTypeMask |= (dataTypeSupport.IsSupported << i);
}
return deviceTypeMask;
return Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
@ -498,7 +480,7 @@ namespace Dml
const std::vector<const onnxruntime::KernelRegistry*>& registries) const
{
std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_";
uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask();
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask();
return PartitionGraph(
graph,

View file

@ -16,10 +16,9 @@ using Base = Microsoft::WRL::RuntimeClass<
TInterfaces...>;
}
using namespace Microsoft::WRL;
namespace Dml
{
using Microsoft::WRL::ComPtr;
class PooledUploadHeap;
class ReadbackHeap;
class ExecutionContext;
@ -87,7 +86,7 @@ namespace Dml
const std::vector<const onnxruntime::KernelRegistry*>& registries
) const;
uint32_t GetSuppportedDeviceDataTypeMask() const;
uint32_t GetSupportedDeviceDataTypeMask() const;
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const;

View file

@ -7,6 +7,7 @@ union ActivationOperatorDescUnion
{
DML_ACTIVATION_IDENTITY_OPERATOR_DESC identity;
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
DML_ACTIVATION_CELU_OPERATOR_DESC celu;
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
@ -36,6 +37,7 @@ struct ActivationOperatorDesc
switch (activationType)
{
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, &params.elu };
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, &params.celu };
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, &params.hardmax };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, &params.identity };

View file

@ -24,8 +24,8 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 120;
static constexpr size_t ActivationFunctionCount = 19;
static constexpr auto ValueCount = 141;
static constexpr size_t ActivationFunctionCount = 20;
};
template <>
@ -62,7 +62,7 @@ struct EnumTraits<DML_CONVOLUTION_DIRECTION>
template <>
struct EnumTraits<DML_PADDING_MODE>
{
static constexpr auto ValueCount = 3;
static constexpr auto ValueCount = 4;
};
template <>
@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
template <>
struct EnumTraits<DML_FEATURE_LEVEL>
{
static constexpr auto ValueCount = 2;
static constexpr auto ValueCount = 4;
};
template <>
@ -113,6 +113,12 @@ struct EnumTraits<DML_ROUNDING_MODE>
static constexpr auto ValueCount = 3;
};
template <>
struct EnumTraits<DML_RANDOM_GENERATOR_TYPE>
{
static constexpr auto ValueCount = 1;
};
template <typename T>
constexpr auto EnumValueCount = EnumTraits<T>::ValueCount;
@ -273,6 +279,18 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC>
{
@ -393,6 +411,18 @@ struct OperatorDescTraits<DML_REDUCE_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REDUCE;
};
template <>
struct OperatorDescTraits<DML_ARGMIN_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMIN;
};
template <>
struct OperatorDescTraits<DML_ARGMAX_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMAX;
};
template <>
struct OperatorDescTraits<DML_AVERAGE_POOLING_OPERATOR_DESC>
{
@ -747,6 +777,12 @@ struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1;
};
template <>
struct OperatorDescTraits<DML_RESAMPLE1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE1;
};
template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_DESC>
{
@ -771,12 +807,108 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_AND;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_OR;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_XOR;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_NOT;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_COUNT;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_RELU_GRAD;
};
template <>
struct OperatorDescTraits<DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING_GRAD;
};
template <>
struct OperatorDescTraits<DML_MAX_POOLING_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING_GRAD;
};
template <>
struct OperatorDescTraits<DML_RANDOM_GENERATOR_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RANDOM_GENERATOR;
};
template <>
struct OperatorDescTraits<DML_NONZERO_COORDINATES_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_NONZERO_COORDINATES;
};
template <>
struct OperatorDescTraits<DML_RESAMPLE_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE_GRAD;
};
template <>
struct OperatorDescTraits<DML_SLICE_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE_GRAD;
};
template <>
struct OperatorDescTraits<DML_ADAM_OPTIMIZER_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ADAM_OPTIMIZER;
};
template <>
struct OperatorDescTraits<DML_ROI_ALIGN_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN;
};
template <>
struct OperatorDescTraits<DML_GATHER_ND1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_ELU;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_CELU_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_CELU;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_HARDMAX_OPERATOR_DESC>
{
@ -993,6 +1125,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_L
using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL>
{
using DescType = DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL>
{
using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>
{
@ -1113,6 +1257,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REDUCE>
using DescType = DML_REDUCE_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMIN>
{
using DescType = DML_ARGMIN_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMAX>
{
using DescType = DML_ARGMAX_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING>
{
@ -1467,6 +1623,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZ
using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE1>
{
using DescType = DML_RESAMPLE1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER>
{
@ -1491,12 +1653,108 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVO
using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_AND>
{
using DescType = DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_OR>
{
using DescType = DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_XOR>
{
using DescType = DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_NOT>
{
using DescType = DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_COUNT>
{
using DescType = DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_RELU_GRAD>
{
using DescType = DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING_GRAD>
{
using DescType = DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING_GRAD>
{
using DescType = DML_MAX_POOLING_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RANDOM_GENERATOR>
{
using DescType = DML_RANDOM_GENERATOR_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_NONZERO_COORDINATES>
{
using DescType = DML_NONZERO_COORDINATES_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE_GRAD>
{
using DescType = DML_RESAMPLE_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE_GRAD>
{
using DescType = DML_SLICE_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ADAM_OPTIMIZER>
{
using DescType = DML_ADAM_OPTIMIZER_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN>
{
using DescType = DML_ROI_ALIGN_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1>
{
using DescType = DML_GATHER_ND1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
{
using DescType = DML_ACTIVATION_ELU_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_CELU>
{
using DescType = DML_ACTIVATION_CELU_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX>
{
@ -1652,6 +1910,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR:
@ -1692,6 +1954,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_GEMM_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_REDUCE:
return std::invoke(std::forward<Visitor>(visitor), DML_REDUCE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ARGMIN:
return std::invoke(std::forward<Visitor>(visitor), DML_ARGMIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ARGMAX:
return std::invoke(std::forward<Visitor>(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_AVERAGE_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LP_POOLING:
@ -1820,8 +2086,40 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_CONVOLUTION_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_BIT_AND:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_BIT_OR:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_RELU_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_AVERAGE_POOLING_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MAX_POOLING_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_MAX_POOLING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RANDOM_GENERATOR:
return std::invoke(std::forward<Visitor>(visitor), DML_RANDOM_GENERATOR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_NONZERO_COORDINATES:
return std::invoke(std::forward<Visitor>(visitor), DML_NONZERO_COORDINATES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RESAMPLE_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_RESAMPLE_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_SLICE_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_SLICE_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ADAM_OPTIMIZER:
return std::invoke(std::forward<Visitor>(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ROI_ALIGN:
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_GATHER_ND1:
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_CELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARDMAX:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
@ -1887,6 +2185,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR";
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR";
@ -1907,6 +2207,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION";
case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM";
case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE";
case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
@ -1971,6 +2273,21 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY";
case DML_OPERATOR_CONVOLUTION_INTEGER: return "DML_OPERATOR_CONVOLUTION_INTEGER";
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION";
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return "DML_OPERATOR_ELEMENT_WISE_BIT_AND";
case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return "DML_OPERATOR_ELEMENT_WISE_BIT_OR";
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return "DML_OPERATOR_ELEMENT_WISE_BIT_XOR";
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return "DML_OPERATOR_ELEMENT_WISE_BIT_NOT";
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return "DML_OPERATOR_ELEMENT_WISE_BIT_COUNT";
case DML_OPERATOR_ACTIVATION_RELU_GRAD: return "DML_OPERATOR_ACTIVATION_RELU_GRAD";
case DML_OPERATOR_AVERAGE_POOLING_GRAD: return "DML_OPERATOR_AVERAGE_POOLING_GRAD";
case DML_OPERATOR_MAX_POOLING_GRAD: return "DML_OPERATOR_MAX_POOLING_GRAD";
case DML_OPERATOR_RANDOM_GENERATOR: return "DML_OPERATOR_RANDOM_GENERATOR";
case DML_OPERATOR_NONZERO_COORDINATES: return "DML_OPERATOR_NONZERO_COORDINATES";
case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD";
case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD";
case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER";
case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN";
case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1";
default:
assert(false);
return "<unknown>";

View file

@ -296,6 +296,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL",
DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL",
DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS[2] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -599,6 +627,38 @@ constexpr DML_OPERATOR_SCHEMA DML_REDUCE_OPERATOR_SCHEMA {
DML_REDUCE_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ARGMIN_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ARGMIN_OPERATOR_SCHEMA {
"DML_OPERATOR_ARGMIN",
DML_OPERATOR_ARGMIN,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
5,
DML_ARGMIN_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ARGMAX_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ARGMAX_OPERATOR_SCHEMA {
"DML_OPERATOR_ARGMAX",
DML_OPERATOR_ARGMAX,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
5,
DML_ARGMAX_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -1628,7 +1688,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIEL
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -1648,6 +1708,247 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA {
DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_BIT_AND",
DML_OPERATOR_ELEMENT_WISE_BIT_AND,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_BIT_OR",
DML_OPERATOR_ELEMENT_WISE_BIT_OR,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_BIT_XOR",
DML_OPERATOR_ELEMENT_WISE_BIT_XOR,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS[2] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_BIT_NOT",
DML_OPERATOR_ELEMENT_WISE_BIT_NOT,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
2,
DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS[2] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT",
DML_OPERATOR_ELEMENT_WISE_BIT_COUNT,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
2,
DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_ACTIVATION_RELU_GRAD",
DML_OPERATOR_ACTIVATION_RELU_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
};
constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_AVERAGE_POOLING_GRAD",
DML_OPERATOR_AVERAGE_POOLING_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_MAX_POOLING_GRAD",
DML_OPERATOR_MAX_POOLING_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputStateTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputStateTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Type", false },
};
constexpr DML_OPERATOR_SCHEMA DML_RANDOM_GENERATOR_OPERATOR_SCHEMA {
"DML_OPERATOR_RANDOM_GENERATOR",
DML_OPERATOR_RANDOM_GENERATOR,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCountTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCoordinatesTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_NONZERO_COORDINATES_OPERATOR_SCHEMA {
"DML_OPERATOR_NONZERO_COORDINATES",
DML_OPERATOR_NONZERO_COORDINATES,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS[7] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Scales", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "InputPixelOffsets", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false },
};
constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_RESAMPLE_GRAD",
DML_OPERATOR_RESAMPLE_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
7,
DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false },
};
constexpr DML_OPERATOR_SCHEMA DML_SLICE_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_SLICE_GRAD",
DML_OPERATOR_SLICE_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
6,
DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS[12] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputParametersTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputFirstMomentTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputSecondMomentTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "GradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "TrainingStepTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputParametersTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputFirstMomentTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSecondMomentTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "LearningRate", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta1", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta2", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA {
"DML_OPERATOR_ADAM_OPTIMIZER",
DML_OPERATOR_ADAM_OPTIMIZER,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
12,
DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS[11] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA {
"DML_OPERATOR_ROI_ALIGN",
DML_OPERATOR_ROI_ALIGN,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
11,
DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BatchDimensionCount", false },
};
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA {
"DML_OPERATOR_GATHER_ND1",
DML_OPERATOR_GATHER_ND1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
6,
DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -1662,6 +1963,20 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_ELU_OPERATOR_SCHEMA {
DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_CELU_OPERATOR_SCHEMA {
"DML_OPERATOR_ACTIVATION_CELU",
DML_OPERATOR_ACTIVATION_CELU,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS[2] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },

View file

@ -145,6 +145,22 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC& desc)
{
return {
@ -328,6 +344,26 @@ inline std::vector<OperatorField> GetFields(const DML_REDUCE_OPERATOR_DESC& desc
OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_ARGMIN_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ARGMAX_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
};
}
inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_OPERATOR_DESC& desc)
{
return {
@ -993,6 +1029,157 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTI
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast<UINT>(desc.GroupCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputStateTensor))),
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputStateTensor))),
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Type))),
};
}
inline std::vector<OperatorField> GetFields(const DML_NONZERO_COORDINATES_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputCountTensor))),
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputCoordinatesTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_RESAMPLE_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const FLOAT*>(desc.Scales), desc.DimensionCount)),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const FLOAT*>(desc.InputPixelOffsets), desc.DimensionCount)),
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const FLOAT*>(desc.OutputPixelOffsets), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_SLICE_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowOffsets), desc.DimensionCount)),
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowSizes), desc.DimensionCount)),
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const INT*>(desc.InputWindowStrides), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_ADAM_OPTIMIZER_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputParametersTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputFirstMomentTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputSecondMomentTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.GradientTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.TrainingStepTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputParametersTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputFirstMomentTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputSecondMomentTensor))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.LearningRate))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta1))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta2))),
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.OutOfBoundsInputValue))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
};
}
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.BatchDimensionCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@ -1001,6 +1188,14 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DE
OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_CELU_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX_OPERATOR_DESC& desc)
{
return {
@ -1165,6 +1360,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA;
@ -1185,6 +1382,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_CONVOLUTION: return DML_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_GEMM: return DML_GEMM_OPERATOR_SCHEMA;
case DML_OPERATOR_REDUCE: return DML_REDUCE_OPERATOR_SCHEMA;
case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA;
case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA;
@ -1249,7 +1448,23 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_RELU_GRAD: return DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING_GRAD: return DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING_GRAD: return DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_RANDOM_GENERATOR: return DML_RANDOM_GENERATOR_OPERATOR_SCHEMA;
case DML_OPERATOR_NONZERO_COORDINATES: return DML_NONZERO_COORDINATES_OPERATOR_SCHEMA;
case DML_OPERATOR_RESAMPLE_GRAD: return DML_RESAMPLE_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA;
@ -1345,6 +1560,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA,
@ -1425,6 +1648,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_REDUCE_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_REDUCE_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ARGMIN:
return AbstractOperatorDesc(
&DML_ARGMIN_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ARGMIN_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ARGMAX:
return AbstractOperatorDesc(
&DML_ARGMAX_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ARGMAX_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_AVERAGE_POOLING:
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING_OPERATOR_SCHEMA,
@ -1681,10 +1912,74 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_BIT_AND:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_BIT_OR:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_RELU_GRAD:
return AbstractOperatorDesc(
&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_AVERAGE_POOLING_GRAD:
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MAX_POOLING_GRAD:
return AbstractOperatorDesc(
&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MAX_POOLING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_RANDOM_GENERATOR:
return AbstractOperatorDesc(
&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_RANDOM_GENERATOR_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_NONZERO_COORDINATES:
return AbstractOperatorDesc(
&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_NONZERO_COORDINATES_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_RESAMPLE_GRAD:
return AbstractOperatorDesc(
&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_RESAMPLE_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_SLICE_GRAD:
return AbstractOperatorDesc(
&DML_SLICE_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_SLICE_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ADAM_OPTIMIZER:
return AbstractOperatorDesc(
&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ADAM_OPTIMIZER_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ROI_ALIGN:
return AbstractOperatorDesc(
&DML_ROI_ALIGN_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ROI_ALIGN_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_GATHER_ND1:
return AbstractOperatorDesc(
&DML_GATHER_ND1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_GATHER_ND1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_ELU_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_CELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_CELU_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_CELU_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_HARDMAX:
return AbstractOperatorDesc(
&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA,

View file

@ -5,20 +5,12 @@
#include "MLOperatorAuthorImpl.h"
#include "FusedGraphKernel.h"
#include "GraphKernelHelper.h"
using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml
{
template <typename T>
static T AlignToPow2(T offset, T alignment)
{
static_assert(std::is_unsigned_v<T>);
assert(alignment != 0);
assert((alignment & (alignment - 1)) == 0);
return (offset + alignment - 1) & ~(alignment - 1);
}
class FusedGraphKernel : public onnxruntime::OpKernel
{
public:
@ -73,42 +65,16 @@ namespace Dml
const uint32_t graphInputCount = kernelInfo.GetInputCount();
auto gpuGraphInputConstnessGetter = [&kernelInfo, &fusedNodeInputDefs, &transferredInitializerMap](uint32_t index)
{
// Transferred initializers are uploaded to GPU memory
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
if (iter != transferredInitializerMap.end())
{
return true;
}
// If an initializer wasn't transferred, the constant input may be available from ORT
const onnxruntime::Tensor* inputTensor = nullptr;
if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr)
{
return false;
}
// Check that the constant ORT input is in GPU memory
if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput)
{
return false;
}
return true;
};
m_inputsConstant.resize(graphInputCount);
for (uint32_t i = 0; i < graphInputCount; ++i)
{
m_inputsConstant[i] = gpuGraphInputConstnessGetter(i);
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap);
}
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
kernelInfo,
m_inputsConstant,
m_inputsConstant.data(),
m_inputsConstant.size(),
transferredInitializerMap,
graph,
fusedNodeInputDefs,
@ -117,116 +83,27 @@ namespace Dml
device.Get(),
m_executionHandle);
// Determine the last input which uses an initializer, so initializers can be freed incrementally
// while processing each input in order.
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
for (uint32_t i = 0; i < graphInputCount; i++)
{
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end())
{
initializerToLastInputIndexMap[&iter->second] = i;
}
}
// Walk through each graph edge and mark used inputs
m_inputsUsed.assign(graphInputCount, false);
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges)
{
m_inputsUsed[edge.GraphInputIndex] = true;
}
// Populate input bindings for operator initialization
std::vector<ComPtr<ID3D12Resource>> initInputResources; // For lifetime control
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initInputResources; // For lifetime control
std::vector<DML_BUFFER_BINDING> initInputBindings(graphInputCount);
m_nonOwnedGraphInputsFromInitializers.resize(graphInputCount);
std::vector<ComPtr<ID3D12Resource>> initializeResourceRefs;
for (uint32_t i = 0; i < initInputBindings.size(); i++)
{
// If the input isn't actually used by the graph, nothing ever needs to be bound (either for
// initialization or execution). So just throw away the transferred initializer and skip this input.
if (!m_inputsUsed[i])
{
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
continue;
}
// Look for the initializer among those transferred from the graph during partitioning
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end())
{
std::byte* tensorPtr = nullptr;
size_t tensorByteSize = 0;
std::unique_ptr<std::byte[]> unpackedTensor;
auto& initializer = iter->second;
// The tensor may be stored as raw data or in typed fields.
if (initializer.has_raw_data())
{
tensorPtr = (std::byte*)(initializer.raw_data().c_str());
tensorByteSize = initializer.raw_data().size();
}
else
{
std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer);
tensorPtr = unpackedTensor.get();
}
// Tensor sizes in DML must be a multiple of 4 bytes large.
tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
if (!m_inputsConstant[i])
{
// Store the resource to use during execution
ComPtr<ID3D12Resource> defaultBuffer = CreateResource(tensorPtr, tensorByteSize);
m_nonOwnedGraphInputsFromInitializers[i] = defaultBuffer;
initializeResourceRefs.push_back(std::move(defaultBuffer));
}
else
{
ComPtr<ID3D12Resource> initializeInputBuffer;
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
if (m_provider->IsMcdmDevice())
{
initializeInputBuffer = CreateResource(tensorPtr, tensorByteSize);
}
else
{
initializeInputBuffer = CreateCpuResource(tensorPtr, tensorByteSize);
}
// Set the binding for operator initialization to the buffer
initInputBindings[i].Buffer = initializeInputBuffer.Get();
initInputBindings[i].SizeInBytes = tensorByteSize;
initializeResourceRefs.push_back(std::move(initializeInputBuffer));
}
// Free the initializer if this is the last usage of it.
if (initializerToLastInputIndexMap[&initializer] == i)
{
transferredInitializerMap.erase(iter);
}
}
else if (m_inputsConstant[i])
{
const onnxruntime::Tensor* inputTensor = nullptr;
THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor));
uint64_t allocId;
UnwrapTensor(inputTensor, &initInputBindings[i].Buffer, &allocId);
initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width;
initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference
initInputResources.push_back(initInputBindings[i].Buffer);
}
}
// All initializers should have been consumed and freed above
assert(transferredInitializerMap.empty());
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initializeResourceRefs;
GraphKernelHelper::PopulateInputBindings(
m_provider.Get(),
m_winmlProvider.Get(),
m_inputsConstant,
kernelInfo,
graphDesc,
fusedNodeInputDefs,
m_inputsUsed,
initInputBindings,
initInputResources,
m_nonOwnedGraphInputsFromInitializers,
initializeResourceRefs,
transferredInitializerMap);
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
@ -234,38 +111,15 @@ namespace Dml
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
std::vector<DML_GRAPH_EDGE_DESC> dmlIntermediateEdges(graphDesc.intermediateEdges.size());
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{ graphDesc.nodes[i].op.Get() };
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{ DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i] };
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
{
dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i] };
}
for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
{
dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i] };
}
for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
{
dmlIntermediateEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i] };
}
DML_GRAPH_DESC dmlGraphDesc = {};
dmlGraphDesc.InputCount = graphInputCount;
dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount();
dmlGraphDesc.NodeCount = gsl::narrow_cast<uint32_t>(dmlGraphNodes.size());
dmlGraphDesc.Nodes = dmlGraphNodes.data();
dmlGraphDesc.InputEdgeCount = gsl::narrow_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
GraphKernelHelper::ConvertGraphDesc(
graphDesc,
dmlGraphDesc,
kernelInfo,
dmlOperatorGraphNodes,
dmlGraphNodes,
dmlInputEdges,
dmlOutputEdges,
dmlIntermediateEdges);
DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE;
if (graphDesc.reuseCommandList)
@ -533,10 +387,10 @@ namespace Dml
const onnxruntime::Tensor* tensor = kernelContext->Input<onnxruntime::Tensor>(i);
uint64_t allocId;
UnwrapTensor(tensor, &inputBindings[i].Buffer, &allocId);
GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId);
inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId);
inputBindings[i].Buffer->Release(); // Avoid holding an additional reference
inputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
inputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]};
m_inputBindingAllocIds[i] = allocId;
}
@ -570,10 +424,10 @@ namespace Dml
);
uint64_t allocId;
UnwrapTensor(tensor, &outputBindings[i].Buffer, &allocId);
GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId);
outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId);
outputBindings[i].Buffer->Release(); // Avoid holding an additional reference
outputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
outputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]};
m_outputBindingAllocIds[i] = allocId;
}
@ -623,106 +477,6 @@ namespace Dml
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
}
void UnwrapTensor(const onnxruntime::Tensor* tensor, ID3D12Resource** resource, uint64_t* allocId) const
{
IUnknown* allocationUnk = static_cast<IUnknown*>(const_cast<void*>(tensor->DataRaw()));
ComPtr<IUnknown> resourceUnk;
m_winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk);
*allocId = m_winmlProvider->TryGetPooledAllocationId(allocationUnk, 0);
THROW_IF_FAILED(resourceUnk->QueryInterface(resource));
}
ComPtr<ID3D12Resource> CreateResource(const std::byte* tensorPtr, size_t tensorByteSize) const
{
ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_DEFAULT,
D3D12_CPU_PAGE_PROPERTY_UNKNOWN,
D3D12_MEMORY_POOL_UNKNOWN,
0,
0
};
D3D12_RESOURCE_DESC resourceDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{ 1, 0 },
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
};
ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())
));
THROW_IF_FAILED(m_provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize));
return buffer;
}
ComPtr<ID3D12Resource> CreateCpuResource(const std::byte* tensorPtr, size_t tensorByteSize) const
{
ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_CUSTOM,
D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE,
D3D12_MEMORY_POOL_L0,
0,
0
};
D3D12_RESOURCE_DESC resourceDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{ 1, 0 },
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
};
ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())
));
// Map the buffer and copy the data
void* bufferData = nullptr;
D3D12_RANGE range = {0, tensorByteSize};
THROW_IF_FAILED(buffer->Map(0, &range, &bufferData));
memcpy(bufferData, tensorPtr, tensorByteSize);
buffer->Unmap(0, &range);
return buffer;
}
ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
std::vector<bool> m_inputsUsed;
const void* m_executionHandle = nullptr;

View file

@ -60,7 +60,8 @@ namespace Dml::GraphDescBuilder
GraphDesc BuildGraphDesc(
const onnxruntime::OpKernelInfo& kernelInfo,
gsl::span<const uint8_t> isConstGpuGraphInput,
const uint8_t* isConstGpuGraphInput,
const size_t isConstGpuGraphInputCount,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
@ -226,7 +227,7 @@ namespace Dml::GraphDescBuilder
graphInputEdges.push_back(edge);
// If this is a constant input, set the appropriate flags on the desc
if (isConstGpuGraphInput[fusedNodeInputIndex])
if (fusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[fusedNodeInputIndex])
{
DmlBufferTensorDesc* tensorDesc = inputTensorDescs[inputIndex];

View file

@ -41,7 +41,8 @@ namespace Dml
GraphDesc BuildGraphDesc(
const onnxruntime::OpKernelInfo& kernelInfo,
gsl::span<const uint8_t> isConstGpuGraphInput,
const uint8_t* isConstGpuGraphInput,
const size_t isConstGpuGraphInputCount,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,

View file

@ -0,0 +1,313 @@
#include "precomp.h"
#include "GraphKernelHelper.h"
namespace Dml
{
namespace GraphKernelHelper
{
Microsoft::WRL::ComPtr<ID3D12Resource>
CreateResource(
Dml::IExecutionProvider* provider,
const std::byte* tensorPtr,
size_t tensorByteSize)
{
Microsoft::WRL::ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0};
D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{1, 0},
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
Microsoft::WRL::ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())));
THROW_IF_FAILED(provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize));
return buffer;
}
Microsoft::WRL::ComPtr<ID3D12Resource>
CreateCpuResource(
Dml::IExecutionProvider* provider,
const std::byte* tensorPtr,
size_t tensorByteSize)
{
Microsoft::WRL::ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, D3D12_MEMORY_POOL_L0, 0, 0};
D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{1, 0},
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
Microsoft::WRL::ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())));
// Map the buffer and copy the data
void* bufferData = nullptr;
D3D12_RANGE range = {0, tensorByteSize};
THROW_IF_FAILED(buffer->Map(0, &range, &bufferData));
memcpy(bufferData, tensorPtr, tensorByteSize);
buffer->Unmap(0, &range);
return buffer;
}
void UnwrapTensor(
IWinmlExecutionProvider* winmlProvider,
const onnxruntime::Tensor* tensor,
ID3D12Resource** resource,
uint64_t* allocId)
{
IUnknown* allocationUnk = static_cast<IUnknown*>(const_cast<void*>(tensor->DataRaw()));
Microsoft::WRL::ComPtr<IUnknown> resourceUnk;
winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk);
*allocId = winmlProvider->TryGetPooledAllocationId(allocationUnk, 0);
THROW_IF_FAILED(resourceUnk->QueryInterface(resource));
}
bool GetGraphInputConstness(
uint32_t index,
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
{
// Transferred initializers are uploaded to GPU memory
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
if (iter != transferredInitializerMap.end())
{
return true;
}
// If an initializer wasn't transferred, the constant input may be available from ORT
const onnxruntime::Tensor* inputTensor = nullptr;
if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr)
{
return false;
}
// Check that the constant ORT input is in GPU memory
if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput)
{
return false;
}
return true;
};
std::vector<std::vector<std::byte>> PopulateInputBindings(
Dml::IExecutionProvider* provider,
IWinmlExecutionProvider* winmlProvider,
const std::vector<uint8_t>& inputsConstant,
const onnxruntime::OpKernelInfo& kernelInfo,
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
_Out_ std::vector<bool>& inputsUsed,
_Out_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
_Out_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
_Out_ std::vector<ComPtr<ID3D12Resource>>& nonOwnedGraphInputsFromInitializers,
_Out_ std::vector<ComPtr<ID3D12Resource>>& initializeResourceRefs,
_Inout_ std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
{
std::vector<std::vector<std::byte>> inputRawData;
const uint32_t graphInputCount = kernelInfo.GetInputCount();
// Determine the last input which uses an initializer, so initializers can be freed incrementally
// while processing each input in order.
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
for (uint32_t i = 0; i < graphInputCount; i++)
{
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end()) {
initializerToLastInputIndexMap[&iter->second] = i;
}
}
// Walk through each graph edge and mark used inputs
inputsUsed.assign(graphInputCount, false);
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) {
inputsUsed[edge.GraphInputIndex] = true;
}
for (uint32_t i = 0; i < initInputBindings.size(); i++)
{
// If the input isn't actually used by the graph, nothing ever needs to be bound (either for
// initialization or execution). So just throw away the transferred initializer and skip this input.
if (!inputsUsed[i])
{
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
inputRawData.push_back(std::vector<std::byte>());
continue;
}
// Look for the initializer among those transferred from the graph during partitioning
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end())
{
std::byte* tensorPtr = nullptr;
size_t tensorByteSize = 0;
std::unique_ptr<std::byte[]> unpackedTensor;
auto& initializer = iter->second;
// The tensor may be stored as raw data or in typed fields.
if (initializer.has_raw_data())
{
tensorPtr = (std::byte*)(initializer.raw_data().c_str());
tensorByteSize = initializer.raw_data().size();
}
else
{
std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer);
tensorPtr = unpackedTensor.get();
}
// Tensor sizes in DML must be a multiple of 4 bytes large.
tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
inputRawData.push_back(std::vector<std::byte>(tensorPtr, tensorPtr + tensorByteSize));
if (!inputsConstant[i])
{
// Store the resource to use during execution
ComPtr<ID3D12Resource> defaultBuffer = CreateResource(provider, tensorPtr, tensorByteSize);
nonOwnedGraphInputsFromInitializers[i] = defaultBuffer;
initializeResourceRefs.push_back(std::move(defaultBuffer));
}
else
{
ComPtr<ID3D12Resource> initializeInputBuffer;
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
if (provider->IsMcdmDevice())
{
initializeInputBuffer = CreateResource(provider, tensorPtr, tensorByteSize);
}
else
{
initializeInputBuffer = CreateCpuResource(provider, tensorPtr, tensorByteSize);
}
// Set the binding for operator initialization to the buffer
initInputBindings[i].Buffer = initializeInputBuffer.Get();
initInputBindings[i].SizeInBytes = tensorByteSize;
initializeResourceRefs.push_back(std::move(initializeInputBuffer));
}
// Free the initializer if this is the last usage of it.
if (initializerToLastInputIndexMap[&initializer] == i)
{
transferredInitializerMap.erase(iter);
}
}
else if (inputsConstant[i])
{
const onnxruntime::Tensor* inputTensor = nullptr;
THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor));
const std::byte* tensorData = reinterpret_cast<const std::byte*>(inputTensor->DataRaw());
inputRawData.push_back(
std::vector<std::byte>(tensorData, tensorData + inputTensor->SizeInBytes()));
uint64_t allocId;
UnwrapTensor(winmlProvider, inputTensor, &initInputBindings[i].Buffer, &allocId);
initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width;
initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference
initInputResources.push_back(initInputBindings[i].Buffer);
}
else
{
inputRawData.push_back(std::vector<std::byte>());
}
}
// All initializers should have been consumed and freed above
assert(transferredInitializerMap.empty());
return inputRawData;
}
void ConvertGraphDesc(
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
const onnxruntime::OpKernelInfo& kernelInfo,
_Out_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
_Out_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges)
{
const uint32_t graphInputCount = kernelInfo.GetInputCount();
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{graphDesc.nodes[i].op.Get()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
{
dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]};
}
for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
{
dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]};
}
for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
{
dmlIntermediateEdges[i] =
DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]};
}
dmlGraphDesc.InputCount = graphInputCount;
dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount();
dmlGraphDesc.NodeCount = gsl::narrow_cast<uint32_t>(dmlGraphNodes.size());
dmlGraphDesc.Nodes = dmlGraphNodes.data();
dmlGraphDesc.InputEdgeCount = gsl::narrow_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -0,0 +1,67 @@
#include "GraphDescBuilder.h"
namespace Dml
{
namespace GraphKernelHelper
{
using namespace Windows::AI::MachineLearning::Adapter;
template <typename T>
static T AlignToPow2(T offset, T alignment)
{
static_assert(std::is_unsigned_v<T>);
assert(alignment != 0);
assert((alignment & (alignment - 1)) == 0);
return (offset + alignment - 1) & ~(alignment - 1);
}
Microsoft::WRL::ComPtr<ID3D12Resource>
CreateResource(
Dml::IExecutionProvider* provider,
const std::byte* tensorPtr,
size_t tensorByteSize);
Microsoft::WRL::ComPtr<ID3D12Resource>
CreateCpuResource(
Dml::IExecutionProvider* provider,
const std::byte* tensorPtr,
size_t tensorByteSize);
void UnwrapTensor(
IWinmlExecutionProvider* winmlProvider,
const onnxruntime::Tensor* tensor,
ID3D12Resource** resource,
uint64_t* allocId);
bool GetGraphInputConstness(
uint32_t index,
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
std::vector<std::vector<std::byte>> PopulateInputBindings(
Dml::IExecutionProvider* provider,
IWinmlExecutionProvider* winmlProvider,
const std::vector<uint8_t>& inputsConstant,
const onnxruntime::OpKernelInfo& kernelInfo,
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
_Out_ std::vector<bool>& inputsUsed,
_Out_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
_Out_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
_Out_ std::vector<ComPtr<ID3D12Resource>>& nonOwnedGraphInputsFromInitializers,
_Out_ std::vector<ComPtr<ID3D12Resource>>& initializeResourceRefs,
_Inout_ std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
void ConvertGraphDesc(
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
const onnxruntime::OpKernelInfo& kernelInfo,
_Out_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
_Out_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -139,7 +139,7 @@ namespace Dml
}
};
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats)
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask)
{
if (arg->Exists())
{
@ -151,16 +151,23 @@ namespace Dml
{
// TODO: Remove this by handling zeroing on the output of fused graph nodes and handling of non-float
// types in DML's identity operator, which is used for strided copies.
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::UInt64 ||
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::Int64)
MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));
if (mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64)
{
return false;
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
{
return false;
}
}
if (requiresFloatFormats)
{
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float &&
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float16)
if (mlDataType != MLOperatorTensorDataType::Float &&
mlDataType != MLOperatorTensorDataType::Float16)
{
return false;
}
@ -172,14 +179,19 @@ namespace Dml
return true;
}
bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration)
bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration, uint32_t supportedDeviceDataTypeMask)
{
for (size_t i = 0; i < node.InputDefs().size(); ++i)
{
bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) !=
registration.requiredConstantCpuInputs.end();
if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs))
if (!isConstantCpuInput &&
!NodeArgSupportedInGraph(
node.InputDefs()[i],
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
supportedDeviceDataTypeMask
))
{
return false;
}
@ -187,7 +199,11 @@ namespace Dml
for (auto arg : node.OutputDefs())
{
if (!NodeArgSupportedInGraph(arg, registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs))
if (!NodeArgSupportedInGraph(
arg,
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
supportedDeviceDataTypeMask
))
{
return false;
}
@ -220,97 +236,155 @@ namespace Dml
bool DoesNodeContainSupportedDataTypes(
const onnxruntime::Node& node,
bool allow64BitInputThroughStrides,
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true
_In_opt_ const InternalRegistrationInfo* regInfo,
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
)
{
THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
bool prefer64BitTensorsDirectly = false;
bool supportedWith64BitTensorsVia32BitStrides = false;
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
if (regInfo != nullptr)
{
// Read the operator flags for handling 64-bit tensors and whether it's allowed to fall back
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
// in this particular call, then the operator-specific flags do not matter as the caller has
// disabled 64-bit support.
if (allow64BitInputThroughStrides)
{
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
}
// Collect the list of CPU-bound input tensors, needed when checking 64-bit fallback
// or for other data types like int-8 which may be supported for CPU inputs but not
// GPU inputs.
auto inputDefinitions = node.InputDefs();
for (uint32_t i : regInfo->requiredConstantCpuInputs)
{
if (i < inputDefinitions.size())
{
constantCpuInputs.push_back(inputDefinitions[i]);
}
}
}
// Assume data types are supported until proven otherwise.
bool nodeContainsSupportedDataTypes = true;
// Callback to check each node's data type.
// Callback to check each node's data type against registered operator support.
std::function<void(const onnxruntime::NodeArg& nodeArg, bool isInput)> nodeCallback = [&](const onnxruntime::NodeArg& nodeArg, bool isInput) -> void
{
// Get the tensor element data type for this node, comparing against what the device actually supports.
// Use the enumeration from the proto instead of nodeArg.Type() which returns a string.
// Reject node if undefined data type or non-tensor, as DML cannot handle it.
MLOperatorTensorDataType onnxElementType;
if (TryGetTensorDataType(nodeArg, &onnxElementType))
if (!TryGetTensorDataType(nodeArg, &onnxElementType))
{
DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType);
if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN)
// We shouldn't have arrived here because (1) no DML operators should have been
// registered which use non-tensor types (2) ONNX validation should have already
// been done, checking for the right kind of inputs and attributes. In theory,
// this branch could be reached with a bad custom operator or malformed file. If
// a legitimate case reaches here and DML needs to support a new input/output type
// besides tensors, then remove the assert.
assert(false);
nodeContainsSupportedDataTypes = false;
return;
}
// Reject node for unknown DML data types.
DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType);
if (dmlElementType == DML_TENSOR_DATA_TYPE_UNKNOWN)
{
nodeContainsSupportedDataTypes = false;
return;
}
// Succeed if the tensor is CPU-bound, as the CPU-side reading code is generic enough
// to handle multiple types regardless of GPU capability (typically these are just
// scalars or simple 1D arrays).
bool isConstantCpuInput = isInput && std::find(constantCpuInputs.begin(), constantCpuInputs.end(), &nodeArg) != constantCpuInputs.end();
if (isConstantCpuInput)
{
// Leave nodeContainsSupportedDataTypes alone.
return;
}
// If this operator implements 64-bit support in terms of strided 32-bit tensors,
// then the data type needs to be remapped, regardless of whether input or output.
//
// Some operators can fairly safely implement 64-bit tensors in terms of
// strided 32-bit tensors regardless of input tensor's execution provider
// because the indices measure along a single axis and should fall within
// the range of an int32/uint32.
//
// Currently all DML kernels outputting int64 and uint64 are expected to
// not *introduce* values out of range, which allows the temporary trick
// using strides to emulate 64 bit tensors to work. If the source is a CPU
// operator, graph input or initializer, it's not safe to assume the input
// can be represented with 32 bits.
//
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask);
if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
if (((1 << dmlElementType) & supportedDeviceDataTypeMask) == 0)
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();
// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
{
nodeContainsSupportedDataTypes = false;
return;
}
}
}
// Reject node if the data type is unsupported by the device.
if (!((1 << dmlElementType) & supportedDeviceDataTypeMask))
{
nodeContainsSupportedDataTypes = false;
return;
}
// Otherwise the node supports the tensor data type.
};
// Check whether the node uses any data types which are unsupported by the device.
node.ForEachDef(nodeCallback);
// DML kernels supporting int64 and uint64 are expected to not *introduce* values out of range, which allows
// the temporary trick using strides to emulate 64 bit tensors to work. If the source is a CPU operator,
// graph input or initializer, it's not safe to assume the input can be represented with 32 bits.
if (regInfo)
{
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
{
const auto* arg = node.InputDefs()[i];
MLOperatorTensorDataType onnxElementType;
if (arg->Exists() && TryGetTensorDataType(*arg, &onnxElementType))
{
if (((onnxElementType == MLOperatorTensorDataType::UInt64) || (onnxElementType == MLOperatorTensorDataType::Int64)))
{
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the map. In this case or if the input comes from a CPU partition, it might be
// out of range.
const std::string& argName = arg->Name();
// Check if the operator handles the input on the CPU as a constant input
bool isConstantCpuInput = std::find(regInfo->requiredConstantCpuInputs.begin(), regInfo->requiredConstantCpuInputs.end(), i) !=
regInfo->requiredConstantCpuInputs.end();
if (!isConstantCpuInput)
{
if (!allow64BitInputThroughStrides)
{
nodeContainsSupportedDataTypes = false;
break;
}
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
{
nodeContainsSupportedDataTypes = false;
break;
}
}
}
}
}
}
return nodeContainsSupportedDataTypes;
}
bool IsNodeSupportedByDml(
const onnxruntime::Node& node,
const onnxruntime::Node& node,
const onnxruntime::KernelRegistry& registry,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
const InternalRegistrationInfoMap& internalRegInfoMap,
bool allow64BitInputThroughStrides,
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap
)
)
{
THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
const onnxruntime::KernelCreateInfo* createInfo;
Status st = registry.TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo);
if (!st.IsOK()) {
return false;
if (!st.IsOK())
{
return false;
}
auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get());
@ -337,7 +411,7 @@ namespace Dml
// Gets properties of the registration for a node
void GetRegistrationProperties(
const onnxruntime::GraphViewer& graph,
const onnxruntime::Node& node,
const onnxruntime::Node& node,
const std::vector<const onnxruntime::KernelRegistry*>& dmlRegistries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
const InternalRegistrationInfoMap& internalRegInfoMap,
@ -368,7 +442,9 @@ namespace Dml
// which is required for MLGraph compilation.
const onnxruntime::KernelCreateInfo* createInfo;
if (!registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo).IsOK())
continue;
{
continue;
}
auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get());
if (regInfoIter != internalRegInfoMap.end())
@ -376,7 +452,7 @@ namespace Dml
auto internalRegInfo = regInfoIter->second;
if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration &&
NodeTensorTypesSupportedInGraph(node, *internalRegInfo))
NodeTensorTypesSupportedInGraph(node, *internalRegInfo, supportedDeviceDataTypeMask))
{
bool requiredCpuInputsConstant = true;
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
@ -912,7 +988,8 @@ namespace Dml
std::move(graphNodePropertyMap),
registryForPartitionKernels,
partitionKernelPrefix,
transferredInitializerMap));
transferredInitializerMap
));
}
return result;

View file

@ -61,7 +61,7 @@ namespace Dml
);
bool IsNodeSupportedByDml(
const onnxruntime::Node& node,
const onnxruntime::Node& node,
const onnxruntime::KernelRegistry& registry,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap,

View file

@ -87,7 +87,7 @@ namespace Dml
if (!IsNodeSupportedByDml(
node,
*registry,
m_providerImpl->GetSuppportedDeviceDataTypeMask(),
m_providerImpl->GetSupportedDeviceDataTypeMask(),
*m_providerImpl->GetInternalRegistrationInfoMap().get(),
allow64BitInputThroughStrides,
nullptr))

View file

@ -133,12 +133,21 @@ namespace Dml
));
}
void DmlOperator::Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t minDimensionCount
)
{
Initialize(kernelInfo, std::nullopt, std::nullopt, std::nullopt, std::nullopt, minDimensionCount);
}
void DmlOperator::Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices,
const std::optional<gsl::span<const uint32_t>> inputShape,
const std::optional<gsl::span<const uint32_t>> outputShape
const std::optional<gsl::span<const uint32_t>> outputShape,
uint32_t minDimensionCount
)
{
if (kernelInputIndices)
@ -179,7 +188,7 @@ namespace Dml
TensorAxis::W,
TensorAxis::RightAligned,
inputShape,
NchwDimensionCount));
minDimensionCount));
}
}
@ -200,7 +209,8 @@ namespace Dml
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
outputShape));
outputShape,
minDimensionCount));
}
}
}
@ -373,6 +383,34 @@ namespace Dml
));
}
void DmlOperator::Remap64bitDmlDataTypesTo32bit()
{
for (auto& tensor : m_inputTensorDescs)
{
tensor.Remap64bitDmlDataTypeTo32bit();
}
for (auto& tensor : m_outputTensorDescs)
{
tensor.Remap64bitDmlDataTypeTo32bit();
}
}
void DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded()
{
// Conditionally remap 64-bit data types to strided 32-bit if DML does not
// support 64-bit data types directly on the device.
uint32_t deviceTypeMask = Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
uint32_t deviceTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_INT64) | (1 << DML_TENSOR_DATA_TYPE_UINT64);
// If the device doesn't support 64-bit tensors, fall back to 32-bit with strides.
if (!(deviceTypeMask & deviceTypeMask64bit))
{
Remap64bitDmlDataTypesTo32bit();
}
}
TensorDesc DmlOperator::CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,

View file

@ -29,12 +29,18 @@ namespace Dml
ComPtr<IUnknown> m_persistentResourcePoolingUnk; // Controls when the persistent resource is returned to the pool
std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
void Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t minDimensionCount
);
void Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices = std::nullopt,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices = std::nullopt,
const std::optional<gsl::span<const uint32_t>> inputShape = std::nullopt,
const std::optional<gsl::span<const uint32_t>> outputShape = std::nullopt
const std::optional<gsl::span<const uint32_t>> outputShape = std::nullopt,
uint32_t minDimensionCount = NchwDimensionCount
);
bool AllowHalfPrecisionComputation() const;
@ -77,6 +83,11 @@ namespace Dml
void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);
// Remap 64-bit data types to 32-bit via doubled strides.
// These should be called before GetDmlInputDescs or GetDmlOutputDescs.
void Remap64bitDmlDataTypesTo32bit();
void Remap64bitDmlDataTypesTo32bitIfNeeded();
TensorDesc CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,

View file

@ -30,6 +30,7 @@ public:
switch (operatorType)
{
case DML_OPERATOR_ACTIVATION_ELU:
case DML_OPERATOR_ACTIVATION_CELU:
operatorDesc.elu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
break;
@ -154,6 +155,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(HardSigmoid, DmlOperatorActivationTempla
DML_OP_DEFINE_CREATION_FUNCTION(Tanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_TANH>);
DML_OP_DEFINE_CREATION_FUNCTION(ScaledTanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SCALED_TANH>);
DML_OP_DEFINE_CREATION_FUNCTION(Relu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(Celu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_CELU>);
DML_OP_DEFINE_CREATION_FUNCTION(LeakyRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_LEAKY_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(PRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(ThresholdedRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU>);

View file

@ -42,7 +42,7 @@ public:
for (size_t i = 0; i < m_inputTensorDescs.size(); i++)
{
// DML doesn't support empty tensors for concat, so we ignore them
if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetDmlSizes()))
if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetSizes()))
{
inputDescs.push_back(m_inputTensorDescs[i].GetDmlDesc());
}

View file

@ -0,0 +1,176 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
{
public:
DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
: DmlOperator(kernelCreationContext),
EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "EinSum expects one output tensor.");
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(8), "Update this switch.");
switch (m_recognizedOperatorType)
{
case RecognizedOperatorType::Multiply:
{
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC operatorDesc = {};
operatorDesc.ATensor = &inputDescs[0];
operatorDesc.BTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &operatorDesc}, kernelCreationContext);
}
break;
case RecognizedOperatorType::MatMul:
case RecognizedOperatorType::MatMulTransposeA:
case RecognizedOperatorType::MatMulTransposeB:
{
DML_GEMM_OPERATOR_DESC operatorDesc = {};
operatorDesc.ATensor = &inputDescs[0];
operatorDesc.BTensor = &inputDescs[1];
// No operatorDesc.CTensor
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.Alpha = 1.0;
operatorDesc.Beta = 0.0;
operatorDesc.FusedActivation = nullptr;
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;
case RecognizedOperatorType::ReduceSum:
{
// Get how many axes are kept in the final output, either 0 or 1 supported
// meaning full reduction or partial with one dimension left. *It could be
// generalized to support any number of output dimensions, but it would need
// to accomodate for Transposition too if the output labels are reordered.
auto keptAxes = m_components.back().GetLabels(m_labelIndices);
assert(keptAxes.size() <= 1);
// DML expects output rank to match input rank (as if ONNX ReduceSum keepdims=1).
// So replace the existing tensor description with the input sizes, except that
// reduced dimensions have size 1.
std::vector<uint32_t> reducedAxes;
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
std::vector<uint32_t> outputSizes = inputSizes;
// Determine which axes are being reduced by taking the opposite of those kept.
uint32_t keptAxesMask = 0;
for (auto axis : keptAxes)
{
keptAxesMask |= (1 << axis);
}
for (uint32_t axis = 0, axisCount = static_cast<uint32_t>(outputSizes.size()); axis < axisCount; ++axis)
{
if (~keptAxesMask & (1<<axis))
{
reducedAxes.push_back(axis);
outputSizes[axis] = 1;
}
}
m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), inputSizes, std::nullopt, 0);
m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), outputSizes, std::nullopt, 0);
m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
DML_REDUCE_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Function = DML_REDUCE_FUNCTION_SUM;
operatorDesc.Axes = reducedAxes.data();
operatorDesc.AxisCount = gsl::narrow_cast<uint32_t>(reducedAxes.size());
SetDmlOperatorDesc({ DML_OPERATOR_REDUCE, &operatorDesc }, kernelCreationContext);
}
break;
case RecognizedOperatorType::Transpose:
case RecognizedOperatorType::Identity:
{
if (m_recognizedOperatorType == RecognizedOperatorType::Transpose)
{
// Transpose via input strides. The output tensor is not strided.
assert(m_components.front().GetDimensionCount() == m_components.back().GetDimensionCount());
auto originalStrides = m_inputTensorDescs.front().GetStrides();
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
std::vector<uint32_t> inputStrides(inputSizes.size());
// If there were no strides, compute them based in descending packed order
// based on the input sizes.
if (originalStrides.empty())
{
Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides);
}
else // Copy the original strides.
{
assert(originalStrides.size() >= inputStrides.size());
size_t offset = originalStrides.size() - inputStrides.size();
inputStrides.assign(originalStrides.begin() + offset, originalStrides.end());
}
// Remap transposed strides using the component labels from input to output.
auto labelIndices = m_components.back().GetLabels(m_labelIndices);
std::vector<uint32_t> newStrides(inputStrides.size());
std::vector<uint32_t> newSizes(inputStrides.size());
for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i)
{
uint32_t labelIndex = labelIndices[i];
assert(labelIndex < inputStrides.size());
newSizes[i] = inputSizes[labelIndex];
newStrides[i] = inputStrides[labelIndex];
}
// Override the initial input tensor with the new strides.
m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), newSizes, newStrides, 0);
m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newSizes, std::nullopt, 0);
m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
}
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &operatorDesc}, kernelCreationContext);
}
break;
default:
return;
}
}
};
void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;
MLOperatorAttributes attributes(context);
EinSumHelper helper(attributes);
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(8), "Verify this test still matches the switch above.");
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
}
DML_OP_DEFINE_CREATION_FUNCTION(Einsum12, VersionedKernel<DmlOperatorEinSum, 12>);
} // namespace Dml

View file

@ -462,6 +462,9 @@ public:
}
};
// Same operator signature as 11. Only difference is new type support
using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11;
class DmlOperatorElementwisePow : public DmlOperator
{
public:
@ -700,6 +703,8 @@ DML_OP_DEFINE_CREATION_FUNCTION(Erf, DmlOperatorElementwiseUnary<DM
// Binary operators:
DML_OP_DEFINE_CREATION_FUNCTION(Greater, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Less, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(GreaterOrEqual, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(LessOrEqual, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Equal, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(And, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Or, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC>);
@ -718,6 +723,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
// Operators with extra attributes:
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);

View file

@ -25,8 +25,8 @@ public:
TensorDesc inputTensorDesc =
TensorDesc(
kernelCreationContext.GetInputEdgeDescription(0).tensorDataType,
m_outputTensorDescs[0].GetDmlSizes(),
m_inputTensorDescs[0].GetDmlSizes(),
m_outputTensorDescs[0].GetSizes(),
m_inputTensorDescs[0].GetSizes(),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
@ -36,8 +36,8 @@ public:
TensorDesc outputTensorDesc =
TensorDesc(
kernelCreationContext.GetOutputEdgeDescription(0).tensorDataType,
m_outputTensorDescs[0].GetDmlSizes(),
m_outputTensorDescs[0].GetDmlSizes(),
m_outputTensorDescs[0].GetSizes(),
m_outputTensorDescs[0].GetSizes(),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,

View file

@ -16,19 +16,21 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "Gather expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Gather expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_GATHER_OPERATOR_DESC operatorDesc = {};
@ -52,20 +54,22 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {};
@ -89,29 +93,30 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
DML_GATHER_ND_OPERATOR_DESC operatorDesc = {};
DML_GATHER_ND1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
operatorDesc.BatchDimensionCount = m_batchCount;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND1, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};

View file

@ -22,6 +22,8 @@ public:
std::vector<std::optional<uint32_t>> inputIndices = { 0, 1 }; // The 3rd tensor ('output_shape') is not bound, just 'X' and 'I' indices.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_inputTensorDescs[1].ForceUnsignedDataType(); // MaxUnpool accepts uint32_t.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

View file

@ -37,7 +37,7 @@ public:
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
1, // minDimensionCount
0
);
@ -49,10 +49,12 @@ public:
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
1, // minDimensionCount
0
);
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
// Adjust the axis so it's in DML's terms rather than the original ONNX indexing.
uint32_t dmlAxis = GetDmlAdjustedAxis(
m_absoluteAxis,

View file

@ -89,18 +89,7 @@ public:
}
};
// A specific type of operation for registration.
template <uint32_t opsetVersion>
class DmlOperatorPaddingTemplate : public DmlOperatorPadding
{
public:
DmlOperatorPaddingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorPadding(kernelInfo, opsetVersion)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, DmlOperatorPaddingTemplate<7>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, DmlOperatorPaddingTemplate<11>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
} // namespace Dml

View file

@ -108,7 +108,14 @@ public:
if (hasOutputIndices || hasDilations)
{
DML_MAX_POOLING2_OPERATOR_DESC desc = {};
desc.OutputIndicesTensor = hasOutputIndices ? &outputDescs[1] : nullptr;
if (hasOutputIndices)
{
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_outputTensorDescs[1].ForceUnsignedDataType(); // MaxPool accepts uint32_t.
desc.OutputIndicesTensor = &outputDescs[1];
}
desc.Dilations = m_kernel.dilations;
SetOpDesc(desc);
}

View file

@ -22,17 +22,12 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
// Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output.
if ((function == DML_REDUCE_FUNCTION_ARGMAX) || (function == DML_REDUCE_FUNCTION_ARGMIN))
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}
std::vector<uint32_t> dmlAxes;
std::vector<DimensionType> reducedDims = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
int dimOffset = gsl::narrow_cast<int>(OperatorHelper::NchwDimensionCount - reducedDims.size());
int dimOffset = gsl::narrow_cast<int>(m_inputTensorDescs[0].GetDimensionCount() - reducedDims.size());
for (auto& dim : m_axes)
{
assert(dim < reducedDims.size()); // ReduceHelperBase already validated this.
reducedDims[dim] = 1;
dmlAxes.push_back(static_cast<uint32_t>(dim + dimOffset));
}
@ -62,15 +57,59 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_REDUCE_OPERATOR_DESC reduceDesc = {};
reduceDesc.InputTensor = inputDescs.data();
reduceDesc.OutputTensor = outputDescs.data();
reduceDesc.Function = function;
reduceDesc.Axes = dmlAxes.data();
reduceDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
// Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output.
if (function == DML_REDUCE_FUNCTION_ARGMAX)
{
DML_ARGMAX_OPERATOR_DESC argmaxDesc;
argmaxDesc.AxisDirection = static_cast<DML_AXIS_DIRECTION>(m_selectLastIndex);
argmaxDesc.InputTensor = inputDescs.data();
argmaxDesc.OutputTensor = outputDescs.data();
argmaxDesc.Axes = dmlAxes.data();
argmaxDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
// of each element. If the device directly supports 64-bit elements, then no need.
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMAX, &argmaxDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
else if (function == DML_REDUCE_FUNCTION_ARGMIN)
{
DML_ARGMIN_OPERATOR_DESC argminDesc;
argminDesc.AxisDirection = static_cast<DML_AXIS_DIRECTION>(m_selectLastIndex);
argminDesc.InputTensor = inputDescs.data();
argminDesc.OutputTensor = outputDescs.data();
argminDesc.Axes = dmlAxes.data();
argminDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
// of each element. If the device directly supports 64-bit elements, then no need.
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMIN, &argminDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
else
{
DML_REDUCE_OPERATOR_DESC reduceDesc = {};
reduceDesc.InputTensor = inputDescs.data();
reduceDesc.OutputTensor = outputDescs.data();
reduceDesc.Function = function;
reduceDesc.Axes = dmlAxes.data();
reduceDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
}
void Compute(const MLOperatorKernelContext& kernelContext) override

View file

@ -251,17 +251,6 @@ public:
}
};
// A specific type of operation for registration.
template <uint32_t OpsetVersion>
struct DmlOperatorResizeTemplate : public DmlOperatorResize
{
public:
DmlOperatorResizeTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorResize(kernelInfo, OpsetVersion)
{
}
};
void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;
@ -304,10 +293,10 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool*
*isSupported = true;
}
DML_OP_DEFINE_CREATION_FUNCTION(Resize10, DmlOperatorResizeTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(Resize11, DmlOperatorResizeTemplate<11>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, DmlOperatorResizeTemplate<7>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, DmlOperatorResizeTemplate<9>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, DmlOperatorResizeTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel<DmlOperatorResize, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel<DmlOperatorResize, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel<DmlOperatorResize, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel<DmlOperatorResize, 9>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel<DmlOperatorResize, 10>);
} // namespace Dml

View file

@ -47,6 +47,9 @@ public:
0
);
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_inputTensorDescs[1].ForceUnsignedDataType(); // DML operator accepts uint32_t.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

View file

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelper
{
public:
using Self = DmlOperatorRegionOfInterestAlign;
DmlOperatorRegionOfInterestAlign(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
: DmlOperator(kernelCreationContext),
RoiAlignHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "RoiAlign expects 3 input tensors.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "RoiAlign expects 1 output tensor.");
DmlOperator::Initialize(kernelCreationContext);
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_inputTensorDescs[2].ForceUnsignedDataType();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
constexpr NameAndIndex mapping[] =
{
{"max", DML_REDUCE_FUNCTION_MAX},
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
const auto reductionFunction = MapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.SpatialScaleX = spatialScale; // ONNX uses the same scale for X and Y.
operatorDesc.SpatialScaleY = spatialScale;
operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds.
operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput;
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = reductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
} // namespace Dml

View file

@ -26,7 +26,7 @@ public:
poolingDesc.ROITensor = &inputDescs[1];
poolingDesc.OutputTensor = &outputDescs[0];
poolingDesc.SpatialScale = m_spatialScale;
poolingDesc.PooledSize = { m_pooledSizeH, m_pooledSizeW };
poolingDesc.PooledSize = { m_outputSizeH, m_outputSizeW };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_POOLING, &poolingDesc };

View file

@ -23,7 +23,6 @@ public:
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size());
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
// When the indices tensor is empty, Scatter is basically Identity. But since DML doesn't support empty or null
// tensors, we have to special-case it outside of DML.
@ -31,6 +30,7 @@ public:
{
std::vector<std::optional<uint32_t>> kernelInputIndices(1, 0);
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
@ -49,14 +49,13 @@ public:
else
{
DmlOperator::Initialize(kernelCreationContext);
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 3);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
// Read the axis.
int onnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
uint32_t dmlAxis = GetDmlAdjustedAxis(onnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
@ -89,20 +88,16 @@ public:
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount);
DmlOperator::Initialize(kernelCreationContext);
size_t dimensionCountMax = std::max({dataDimensions.size(), updatesDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 3);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];

View file

@ -47,23 +47,12 @@ public:
}
};
// A specific type of operation for registration.
template <uint32_t opsetVersion>
class DmlOperatorSliceTemplate : public DmlOperatorSlice
{
public:
DmlOperatorSliceTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorSlice(kernelInfo, opsetVersion)
{
}
};
void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = (context->GetInputCount() <= 5);
}
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel<DmlOperatorSlice, 7> );
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel<DmlOperatorSlice, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel<DmlOperatorSlice, 11>);
} // namespace Dml

View file

@ -22,6 +22,8 @@ public:
std::vector<std::optional<uint32_t>> inputIndices = { 0 }; // Use only the first tensor. The second tensor is CPU-based.
std::vector<std::optional<uint32_t>> outputIndices = { 0, 1 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_outputTensorDescs[1].ForceUnsignedDataType();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
@ -70,19 +72,8 @@ private:
ComPtr<IDMLCompiledOperator> m_zeroOperator;
};
// A specific type of operation for registration.
template <uint32_t OpsetVersion>
class DmlOperatorTopKTemplate : public DmlOperatorTopK
{
public:
DmlOperatorTopKTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorTopK(kernelInfo, OpsetVersion)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(TopK7, DmlOperatorTopKTemplate<7>);
DML_OP_DEFINE_CREATION_FUNCTION(TopK10, DmlOperatorTopKTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(TopK11, DmlOperatorTopKTemplate<11>);
DML_OP_DEFINE_CREATION_FUNCTION(TopK7, VersionedKernel<DmlOperatorTopK, 7 >);
DML_OP_DEFINE_CREATION_FUNCTION(TopK10, VersionedKernel<DmlOperatorTopK, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(TopK11, VersionedKernel<DmlOperatorTopK, 11>);
} // namespace Dml

View file

@ -19,41 +19,25 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() >= 1);
DmlOperator::Initialize(kernelInfo);
const MLOperatorEdgeDescription inputEdgeDescription = kernelInfo.GetInputEdgeDescription(0);
const std::vector<uint32_t> originalSizes = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(m_permutations.size() == originalSizes.size());
// Calculate strides from original shape.
ML_CHECK_VALID_ARGUMENT(!originalSizes.empty());
std::vector<uint32_t> inputStrides(originalSizes.size());
inputStrides.back() = 1;
for (int i = gsl::narrow_cast<int>(inputStrides.size()) - 2; i >= 0; i--)
{
inputStrides[i] = inputStrides[i + 1] * gsl::narrow_cast<uint32_t>(originalSizes[i + 1]);
}
Dml::GetDescendingPackedStrides(originalSizes, /*out*/ inputStrides);
const int leadingDims = gsl::narrow_cast<int32_t>(m_inputTensorDescs.front().GetDimensionCount() - originalSizes.size());
std::vector<uint32_t> sizes(m_inputTensorDescs.front().GetDimensionCount());
std::vector<uint32_t> strides(m_inputTensorDescs.front().GetDimensionCount());
// Fill leading tensor desc sizes/strides with defaults.
for (int dimDML = 0; dimDML < leadingDims; ++dimDML)
{
sizes[dimDML] = 1;
strides[dimDML] = 0;
}
std::vector<uint32_t> sizes(inputStrides.size());
std::vector<uint32_t> strides(inputStrides.size());
// Permute the shape and strides.
for (int dimInput = 0, dimCount = gsl::narrow_cast<int>(originalSizes.size()); dimInput < dimCount; ++dimInput)
{
int dimDML = dimInput + leadingDims;
int dimPermuted = m_permutations[dimInput];
ML_CHECK_VALID_ARGUMENT(gsl::narrow_cast<size_t>(dimPermuted) < originalSizes.size());
sizes[dimDML] = gsl::narrow_cast<int32_t>(originalSizes[dimPermuted]);
strides[dimDML] = inputStrides[dimPermuted];
sizes[dimInput] = originalSizes[dimPermuted];
strides[dimInput] = inputStrides[dimPermuted];
}
// Override the initial tensor descs. The output tensor is not strided.

View file

@ -34,22 +34,28 @@ enum class SupportedTensorDataTypes : uint32_t
UInt64 = 1<<13,
Complex64 = 1<<14,
Complex128 = 1<<15,
Int8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Int32to64 = UInt32|Int32|UInt64|Int64,
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
NumericDefault = Int8to32|Float16to32,
NumericDefault = Ints8to32|Float16to32,
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
Ints8Bit = UInt8|Int8,
Ints16Bit = UInt16|Int16,
Ints32Bit = UInt32|Int32,
All = static_cast<uint32_t>(-1),
};
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
enum class DmGraphSupport
enum class DmlGraphSupport : uint32_t
{
Supported = 0,
NotSupported = 1,
Supported = 0,
NotSupported = 1,
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
};
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
struct OperatorRegistrationInformation
{
@ -63,7 +69,7 @@ struct OperatorRegistrationInformation
gsl::span<char const* const> tensorTypeNames;
gsl::span<const SupportedTensorDataTypes> supportedTensorDataTypes;
DmGraphSupport DmGraphSupport;
DmlGraphSupport dmlGraphSupport;
std::pair<std::array<const uint32_t, 4>, int> requiredConstantCpuInputs = {{}, 0};
@ -86,6 +92,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
@ -117,8 +124,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Ceil);
DML_OP_EXTERN_CREATION_FUNCTION(Floor);
DML_OP_EXTERN_CREATION_FUNCTION(Clip7);
DML_OP_EXTERN_CREATION_FUNCTION(Clip11);
DML_OP_EXTERN_CREATION_FUNCTION(Clip12);
DML_OP_EXTERN_CREATION_FUNCTION(Greater);
DML_OP_EXTERN_CREATION_FUNCTION(Less);
DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual);
DML_OP_EXTERN_CREATION_FUNCTION(LessOrEqual);
DML_OP_EXTERN_CREATION_FUNCTION(Equal);
DML_OP_EXTERN_CREATION_FUNCTION(Not);
DML_OP_EXTERN_CREATION_FUNCTION(And);
@ -133,6 +143,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Mean);
DML_OP_EXTERN_CREATION_FUNCTION(Max);
DML_OP_EXTERN_CREATION_FUNCTION(Min);
DML_OP_EXTERN_CREATION_FUNCTION(ReduceSum);
DML_OP_EXTERN_CREATION_FUNCTION(Einsum12);
DML_OP_EXTERN_CREATION_FUNCTION(ReduceMean);
DML_OP_EXTERN_CREATION_FUNCTION(ReduceProd);
DML_OP_EXTERN_CREATION_FUNCTION(ReduceLogSum);
@ -160,6 +171,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LeakyRelu);
DML_OP_EXTERN_CREATION_FUNCTION(PRelu);
DML_OP_EXTERN_CREATION_FUNCTION(ThresholdedRelu);
DML_OP_EXTERN_CREATION_FUNCTION(Elu);
DML_OP_EXTERN_CREATION_FUNCTION(Celu);
DML_OP_EXTERN_CREATION_FUNCTION(Selu);
DML_OP_EXTERN_CREATION_FUNCTION(Softmax);
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax);
@ -233,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
DML_OP_EXTERN_QUERY_FUNCTION(Resize);
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
@ -240,6 +253,7 @@ constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListScatterGather = { "T", "Tind" };
constexpr static std::array<const char*, 1> typeNameListScatterGatherND = { "T" }; // Tind is curiously missing, only allowing 64-bit.
@ -249,15 +263,18 @@ constexpr static std::array<const char*, 1> typeNameListEyeLike = { "T2" };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
@ -275,7 +292,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogica
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListPadWithoutFloat16 = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
@ -305,7 +322,7 @@ constexpr auto requiredConstantCpuInputs(Args... args)
// Identity operators use Copy, alias their first input, and require floating point formats
// for usage in the graph, besides constant inputs. This is because they currently use
// element-wise identity operators in the graph for striding support, but issue actual copies
// element-wise identity operators in the graph for striding support, but issue actual copies
// outside the graph. Element-wise identity currently only supports floating point types.
#define REG_INFO_ID(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, true, ##__VA_ARGS__,
@ -325,230 +342,244 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
/// Support query function
// Deep Learning Standard Layers
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
// Data Reorganization Layers
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
// Data reorganization that merely changes the dimensions while keeping the data identical.
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmGraphSupport::Supported)},
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmGraphSupport::Supported)},
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
// Imaging Operators
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
// Activation Functions
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
// Uncategorized
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)},
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
// Fused operators
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmGraphSupport::Supported)},
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmGraphSupport::NotSupported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmGraphSupport::NotSupported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)},
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
};
template<typename T>
@ -572,10 +603,13 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
MLOperatorKernelDescription desc = {};
desc.domain = information.domain;
desc.name = information.operatorName;
desc.executionType = MLOperatorExecutionType::D3D12;
desc.executionType = MLOperatorExecutionType::D3D12;
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
bool kernelSupportsGraph = (information.DmGraphSupport == DmGraphSupport::Supported);
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
desc.options = information.shapeInferenceFunction ?
MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes;
@ -651,6 +685,9 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
kernelSupportsGraph, // supportsGraph
information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr,
information.requiresFloatFormatsForGraph,
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
information.requiredConstantCpuInputs.first.data(),
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
));

View file

@ -11,6 +11,19 @@ class MLOperatorKernelCreationContext;
#define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
#define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void CALLBACK Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported);
// A specific opset version for registration.
// e.g.
// DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorSlice, 10>);
template <typename BaseClass, uint32_t opsetVersion>
class VersionedKernel : public BaseClass
{
public:
VersionedKernel(const MLOperatorKernelCreationContext& kernelInfo)
: BaseClass(kernelInfo, opsetVersion)
{
}
};
// Declares a callback creation function of the given operator class.
// This does not register it, just declares it for usage by registration later.
//

View file

@ -14,6 +14,7 @@ namespace Dml
switch (function)
{
case DML_OPERATOR_ACTIVATION_ELU:
case DML_OPERATOR_ACTIVATION_CELU:
return 1.0f;
case DML_OPERATOR_ACTIVATION_LEAKY_RELU:
@ -358,37 +359,45 @@ namespace Dml
}
}
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
{
for (auto& nameAndIndex : nameAndIndexList)
{
if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0)
{
return nameAndIndex.index;
}
}
ML_INVALID_ARGUMENT("Unknown mode value.");
}
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode)
{
// The ONNX modes are "nearest" and "linear." Other modes exist for compatibility,
// since Winml supported them in the past.
if (mode == "NEAREST" || mode == "nearest" || mode == "nn" || mode == "NN") {
return DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR;
}
else if (mode == "BILINEAR" || mode == "bilinear" || mode == "linear")
constexpr NameAndIndex mapping[] =
{
return DML_INTERPOLATION_MODE_LINEAR;
}
else
{
ML_INVALID_ARGUMENT("Unknown sampling interpolation mode.");
}
{"nearest", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
{"NEAREST", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
{"NN", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
{"nn", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
{"linear", DML_INTERPOLATION_MODE_LINEAR},
{"BILINEAR", DML_INTERPOLATION_MODE_LINEAR},
{"bilinear", DML_INTERPOLATION_MODE_LINEAR},
};
return MapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping);
}
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode)
{
if (mode == "DCR")
constexpr NameAndIndex mapping[] =
{
return DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW;
}
else if (mode == "CRD")
{
return DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH;
}
else
{
ML_INVALID_ARGUMENT("Unknown depth space mode.");
}
{"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW},
{"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH},
};
return MapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping);
}
} // namespace Dml

View file

@ -57,6 +57,20 @@ namespace Dml
void GetDmlAdjustedAxes(/*inout*/ gsl::span<const int32_t> axes, uint32_t onnxDimCount, uint32_t dmlDimCount, std::vector<uint32_t>& dmlAxes);
struct NameAndIndex
{
const char* name; // Null terminated.
uint32_t index;
};
template<typename T>
T MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
{
return static_cast<T>(MapStringToIndex(mode, nameAndIndexList));
}
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode);
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode);

View file

@ -184,53 +184,6 @@ TensorDesc::TensorDesc(
}
}
////////////////////////////////////////
// Handle 64-bit tensors.
uint64_t endPaddingInBytes = 0;
if (dataType == MLOperatorTensorDataType::UInt64 || dataType == MLOperatorTensorDataType::Int64)
{
// DirectML doesn't support tensor of int64 because Direct3D doesn't support
// the data type. A workaround is to use strides to fake 64-bit memory access
// while only the lower 32 bits contains the data. This trick obviously doesn't
// work if the data element is genuine 64-bit. It also doesn't work if the data
// element is negative as the signed bit will be incorrectly interpreted.
m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT32;
// If the strides haven't been calculated yet, initialize them as packed.
if (!useStrides)
{
uint32_t stride = 1;
for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
}
// Double the stride values to emulate 64-bit integer support.
for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i)
{
m_strides[i] *= 2;
}
useStrides = true;
// The physical size of the tensor will have an extra 4 bytes at the end.
// DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last
// addressable element of the tensor plus the space for the last element. However, the size
// of the last element is now halved from 8 bytes to 4 bytes.
//
// Example:
// Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48
// Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44
//
// DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate
// so that the entire region can be zeroed.
endPaddingInBytes = sizeof(uint32_t);
}
if (useStrides)
{
m_bufferTensorDesc.Strides = m_strides;
@ -239,20 +192,84 @@ TensorDesc::TensorDesc(
m_bufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE;
m_bufferTensorDesc.GuaranteedBaseOffsetAlignment = guaranteedBaseOffsetAlignment;
m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
m_bufferTensorDesc.DataType,
m_bufferTensorDesc.DimensionCount,
m_sizes,
m_bufferTensorDesc.DataType,
m_bufferTensorDesc.DimensionCount,
m_sizes,
useStrides ? m_strides : nullptr
);
assert(m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType));
}
void TensorDesc::Remap64bitDmlDataTypeTo32bit()
{
if (m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_UINT64 &&
m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_INT64)
{
return; // Nothing to do.
}
uint64_t endPaddingInBytes = 0;
// A workaround for older devices is to use strides to fake 64-bit memory access
// while only the lower 32 bits contains the data. This trick obviously doesn't
// work if the data element is genuine 64-bit. It also doesn't work if the data
// element is negative as the signed bit will be incorrectly interpreted.
m_bufferTensorDesc.DataType = Dml::Remap64bitDmlDataTypeTo32bit(m_bufferTensorDesc.DataType);
// If the strides haven't been calculated yet, initialize them as packed.
if (m_bufferTensorDesc.Strides == nullptr)
{
uint32_t stride = 1;
for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
}
// Double the stride values to emulate 64-bit integer support.
for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i)
{
m_strides[i] *= 2;
}
// The physical size of the tensor will have an extra 4 bytes at the end.
// DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last
// addressable element of the tensor plus the space for the last element. However, the size
// of the last element is now halved from 8 bytes to 4 bytes.
//
// Example:
// Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48
// Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44
//
// DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate
// so that the entire region can be zeroed.
endPaddingInBytes = sizeof(uint32_t);
m_bufferTensorDesc.Strides = m_strides;
m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
m_bufferTensorDesc.DataType,
m_bufferTensorDesc.DimensionCount,
m_sizes,
m_strides
) + endPaddingInBytes;
}
bool TensorDesc::WasRemapped64bitTo32bit() const
{
bool was64BitIntType = (m_mlOperatorTensorDataType == MLOperatorTensorDataType::UInt64 || m_mlOperatorTensorDataType == MLOperatorTensorDataType::Int64);
bool is32BitIntType = (m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_UINT32 || m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_INT32);
return was64BitIntType && is32BitIntType;
}
gsl::span<const uint32_t> TensorDesc::GetStrides() const
{
if (m_bufferTensorDesc.Strides == nullptr)
{
return {};
}
return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount };
return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount };
}
DML_TENSOR_DESC TensorDesc::GetDmlDesc()
@ -297,7 +314,8 @@ void TensorDesc::ForceUnsignedDataType()
m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8;
break;
// Nothing to do if already unsigned
// Nothing to do if already unsigned.
case DML_TENSOR_DATA_TYPE_UINT64:
case DML_TENSOR_DATA_TYPE_UINT32:
case DML_TENSOR_DATA_TYPE_UINT16:
case DML_TENSOR_DATA_TYPE_UINT8:
@ -307,3 +325,35 @@ void TensorDesc::ForceUnsignedDataType()
ML_INVALID_ARGUMENT("Can't coerce unknown or non-integral data type");
}
}
void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
{
ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount);
ML_CHECK_VALID_ARGUMENT(alignment == TensorAxis::RightAligned || alignment == TensorAxis::LeftAligned);
const uint32_t oldDimensionCount = m_bufferTensorDesc.DimensionCount;
const int32_t difference = static_cast<int32_t>(newDimensionCount - oldDimensionCount);
if (difference == 0)
{
return;
}
int32_t fillOffset = oldDimensionCount;
int32_t fillCount = std::max(0, difference);
// alignment == TensorAxis::LeftAligned is the easy case.
// Right alignment needs more work, shifting values over.
if (alignment == TensorAxis::RightAligned)
{
fillOffset = 0; // Fill leading dimensions with 1's starting at the front.
uint32_t moveCount = std::min(newDimensionCount, oldDimensionCount);
memmove(&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof(m_sizes[0]) * moveCount);
memmove(&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof(m_strides[0]) * moveCount);
}
if (fillCount > 0)
{
std::fill(&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u);
std::fill(&m_strides[fillOffset], &m_strides[fillOffset] + fillCount, 0u);
}
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}

View file

@ -36,11 +36,13 @@ namespace Dml
inline DML_TENSOR_DATA_TYPE GetDmlDataType() const { return m_bufferTensorDesc.DataType; }
inline MLOperatorTensorDataType GetMlOperatorDataType() const { return m_mlOperatorTensorDataType; }
inline gsl::span<const uint32_t> GetDmlSizes() const { return m_sizes; }
void ForceUnsignedDataType();
void Remap64bitDmlDataTypeTo32bit();
bool WasRemapped64bitTo32bit() const;
inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const;
@ -58,8 +60,6 @@ namespace Dml
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
};
class TensorDescBuilder
{
public:

View file

@ -15,6 +15,7 @@ namespace AttrName
static constexpr const char* Axis = "axis";
static constexpr const char* AxisW = "axis_w";
static constexpr const char* BatchAxis = "batch_axis";
static constexpr const char* BatchDimensions = "batch_dims";
static constexpr const char* Beta = "beta";
static constexpr const char* Bias = "bias";
static constexpr const char* BlockSize = "blocksize";
@ -32,6 +33,7 @@ namespace AttrName
static constexpr const char* Dtype = "dtype";
static constexpr const char* Ends = "ends";
static constexpr const char* Epsilon = "epsilon";
static constexpr const char* Equation = "equation";
static constexpr const char* ExcludeOutside = "exclude_outside";
static constexpr const char* Exclusive = "exclusive";
static constexpr const char* Exponent = "exponent";
@ -45,6 +47,7 @@ namespace AttrName
static constexpr const char* InputForget = "input_forget";
static constexpr const char* K = "k";
static constexpr const char* KeepDims = "keepdims";
static constexpr const char* SelectLastIndex = "select_last_index";
static constexpr const char* KernelShape = "kernel_shape";
static constexpr const char* LinearBeforeReset = "linear_before_reset";
static constexpr const char* Lambda = "lambd"; // Deliberate typo to match ONNX spec.
@ -57,12 +60,15 @@ namespace AttrName
static constexpr const char* NearestMode = "nearest_mode";
static constexpr const char* NormalizeVariance = "normalize_variance";
static constexpr const char* P = "p";
static constexpr const char* OutputHeight = "output_height";
static constexpr const char* OutputShape = "output_shape";
static constexpr const char* OutputPadding = "output_padding";
static constexpr const char* OutputWidth = "output_width";
static constexpr const char* Pads = "pads";
static constexpr const char* PooledShape = "pooled_shape";
static constexpr const char* Reverse = "reverse";
static constexpr const char* SampleSize = "sample_size";
static constexpr const char* SamplingRatio = "sampling_ratio";
static constexpr const char* Scale = "scale";
static constexpr const char* Scales = "scales";
static constexpr const char* Seed = "seed";

View file

@ -20,13 +20,12 @@
}\
}
template<typename T, typename I> T clamp_cast(I input)
{
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
}
namespace OperatorHelper
{
template<typename T, typename I> T clamp_cast(I input)
{
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
}
enum TensorAxis { N, C, H, W, DoNotCoerce = UINT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 };
enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast };

View file

@ -113,6 +113,9 @@ IMLOperatorRegistryPrivate : public IUnknown
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0
) const noexcept PURE;

View file

@ -113,6 +113,36 @@ namespace OperatorHelper
}
}
float CastFloat16ToFloat32(uint16_t input)
{
// Promote float16m10e5s1 to float32m23e8s1.
// Note this works on machines of both ascending and descending byte
// endianness, so long as float32 and uint32 endianness match.
// It does not work for a few abberant architectures which store
// float32 and uint32 with opposite endianness.
const uint32_t float16unsignedValueMask = 0x7FFF;
const uint32_t float16signMask = 0x8000;
const uint32_t float16exponentMask = 0x7C00;
const uint32_t float32exponentMask = 0x7F800000;
uint32_t float16unsignedValue = input & float16unsignedValueMask;
uint32_t float16sign = input & float16signMask;
uint32_t float16exponent = input & float16exponentMask;
// Shift mantissa bits left (23 - 10 = 13).
// Adjust exponent bias (127 - 15 = 112, 112 << 23 == 0x38000000).
// Move sign bit to float32 MSB (32 - 16 = 16).
uint32_t float32unsignedValue = (float16unsignedValue << 13) + 0x38000000;
uint32_t float32sign = float16sign << 16;
uint32_t result = (float16exponent == 0) ? (float32unsignedValue & ~float32exponentMask) : // Denormal
(float16exponent == float16exponentMask) ? (float32unsignedValue | float32exponentMask) : // Infinity
float32unsignedValue; // Any other normal value
result |= float32sign;
return reinterpret_cast<float&>(result);
}
int64_t CastToInt64(MLOperatorTensorDataType tensorDataType, const void* p)
{
switch (tensorDataType)
@ -150,7 +180,7 @@ namespace OperatorHelper
case MLOperatorTensorDataType::Int64: return static_cast<double>(*reinterpret_cast<const int64_t*>(p));
case MLOperatorTensorDataType::String: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::String type is unsupported for reading as an integer.");
case MLOperatorTensorDataType::Bool: return static_cast<double>(*reinterpret_cast<const uint8_t*>(p));
case MLOperatorTensorDataType::Float16: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::Float16 type is unsupported for reading as an integer.");
case MLOperatorTensorDataType::Float16: return static_cast<double>(CastFloat16ToFloat32(*reinterpret_cast<const uint16_t*>(p)));
case MLOperatorTensorDataType::Double: return static_cast<double>(*reinterpret_cast<const double*>(p));
case MLOperatorTensorDataType::UInt32: return static_cast<double>(*reinterpret_cast<const uint32_t*>(p));
case MLOperatorTensorDataType::UInt64: return static_cast<double>(*reinterpret_cast<const uint64_t*>(p));
@ -673,7 +703,7 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
ML_CHECK_VALID_ARGUMENT(outDimCount >= 0 && outDimCount <= NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(outDimCount >= 0);
std::vector<DimensionType> outputDimensions(outDimCount, 1);
@ -707,21 +737,27 @@ namespace OperatorHelper
{
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
int32_t batchCount = m_batchCount;
// Determine the number of output dimensions.
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1);
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > batchCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() > batchCount);
const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back();
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex);
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - numberOfCoordinatesPerIndex;
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - 1; // Strip off last dimension.
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= batchCount + numberOfCoordinatesPerIndex);
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - batchCount - numberOfCoordinatesPerIndex;
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - batchCount - 1; // Strip off last dimension.
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(batchCount + numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0);
// Form the full expected size by concatenating the prefix part of the indices tensor shape
// with the suffix of the input tensor shape.
// Form the full expected size by concatenating fragments:
// 1 - batch count
// 2 - prefix part of the indices tensor shape
// 3 - suffix of the input tensor shape.
std::vector<DimensionType> outputDimensions;
outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1);
outputDimensions.assign(inputDimensions.begin(), inputDimensions.begin() + batchCount);
outputDimensions.insert(outputDimensions.end(), indicesDimensions.begin() + batchCount, indicesDimensions.end() - 1);
outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end());
return { EdgeShapes(std::move(outputDimensions)) };
@ -782,8 +818,6 @@ namespace OperatorHelper
// Dim Offset : 1
std::vector<DimensionType> reducedDims = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(reducedDims.size() <= NchwDimensionCount);
std::vector<bool> reduced(reducedDims.size(), false);
for (auto& dim : m_axes)
@ -817,8 +851,6 @@ namespace OperatorHelper
void ReduceHelperBase::AdjustAxesAndOutputShape(const std::vector<uint32_t>& inputShape)
{
ML_CHECK_VALID_ARGUMENT(inputShape.size() <= NchwDimensionCount);
// If axes is not specified, reduce over all the dimensions
if (m_axes.empty())
{
@ -826,7 +858,264 @@ namespace OperatorHelper
std::iota(m_axes.begin(), m_axes.end(), 0);
}
}
void EinSumHelper::Initialize()
{
ParseEquationComponents();
m_recognizedOperatorType = DetermineRecognizedOperatorType();
}
void EinSumHelper::ParseEquationComponents()
{
// Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to
// numeric indices {(0,1}, {1,2}, {0,2}}. The last component is the output.
std::map<char, uint32_t> labelMap;
std::set<char> repeatedLabels;
uint32_t currentLabelIndex = 0;
Component currentComponent = {};
bool foundOutput = false;
bool reachedEnd = false;
// Read first to last character in equation, looking for letters, commas, and one arrow.
for (char* token = m_equation.data(); !reachedEnd; ++token)
{
char ch = *token;
// Only ASCII letters are valid subscript symbols in numpy.einsum().
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'))
{
// Check whether label already has an index.
const auto [i, inserted] = labelMap.insert({ch, currentLabelIndex});
if (inserted)
{
ML_CHECK_VALID_ARGUMENT(!foundOutput, "Found label in equation output not matching any label from inputs.")
++currentLabelIndex; // New label found.
}
else if (!foundOutput)
{
// If label in input already found earlier, then keep track of this later
// to generate the default output in case one is not specified.
repeatedLabels.insert(ch);
}
m_labelIndices.push_back(i->second);
}
else if (ch == ' ')
{
// Ignore spaces.
}
else
{
currentComponent.labelIndexEnd = static_cast<uint32_t>(m_labelIndices.size());
m_components.push_back(currentComponent);
currentComponent.labelIndexBegin = currentComponent.labelIndexEnd;
switch (ch)
{
case ',':
// Note it's valid for 2 commas be adjacent, which indicates a scalar and generates
// an empty component.
break;
case '-': // Start of "->" (must be atomic, no space between them).
++token; // Skip '-'.
ML_CHECK_VALID_ARGUMENT(*token == '>', "Expected '->' for output.")
ML_CHECK_VALID_ARGUMENT(foundOutput == false, "Only one output arrow '->' is valid.")
foundOutput = true;
break;
case '.':
// Ellipsis is unsupported. Leave recognized operator as None, deferring to another EP.
m_components.clear();
return;
case '\0':
reachedEnd = true;
break; // End of string.
default:
ML_INVALID_ARGUMENT("Unsupported character in equation string. Must be a-z, A-Z, ',', or '->'.");
}
}
}
if (!foundOutput)
{
// If no explicit output was given, generate an implicit output by ordering all the
// labels in alphabetic order (by ASCII value consistent with numpy, so Z < a).
// Exclude any labels that occurred more than once, as these cancel out.
for (auto i : labelMap)
{
if (repeatedLabels.count(i.first) == 0)
{
m_labelIndices.push_back(i.second);
}
}
// Push the final component, which is the output.
currentComponent.labelIndexEnd = static_cast<uint32_t>(m_labelIndices.size());
m_components.push_back(currentComponent);
}
}
EinSumHelper::RecognizedOperatorType EinSumHelper::DetermineRecognizedOperatorType()
{
if (m_components.empty())
{
return RecognizedOperatorType::None; // Parsing may have found unsupported components - treating as unknown.
}
// std::ranges::equal is not supported yet.
auto equals = [](gsl::span<const uint32_t> a, gsl::span<const uint32_t> b)
{
return std::equal(a.begin(), a.end(), b.begin(), b.end());
};
std::array<uint32_t, 3> componentRanks;
if (m_components.size() > componentRanks.size())
{
// No recognized operator takes more than 2 inputs and 1 output.
// EinSum itself is generic and can handle any variable number of inputs,
// but DML's operators expect fixed counts.
return RecognizedOperatorType::None;
}
else if (m_components.size() == 2)
{
auto& inputLabels = m_components[0].GetLabels(m_labelIndices);
auto& outputLabels = m_components[1].GetLabels(m_labelIndices);
if (inputLabels.size() == outputLabels.size())
{
// Check identity.
if (equals(inputLabels, outputLabels))
{
// Handles: "->", "i->i", "ij->ij", "ijk->ijk", "ijkl->ijkl" ...
return RecognizedOperatorType::Identity;
}
else // Transpose since a permutation exists.
{
// Handles: "ij->ji", "ijk->kji", "ijkl->lkji", "ijkl->ijkl" ...
return RecognizedOperatorType::Transpose;
}
}
else if (outputLabels.empty()) // Scalar output, with all inputs reduced.
{
// Handles: "i->", "ij->", "ijk->", "ijkl->" ...
return RecognizedOperatorType::ReduceSum;
}
}
else if (m_components.size() == 3)
{
// If all components have the same size and label order, then apply elementwise multiplication.
auto& inputALabels = m_components[0].GetLabels(m_labelIndices);
auto& inputBLabels = m_components[1].GetLabels(m_labelIndices);
auto& outputLabels = m_components[2].GetLabels(m_labelIndices);
if (equals(inputALabels, outputLabels) && equals(inputBLabels, outputLabels))
{
// Handles: "i,i->i", "ij,ij->ij", "ijk,ijk->ijk", "ijkl,ijkl->ijkl" ...
return RecognizedOperatorType::Multiply;
}
}
// Otherwise check for special cases of dedicated operators...
struct RecognizedOperatorInfo
{
RecognizedOperatorType recognizedOperatorType;
std::initializer_list<const uint32_t> componentRanks;
std::initializer_list<const uint32_t> labelIndices;
};
const RecognizedOperatorInfo recognizedOperators[] = {
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
};
// For each recognized operator, compare the labels-per-component and label indices.
for (auto& recognizedOperator : recognizedOperators)
{
if (equals(m_labelIndices, recognizedOperator.labelIndices)
&& m_components.size() == recognizedOperator.componentRanks.size())
{
for (size_t i = 0; i < m_components.size(); ++i)
{
componentRanks[i] = m_components[i].GetDimensionCount();
}
if (equals(gsl::make_span(componentRanks.data(), m_components.size()), recognizedOperator.componentRanks))
{
return recognizedOperator.recognizedOperatorType;
}
}
}
return RecognizedOperatorType::None;
}
std::vector<EdgeShapes> EinSumHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
assert(!m_components.empty()); // Should have already parsed components.
uint32_t inputCount = shapeInfo.GetInputCount();
uint32_t outputCount = shapeInfo.GetOutputCount();
ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count.");
ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor.");
std::vector<uint32_t> labelSizes(m_labelIndices.size(), INT_MIN);
// Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier.
for (uint32_t i = 0; i < inputCount; ++i)
{
auto inputShape = shapeInfo.GetInputTensorShape(i);
auto& component = m_components[i];
auto labelIndices = component.GetLabels(m_labelIndices);
uint32_t dimensionCount = component.GetDimensionCount();
ML_CHECK_VALID_ARGUMENT(inputShape.size() == dimensionCount, "Mismatch between input tensor shape and string equation label count.");
for (uint32_t i = 0; i < dimensionCount; ++i)
{
// If this is the first time seeing this label, then record the size.
// Otherwise any following occurrences of the label must match sizes.
// e.g. Given "ij,ji", both i's and both j's must match dimension sizes.
uint32_t dimensionSize = inputShape[i];
uint32_t labelIndex = labelIndices[i];
assert(labelIndex < labelSizes.size());
if (labelSizes[labelIndex] == INT_MIN)
{
labelSizes[labelIndex] = dimensionSize;
}
else
{
ML_CHECK_VALID_ARGUMENT(labelSizes[labelIndex] == dimensionSize, "All labels must have the same dimension sizes.");
}
}
}
// Generate output dimensions from corresponding input tensor labels.
// e.g. Given ij,jk->ij with [2,3] and [3,5], the output is [2,5].
std::vector<uint32_t> outputDimensions;
auto outputLabelIndices = m_components.back().GetLabels(m_labelIndices);
for (auto labelIndex : outputLabelIndices)
{
outputDimensions.push_back(labelSizes[labelIndex]);
}
return { std::move(EdgeShapes(outputDimensions)) };
}
std::vector<EdgeShapes> MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2);
@ -971,7 +1260,6 @@ namespace OperatorHelper
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto outputShape = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(outputShape.size() <= NchwDimensionCount);
uint32_t inputCount = shapeInfo.GetInputCount();
@ -1110,8 +1398,25 @@ namespace OperatorHelper
{
roiShape[0], // number of ROIs
inputShape[C], // number of channels
static_cast<DimensionType>(m_pooledSizeH),
static_cast<DimensionType>(m_pooledSizeW),
static_cast<DimensionType>(m_outputSizeH),
static_cast<DimensionType>(m_outputSizeW),
};
return { std::move(EdgeShapes(outputDimensions)) };
}
std::vector<EdgeShapes> RoiAlignHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
auto inputShape = shapeInfo.GetInputTensorShape(InputTensors::INPUT);
ML_CHECK_VALID_ARGUMENT(inputShape.size() >= 4, "inputShape must be >= 4.");
DimensionType outputDimensions[4] =
{
roiShape[0], // number of ROIs
inputShape[C], // number of channels
static_cast<DimensionType>(m_outputSizeH),
static_cast<DimensionType>(m_outputSizeW),
};
return { std::move(EdgeShapes(outputDimensions)) };

View file

@ -633,7 +633,7 @@ public:
{
int dimIndex = axes.empty() ? i : axes[i];
int stride = steps.empty() ? 1 : steps[i];
ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions.");
ML_CHECK_VALID_ARGUMENT(static_cast<size_t>(dimIndex) < static_cast<size_t>(inputDimensions.size()), "'axes' must be valid with within actual input dimensions.");
ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0.");
// Positive values are offsets from 0.
@ -733,6 +733,7 @@ class ReduceHelperBase {
template <typename Info_t, typename Shape_t>
ReduceHelperBase(const Info_t& info, const Shape_t& shape, bool usingAxes) {
m_keepDims = info.GetOptionalAttribute<int>(AttrName::KeepDims, 1);
m_selectLastIndex = info.GetOptionalAttribute<int>(AttrName::SelectLastIndex, 0);
if (usingAxes) {
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
} else {
@ -751,6 +752,7 @@ class ReduceHelperBase {
protected:
std::vector<int> m_axes;
int m_keepDims = 0;
int m_selectLastIndex = 0;
};
class ArgMinArgMaxHelper : public ReduceHelperBase {
@ -769,6 +771,70 @@ class ReduceHelper : public ReduceHelperBase {
ReduceHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, true) {}
};
class EinSumHelper
{
public:
void Initialize();
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
EinSumHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
{
m_equation = info.GetAttribute(AttrName::Equation);
Initialize();
}
EinSumHelper(const MLOperatorAttributes& info)
{
m_equation = info.GetAttribute(AttrName::Equation);
Initialize();
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
enum class RecognizedOperatorType
{
None,
Identity,
Multiply,
MatMul,
MatMulTransposeA,
MatMulTransposeB,
ReduceSum,
Transpose,
Total,
};
RecognizedOperatorType GetRecognizedOperatorType() const noexcept { return m_recognizedOperatorType; }
protected:
void ParseEquationComponents();
RecognizedOperatorType DetermineRecognizedOperatorType();
protected:
struct Component
{
uint32_t labelIndexBegin;
uint32_t labelIndexEnd;
uint32_t GetDimensionCount() const noexcept
{
return labelIndexEnd - labelIndexBegin;
}
gsl::span<const uint32_t> GetLabels(gsl::span<const uint32_t> labels) const
{
return labels.subspan(labelIndexBegin, labelIndexEnd - labelIndexBegin);
};
};
std::string m_equation;
std::vector<uint32_t> m_labelIndices; // Concatenation of all labels as rebased indices ("ij,ai" -> 0,1,2,0).
std::vector<Component> m_components; // All components in order, including inputs and output.
std::vector<uint32_t> m_outputDimensions;
RecognizedOperatorType m_recognizedOperatorType = RecognizedOperatorType::None;
};
class MatMulHelperBase {
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
@ -975,9 +1041,13 @@ class GatherNdHelper {
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
GatherNdHelper(const Info_t& info, const Shape_t& shape) {
m_batchCount = info.GetOptionalAttribute<int32_t>(AttrName::BatchDimensions, 0);
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
protected:
int32_t m_batchCount;
};
class PoolingHelperBase {
@ -1040,26 +1110,51 @@ class PoolingHelper : public PoolingHelperBase {
PoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
};
class RoiPoolingHelper {
public:
enum InputTensors { INPUT,
ROIS };
class RoiPoolingHelperBase
{
public:
enum InputTensors { INPUT, ROIS, BATCH_INDICES };
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
RoiPoolingHelper(const Info_t& info, const Shape_t& shape) {
std::vector<int> pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape);
ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2.");
m_pooledSizeH = pooledShape[0];
m_pooledSizeW = pooledShape[1];
}
RoiPoolingHelperBase()
{}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
protected:
uint32_t m_pooledSizeW;
uint32_t m_pooledSizeH;
protected:
uint32_t m_outputSizeW = 1;
uint32_t m_outputSizeH = 1;
};
class RoiPoolingHelper : public RoiPoolingHelperBase
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
RoiPoolingHelper(const Info_t& info, const Shape_t& shape)
{
std::vector<int> pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape);
ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2.");
m_outputSizeH = pooledShape[0];
m_outputSizeW = pooledShape[1];
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class RoiAlignHelper : public RoiPoolingHelperBase
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
RoiAlignHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
{
m_outputSizeW = info.GetOptionalAttribute<uint32_t>(AttrName::OutputWidth, 1);
m_outputSizeH = info.GetOptionalAttribute<uint32_t>(AttrName::OutputHeight, 1);
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class SqueezeHelper {
@ -1330,6 +1425,7 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
@ -1382,8 +1478,11 @@ using ShapeInferenceHelper_Ceil = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_GreaterOrEqual = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_LessOrEqual = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Equal = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Not = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_And = GetBroadcastedOutputShapeHelper;
@ -1430,6 +1529,7 @@ using ShapeInferenceHelper_ReduceL1 = ReduceHelper;
using ShapeInferenceHelper_ReduceL2 = ReduceHelper;
using ShapeInferenceHelper_ReduceMax = ReduceHelper;
using ShapeInferenceHelper_ReduceMin = ReduceHelper;
using ShapeInferenceHelper_Einsum12 = VersionedOpsetHelper<EinSumHelper, 12>;
using ShapeInferenceHelper_ArgMax = ArgMinArgMaxHelper;
using ShapeInferenceHelper_ArgMin = ArgMinArgMaxHelper;
using ShapeInferenceHelper_Gemm = GemmHelper;
@ -1450,6 +1550,7 @@ using ShapeInferenceHelper_LeakyRelu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_PRelu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_ThresholdedRelu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper;

View file

@ -245,6 +245,24 @@ namespace OperatorHelper
static const int sc_sinceVer_Unsqueeze = 11;
} // namespace OnnxOperatorSet11
namespace OnnxOperatorSet12
{
static const int sc_sinceVer_ArgMin = 12;
static const int sc_sinceVer_ArgMax = 12;
static const int sc_sinceVer_Celu = 12;
static const int sc_sinceVer_Clip = 12;
static const int sc_sinceVer_Einsum = 12;
static const int sc_sinceVer_GatherND = 12;
static const int sc_sinceVer_GreaterOrEqual = 12;
static const int sc_sinceVer_LessOrEqual = 12;
static const int sc_sinceVer_MaxPool = 12;
static const int sc_sinceVer_Min = 12;
static const int sc_sinceVer_Max = 12;
static const int sc_sinceVer_Pow = 12;
static const int sc_sinceVer_ReduceMax = 12;
static const int sc_sinceVer_ReduceMin = 12;
} // namespace OnnxOperatorSet12
namespace MsftOperatorSet1
{
static const int sc_sinceVer_FusedConv = 1;

View file

@ -7,6 +7,8 @@
#include <cassert>
#include <chrono>
#include <vector>
#include <map>
#include <set>
#include <numeric>
#include <wrl/client.h>

View file

@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="DirectML" version="2.1.0" targetFramework="native" />
<package id="DirectML" version="3.0.0" targetFramework="native" />
<package id="GoogleTestAdapter" version="0.17.1" targetFramework="net46" />
</packages>

View file

@ -288,7 +288,7 @@ def generate_files(list, args):
'" target="runtimes\\win-' + args.target_architecture + '\\native" />')
files_list.append('<file src=' + '"' + os.path.join(args.native_build_path, 'DirectML.pdb') +
'" target="runtimes\\win-' + args.target_architecture + '\\native" />')
files_list.append('<file src=' + '"' + os.path.join(args.packages_path, 'DirectML.2.1.0\\LICENSE.txt') +
files_list.append('<file src=' + '"' + os.path.join(args.packages_path, 'DirectML.2.1.1\\LICENSE.txt') +
'" target="DirectML_LICENSE.txt" />')
if includes_winml:

View file

@ -56,6 +56,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try {
#ifdef LAYERING_DONE
@ -79,6 +82,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
supportsGraph,
requiredInputCountForGraph,
requiresFloatFormatsForGraph,
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
requiredConstantCpuInputs,
constantCpuInputCount);
}

View file

@ -29,6 +29,9 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
bool supports_graph,
const uint32_t* required_input_count_for_graph = nullptr,
bool requires_float_formats_for_graph = false,
bool supports_64bit_directly = false,
bool allows_64bit_via_strides = false,
bool allows_64bit_via_strides_from_any_ep = false,
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
uint32_t constant_cpu_input_count = 0) const noexcept override;

View file

@ -54,9 +54,10 @@ static std::string ToString(wfc::IVectorView<int64_t> shape) {
static std::string ToString(
winml::TensorKind kind,
wfc::IVectorView<int64_t> shape) {
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
// Any unrecognized data type is considered "Undefined".
if (static_cast<uint32_t>(kind) >= std::size(SzTensorKind)) {
kind = winml::TensorKind::Undefined;
}
std::ostringstream stream;
stream << SzTensorKind[static_cast<uint32_t>(kind)] << ToString(shape);
@ -73,9 +74,10 @@ static std::string ToString(winml::ITensor value) {
static std::string ToString(winml::IMapFeatureDescriptor descriptor) {
auto keyKind = descriptor.KeyKind();
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
// Any unrecognized data type is considered "Undefined".
if (static_cast<uint32_t>(keyKind) >= std::size(SzTensorKind)) {
keyKind = winml::TensorKind::Undefined;
}
auto valueDescriptor = descriptor.ValueDescriptor();
std::ostringstream stream;
@ -86,9 +88,10 @@ static std::string ToString(winml::IMapFeatureDescriptor descriptor) {
static std::string ToString(winrt::com_ptr<_winml::IMapFeatureValue> value) {
winml::TensorKind keyKind;
FAIL_FAST_IF_FAILED(value->get_KeyKind(&keyKind));
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
// Any unrecognized data type is considered "Undefined".
if (static_cast<uint32_t>(keyKind) >= std::size(SzTensorKind)) {
keyKind = winml::TensorKind::Undefined;
}
winml::ILearningModelFeatureDescriptor valueDescriptor;
FAIL_FAST_IF_FAILED(value->get_ValueDescriptor(&valueDescriptor));