mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
Merge pull request #4925 from microsoft/user/dwayner/Iron
ORT DirectML EP for Iron release, ONNX 1.5
This commit is contained in:
commit
040c5fa3e0
60 changed files with 2953 additions and 874 deletions
2
cmake/external/dml.cmake
vendored
2
cmake/external/dml.cmake
vendored
|
|
@ -20,7 +20,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
|
|||
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
|
||||
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
|
||||
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
|
||||
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.2.1.0)
|
||||
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)
|
||||
|
||||
# Restore nuget packages, which will pull down the DirectML redist package
|
||||
add_custom_command(
|
||||
|
|
|
|||
3
onnxruntime/core/providers/dml/.clang-format
Normal file
3
onnxruntime/core/providers/dml/.clang-format
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Readability matters. Prevent syntax noise in pull requests for people who
|
||||
# accidentally leave enabled the auto-formatting options in Visual Studio.
|
||||
DisableFormat: true
|
||||
|
|
@ -94,11 +94,16 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
const void* executionHandle,
|
||||
DmlGraphNodeCreateInfo* graphNodeCreateInfo
|
||||
)>;
|
||||
|
||||
|
||||
struct GraphNodeFactoryRegistration
|
||||
{
|
||||
GraphNodeFactory factory;
|
||||
std::optional<uint32_t> requiredInputCount;
|
||||
|
||||
// The operator inputs/outputs must be a floating point data type. When true,
|
||||
// if the node's tensor data type is not-floating point, the node is partioned
|
||||
// separately (unless the input/output is a CPU constant input, which is okay,
|
||||
// as those can be read directly by the DML operator in the DML_OPERATOR_DESC).
|
||||
bool requiresFloatFormatsExceptConstInputs = false;
|
||||
};
|
||||
|
||||
|
|
@ -109,6 +114,20 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
std::vector<uint32_t> requiredConstantCpuInputs;
|
||||
std::optional<GraphNodeFactoryRegistration> graphNodeFactoryRegistration;
|
||||
KernelSupportQuery supportQuery;
|
||||
|
||||
// Many ONNX operators use 64-bit tensors, but most DML operators only support
|
||||
// 32-bit indices. This flag indicates to the graph whether it's okay to compute
|
||||
// the result using 32-bit tensors (ignoring the upper bits) via doubled strides.
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false;
|
||||
|
||||
// When true, the input to the current operator may come from any execution
|
||||
// provider. Otherwise it must have come from another DML node to assume it's safe
|
||||
// to use 64-bit to 32-bit striding.
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
|
||||
|
||||
// Operator supports true 64-bit tensors directly, no strides needed.
|
||||
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
|
||||
bool prefer64BitTensorsDirectly = false;
|
||||
};
|
||||
|
||||
using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;
|
||||
|
|
|
|||
|
|
@ -334,6 +334,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph,
|
||||
bool requiresFloatFormatsForGraph,
|
||||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
|
||||
uint32_t constantCpuInputCount) const noexcept try
|
||||
{
|
||||
|
|
@ -456,6 +459,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
{
|
||||
auto regInfo = std::make_shared<InternalRegistrationInfo>();
|
||||
regInfo->requiredConstantCpuInputs = constantCpuInputCapture;
|
||||
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
|
||||
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
|
||||
|
||||
// Only internal operators support usage in DML graphs
|
||||
if (supportsGraph)
|
||||
|
|
@ -527,8 +533,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
else
|
||||
{
|
||||
// Currently unsupported for external operators
|
||||
if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph ||
|
||||
requiresFloatFormatsForGraph || requiredConstantCpuInputs)
|
||||
if (canAliasFirstInput ||
|
||||
supportsGraph ||
|
||||
requiredInputCountForGraph ||
|
||||
requiresFloatFormatsForGraph ||
|
||||
requiredConstantCpuInputs ||
|
||||
supportedWith64BitTensorsVia32BitStrides ||
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
|
||||
prefer64BitTensorsDirectly)
|
||||
{
|
||||
THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
|
|||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph = nullptr,
|
||||
bool requiresFloatFormatsForGraph = false,
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
|
||||
uint32_t constantCpuInputCount = 0) const noexcept override;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,29 +16,42 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
|
|||
case MLOperatorTensorDataType::UInt16: return DML_TENSOR_DATA_TYPE_UINT16;
|
||||
case MLOperatorTensorDataType::Int16: return DML_TENSOR_DATA_TYPE_INT16;
|
||||
case MLOperatorTensorDataType::Int32: return DML_TENSOR_DATA_TYPE_INT32;
|
||||
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_UINT32;
|
||||
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_INT64;
|
||||
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
|
||||
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
|
||||
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
|
||||
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT32; // Stride is used to access lower 32-bits.
|
||||
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT64;
|
||||
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::Complex128: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::Undefined:
|
||||
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;;
|
||||
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
};
|
||||
}
|
||||
|
||||
DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept
|
||||
{
|
||||
switch (dmlElementType)
|
||||
{
|
||||
case DML_TENSOR_DATA_TYPE_UINT64: return DML_TENSOR_DATA_TYPE_UINT32; break;
|
||||
case DML_TENSOR_DATA_TYPE_INT64: return DML_TENSOR_DATA_TYPE_INT32; break;
|
||||
default: return dmlElementType;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT64: return true;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT32: return true;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT16: return true;
|
||||
case DML_TENSOR_DATA_TYPE_UINT64: return false;
|
||||
case DML_TENSOR_DATA_TYPE_UINT32: return false;
|
||||
case DML_TENSOR_DATA_TYPE_UINT16: return false;
|
||||
case DML_TENSOR_DATA_TYPE_UINT8: return false;
|
||||
case DML_TENSOR_DATA_TYPE_INT64: return true;
|
||||
case DML_TENSOR_DATA_TYPE_INT32: return true;
|
||||
case DML_TENSOR_DATA_TYPE_INT16: return true;
|
||||
case DML_TENSOR_DATA_TYPE_INT8: return true;
|
||||
|
|
@ -70,9 +83,14 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
|
|||
case DML_TENSOR_DATA_TYPE_INT32: return MLOperatorTensorDataType::Int32;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT16: return MLOperatorTensorDataType::Float16;
|
||||
case DML_TENSOR_DATA_TYPE_UINT32: return MLOperatorTensorDataType::UInt32;
|
||||
case DML_TENSOR_DATA_TYPE_UINT64: return MLOperatorTensorDataType::UInt64;
|
||||
case DML_TENSOR_DATA_TYPE_INT64: return MLOperatorTensorDataType::Int64;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT64: return MLOperatorTensorDataType::Double;
|
||||
|
||||
default: ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
|
||||
};
|
||||
}
|
||||
|
||||
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
|
||||
{
|
||||
return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType);
|
||||
|
|
@ -90,4 +108,40 @@ size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor)
|
|||
return ComputeByteSizeFromDimensions(gsl::make_span(dimensions.data(), dimensionCount), tensor.GetTensorDataType());
|
||||
}
|
||||
|
||||
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
|
||||
{
|
||||
uint32_t deviceTypeMask = 0u;
|
||||
|
||||
// Form the bitmask of all supported data types.
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
|
||||
{
|
||||
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
|
||||
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
|
||||
|
||||
THROW_IF_FAILED(dmlDevice->CheckFeatureSupport(
|
||||
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
|
||||
sizeof(dataTypeQuery),
|
||||
&dataTypeQuery,
|
||||
sizeof(dataTypeSupport),
|
||||
&dataTypeSupport
|
||||
));
|
||||
|
||||
deviceTypeMask |= (dataTypeSupport.IsSupported << i);
|
||||
}
|
||||
|
||||
return deviceTypeMask;
|
||||
}
|
||||
|
||||
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides)
|
||||
{
|
||||
assert(sizes.size() == strides.size());
|
||||
|
||||
uint32_t stride = 1;
|
||||
for (size_t i = strides.size(); i-- > 0; )
|
||||
{
|
||||
strides[i] = stride;
|
||||
stride *= sizes[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -10,13 +10,16 @@ namespace Dml
|
|||
{
|
||||
using namespace OperatorHelper;
|
||||
|
||||
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX;
|
||||
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX1;
|
||||
|
||||
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType);
|
||||
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept;
|
||||
DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept;
|
||||
MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType);
|
||||
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
|
||||
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
|
||||
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice);
|
||||
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides);
|
||||
|
||||
bool IsSigned(DML_TENSOR_DATA_TYPE dataType);
|
||||
|
||||
|
|
@ -40,6 +43,12 @@ namespace Dml
|
|||
UINT elementSizeInBytes = 0;
|
||||
switch (dataType)
|
||||
{
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT64:
|
||||
case DML_TENSOR_DATA_TYPE_UINT64:
|
||||
case DML_TENSOR_DATA_TYPE_INT64:
|
||||
elementSizeInBytes = 8;
|
||||
break;
|
||||
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT32:
|
||||
case DML_TENSOR_DATA_TYPE_UINT32:
|
||||
case DML_TENSOR_DATA_TYPE_INT32:
|
||||
|
|
|
|||
|
|
@ -376,8 +376,9 @@ namespace Dml
|
|||
{
|
||||
assert(!m_closed);
|
||||
|
||||
const size_t sourceSizeInBytes = ComputeByteSizeFromTensor(*src);
|
||||
const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst);
|
||||
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src)); // Tensors must be the same size
|
||||
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != sourceSizeInBytes); // Tensors must be the same size
|
||||
|
||||
if (dataSizeInBytes == 0)
|
||||
{
|
||||
|
|
@ -461,7 +462,7 @@ namespace Dml
|
|||
}
|
||||
CATCH_RETURN();
|
||||
|
||||
uint32_t ExecutionProviderImpl::GetSuppportedDeviceDataTypeMask() const
|
||||
uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const
|
||||
{
|
||||
// The DML provider registers all supported kernels up-front regardless of actual device capability,
|
||||
// but this is problematic later when executing the graph because DirectML will fail to create
|
||||
|
|
@ -470,26 +471,7 @@ namespace Dml
|
|||
// handle them, similar to the fallback in CUDAExecutionProvider::GetCapability for certain RNN/GRU/Conv
|
||||
// attributes.
|
||||
|
||||
uint32_t deviceTypeMask = 0u;
|
||||
|
||||
// Form the bitmask of all supported data types.
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
|
||||
{
|
||||
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
|
||||
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
|
||||
|
||||
THROW_IF_FAILED(m_dmlDevice->CheckFeatureSupport(
|
||||
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
|
||||
sizeof(dataTypeQuery),
|
||||
&dataTypeQuery,
|
||||
sizeof(dataTypeSupport),
|
||||
&dataTypeSupport
|
||||
));
|
||||
|
||||
deviceTypeMask |= (dataTypeSupport.IsSupported << i);
|
||||
}
|
||||
|
||||
return deviceTypeMask;
|
||||
return Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
|
|
@ -498,7 +480,7 @@ namespace Dml
|
|||
const std::vector<const onnxruntime::KernelRegistry*>& registries) const
|
||||
{
|
||||
std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_";
|
||||
uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask();
|
||||
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask();
|
||||
|
||||
return PartitionGraph(
|
||||
graph,
|
||||
|
|
|
|||
|
|
@ -16,10 +16,9 @@ using Base = Microsoft::WRL::RuntimeClass<
|
|||
TInterfaces...>;
|
||||
}
|
||||
|
||||
using namespace Microsoft::WRL;
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
using Microsoft::WRL::ComPtr;
|
||||
class PooledUploadHeap;
|
||||
class ReadbackHeap;
|
||||
class ExecutionContext;
|
||||
|
|
@ -87,7 +86,7 @@ namespace Dml
|
|||
const std::vector<const onnxruntime::KernelRegistry*>& registries
|
||||
) const;
|
||||
|
||||
uint32_t GetSuppportedDeviceDataTypeMask() const;
|
||||
uint32_t GetSupportedDeviceDataTypeMask() const;
|
||||
|
||||
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
|
||||
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ union ActivationOperatorDescUnion
|
|||
{
|
||||
DML_ACTIVATION_IDENTITY_OPERATOR_DESC identity;
|
||||
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
|
||||
DML_ACTIVATION_CELU_OPERATOR_DESC celu;
|
||||
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
|
||||
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
|
||||
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
|
||||
|
|
@ -36,6 +37,7 @@ struct ActivationOperatorDesc
|
|||
switch (activationType)
|
||||
{
|
||||
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, ¶ms.elu };
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, ¶ms.celu };
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, ¶ms.hardmax };
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, ¶ms.sigmoid };
|
||||
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, ¶ms.identity };
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 120;
|
||||
static constexpr size_t ActivationFunctionCount = 19;
|
||||
static constexpr auto ValueCount = 141;
|
||||
static constexpr size_t ActivationFunctionCount = 20;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -62,7 +62,7 @@ struct EnumTraits<DML_CONVOLUTION_DIRECTION>
|
|||
template <>
|
||||
struct EnumTraits<DML_PADDING_MODE>
|
||||
{
|
||||
static constexpr auto ValueCount = 3;
|
||||
static constexpr auto ValueCount = 4;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
|
|||
template <>
|
||||
struct EnumTraits<DML_FEATURE_LEVEL>
|
||||
{
|
||||
static constexpr auto ValueCount = 2;
|
||||
static constexpr auto ValueCount = 4;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -113,6 +113,12 @@ struct EnumTraits<DML_ROUNDING_MODE>
|
|||
static constexpr auto ValueCount = 3;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EnumTraits<DML_RANDOM_GENERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr auto EnumValueCount = EnumTraits<T>::ValueCount;
|
||||
|
||||
|
|
@ -273,6 +279,18 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -393,6 +411,18 @@ struct OperatorDescTraits<DML_REDUCE_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REDUCE;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ARGMIN_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMIN;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ARGMAX_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ARGMAX;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_AVERAGE_POOLING_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -747,6 +777,12 @@ struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_RESAMPLE1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -771,12 +807,108 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_AND;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_OR;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_XOR;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_NOT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_COUNT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_RELU_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_MAX_POOLING_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_RANDOM_GENERATOR_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RANDOM_GENERATOR;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_NONZERO_COORDINATES_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_NONZERO_COORDINATES;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_RESAMPLE_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_SLICE_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ADAM_OPTIMIZER_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ADAM_OPTIMIZER;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ROI_ALIGN_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_GATHER_ND1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_ELU;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_CELU_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_CELU;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_HARDMAX_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -993,6 +1125,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_L
|
|||
using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>
|
||||
{
|
||||
|
|
@ -1113,6 +1257,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REDUCE>
|
|||
using DescType = DML_REDUCE_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMIN>
|
||||
{
|
||||
using DescType = DML_ARGMIN_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ARGMAX>
|
||||
{
|
||||
using DescType = DML_ARGMAX_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING>
|
||||
{
|
||||
|
|
@ -1467,6 +1623,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZ
|
|||
using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE1>
|
||||
{
|
||||
using DescType = DML_RESAMPLE1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER>
|
||||
{
|
||||
|
|
@ -1491,12 +1653,108 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVO
|
|||
using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_AND>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_OR>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_XOR>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_NOT>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_COUNT>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_RELU_GRAD>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING_GRAD>
|
||||
{
|
||||
using DescType = DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING_GRAD>
|
||||
{
|
||||
using DescType = DML_MAX_POOLING_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RANDOM_GENERATOR>
|
||||
{
|
||||
using DescType = DML_RANDOM_GENERATOR_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_NONZERO_COORDINATES>
|
||||
{
|
||||
using DescType = DML_NONZERO_COORDINATES_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE_GRAD>
|
||||
{
|
||||
using DescType = DML_RESAMPLE_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE_GRAD>
|
||||
{
|
||||
using DescType = DML_SLICE_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ADAM_OPTIMIZER>
|
||||
{
|
||||
using DescType = DML_ADAM_OPTIMIZER_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN>
|
||||
{
|
||||
using DescType = DML_ROI_ALIGN_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1>
|
||||
{
|
||||
using DescType = DML_GATHER_ND1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_ELU_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_CELU>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_CELU_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX>
|
||||
{
|
||||
|
|
@ -1652,6 +1910,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR:
|
||||
|
|
@ -1692,6 +1954,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_GEMM_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_REDUCE:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_REDUCE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ARGMIN:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ARGMIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ARGMAX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_AVERAGE_POOLING:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_LP_POOLING:
|
||||
|
|
@ -1820,8 +2086,40 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_CONVOLUTION_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_AND:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_OR:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_RELU_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_AVERAGE_POOLING_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_MAX_POOLING_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_MAX_POOLING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_RANDOM_GENERATOR:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_RANDOM_GENERATOR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_NONZERO_COORDINATES:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_NONZERO_COORDINATES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_RESAMPLE_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_RESAMPLE_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_SLICE_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_SLICE_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ROI_ALIGN:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_GATHER_ND1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_CELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
|
||||
|
|
@ -1887,6 +2185,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR";
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR";
|
||||
|
|
@ -1907,6 +2207,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION";
|
||||
case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM";
|
||||
case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE";
|
||||
case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
|
||||
case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
|
||||
case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
|
||||
case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
|
||||
case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
|
||||
|
|
@ -1971,6 +2273,21 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY";
|
||||
case DML_OPERATOR_CONVOLUTION_INTEGER: return "DML_OPERATOR_CONVOLUTION_INTEGER";
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION";
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return "DML_OPERATOR_ELEMENT_WISE_BIT_AND";
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return "DML_OPERATOR_ELEMENT_WISE_BIT_OR";
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return "DML_OPERATOR_ELEMENT_WISE_BIT_XOR";
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return "DML_OPERATOR_ELEMENT_WISE_BIT_NOT";
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return "DML_OPERATOR_ELEMENT_WISE_BIT_COUNT";
|
||||
case DML_OPERATOR_ACTIVATION_RELU_GRAD: return "DML_OPERATOR_ACTIVATION_RELU_GRAD";
|
||||
case DML_OPERATOR_AVERAGE_POOLING_GRAD: return "DML_OPERATOR_AVERAGE_POOLING_GRAD";
|
||||
case DML_OPERATOR_MAX_POOLING_GRAD: return "DML_OPERATOR_MAX_POOLING_GRAD";
|
||||
case DML_OPERATOR_RANDOM_GENERATOR: return "DML_OPERATOR_RANDOM_GENERATOR";
|
||||
case DML_OPERATOR_NONZERO_COORDINATES: return "DML_OPERATOR_NONZERO_COORDINATES";
|
||||
case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD";
|
||||
case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD";
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER";
|
||||
case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN";
|
||||
case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1";
|
||||
default:
|
||||
assert(false);
|
||||
return "<unknown>";
|
||||
|
|
|
|||
|
|
@ -296,6 +296,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA
|
|||
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL",
|
||||
DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL",
|
||||
DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS[2] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -599,6 +627,38 @@ constexpr DML_OPERATOR_SCHEMA DML_REDUCE_OPERATOR_SCHEMA {
|
|||
DML_REDUCE_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ARGMIN_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ARGMIN_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ARGMIN",
|
||||
DML_OPERATOR_ARGMIN,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
5,
|
||||
DML_ARGMIN_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ARGMAX_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ARGMAX_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ARGMAX",
|
||||
DML_OPERATOR_ARGMAX,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
5,
|
||||
DML_ARGMAX_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -1628,7 +1688,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIEL
|
|||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -1648,6 +1708,247 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA {
|
|||
DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_BIT_AND",
|
||||
DML_OPERATOR_ELEMENT_WISE_BIT_AND,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_BIT_OR",
|
||||
DML_OPERATOR_ELEMENT_WISE_BIT_OR,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_BIT_XOR",
|
||||
DML_OPERATOR_ELEMENT_WISE_BIT_XOR,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS[2] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_BIT_NOT",
|
||||
DML_OPERATOR_ELEMENT_WISE_BIT_NOT,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
2,
|
||||
DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS[2] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT",
|
||||
DML_OPERATOR_ELEMENT_WISE_BIT_COUNT,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
2,
|
||||
DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ACTIVATION_RELU_GRAD",
|
||||
DML_OPERATOR_ACTIVATION_RELU_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_AVERAGE_POOLING_GRAD",
|
||||
DML_OPERATOR_AVERAGE_POOLING_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_MAX_POOLING_GRAD",
|
||||
DML_OPERATOR_MAX_POOLING_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputStateTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputStateTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Type", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_RANDOM_GENERATOR_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_RANDOM_GENERATOR",
|
||||
DML_OPERATOR_RANDOM_GENERATOR,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_RANDOM_GENERATOR_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCountTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCoordinatesTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_NONZERO_COORDINATES_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_NONZERO_COORDINATES",
|
||||
DML_OPERATOR_NONZERO_COORDINATES,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
DML_NONZERO_COORDINATES_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS[7] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Scales", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "InputPixelOffsets", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_RESAMPLE_GRAD",
|
||||
DML_OPERATOR_RESAMPLE_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
7,
|
||||
DML_RESAMPLE_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_SLICE_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_SLICE_GRAD",
|
||||
DML_OPERATOR_SLICE_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
6,
|
||||
DML_SLICE_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS[12] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputParametersTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputFirstMomentTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputSecondMomentTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "GradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "TrainingStepTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputParametersTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputFirstMomentTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSecondMomentTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "LearningRate", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta1", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta2", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ADAM_OPTIMIZER",
|
||||
DML_OPERATOR_ADAM_OPTIMIZER,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
12,
|
||||
DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS[11] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ROI_ALIGN",
|
||||
DML_OPERATOR_ROI_ALIGN,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
11,
|
||||
DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BatchDimensionCount", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_GATHER_ND1",
|
||||
DML_OPERATOR_GATHER_ND1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
6,
|
||||
DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -1662,6 +1963,20 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_ELU_OPERATOR_SCHEMA {
|
|||
DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_CELU_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ACTIVATION_CELU",
|
||||
DML_OPERATOR_ACTIVATION_CELU,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ACTIVATION_CELU_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS[2] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
|
|||
|
|
@ -145,6 +145,22 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_
|
|||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -328,6 +344,26 @@ inline std::vector<OperatorField> GetFields(const DML_REDUCE_OPERATOR_DESC& desc
|
|||
OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ARGMIN_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
|
||||
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
OperatorField(&DML_ARGMIN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ARGMAX_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
|
||||
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
OperatorField(&DML_ARGMAX_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -993,6 +1029,157 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTI
|
|||
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast<UINT>(desc.GroupCount))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputStateTensor))),
|
||||
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputStateTensor))),
|
||||
OperatorField(&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Type))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_NONZERO_COORDINATES_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputCountTensor))),
|
||||
OperatorField(&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputCoordinatesTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_RESAMPLE_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const FLOAT*>(desc.Scales), desc.DimensionCount)),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const FLOAT*>(desc.InputPixelOffsets), desc.DimensionCount)),
|
||||
OperatorField(&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const FLOAT*>(desc.OutputPixelOffsets), desc.DimensionCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_SLICE_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowOffsets), desc.DimensionCount)),
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowSizes), desc.DimensionCount)),
|
||||
OperatorField(&DML_SLICE_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const INT*>(desc.InputWindowStrides), desc.DimensionCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ADAM_OPTIMIZER_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputParametersTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputFirstMomentTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputSecondMomentTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.GradientTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.TrainingStepTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputParametersTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputFirstMomentTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputSecondMomentTensor))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.LearningRate))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta1))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta2))),
|
||||
OperatorField(&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.OutOfBoundsInputValue))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
|
||||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
|
||||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.BatchDimensionCount))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1001,6 +1188,14 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DE
|
|||
OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_CELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_CELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1165,6 +1360,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA;
|
||||
|
|
@ -1185,6 +1382,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_CONVOLUTION: return DML_CONVOLUTION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GEMM: return DML_GEMM_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_REDUCE: return DML_REDUCE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA;
|
||||
|
|
@ -1249,7 +1448,23 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_OR: return DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR: return DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT: return DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT: return DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_RELU_GRAD: return DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_AVERAGE_POOLING_GRAD: return DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_MAX_POOLING_GRAD: return DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_RANDOM_GENERATOR: return DML_RANDOM_GENERATOR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_NONZERO_COORDINATES: return DML_NONZERO_COORDINATES_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_RESAMPLE_GRAD: return DML_RESAMPLE_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA;
|
||||
|
|
@ -1345,6 +1560,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA,
|
||||
|
|
@ -1425,6 +1648,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_REDUCE_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_REDUCE_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ARGMIN:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ARGMIN_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ARGMIN_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ARGMAX:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ARGMAX_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ARGMAX_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_AVERAGE_POOLING:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_AVERAGE_POOLING_OPERATOR_SCHEMA,
|
||||
|
|
@ -1681,10 +1912,74 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_AND:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_OR:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_BIT_OR_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_OR_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_XOR:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_BIT_XOR_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_XOR_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_NOT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_BIT_NOT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_NOT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_COUNT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_BIT_COUNT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_RELU_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_RELU_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_RELU_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_AVERAGE_POOLING_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_MAX_POOLING_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_MAX_POOLING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_RANDOM_GENERATOR:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_RANDOM_GENERATOR_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_RANDOM_GENERATOR_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_NONZERO_COORDINATES:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_NONZERO_COORDINATES_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_NONZERO_COORDINATES_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_RESAMPLE_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_RESAMPLE_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_RESAMPLE_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_SLICE_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_SLICE_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_SLICE_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ADAM_OPTIMIZER_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ROI_ALIGN:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ROI_ALIGN_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ROI_ALIGN_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_GATHER_ND1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_GATHER_ND1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_GATHER_ND1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_ELU_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_CELU:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_CELU_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_CELU_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA,
|
||||
|
|
|
|||
|
|
@ -5,20 +5,12 @@
|
|||
|
||||
#include "MLOperatorAuthorImpl.h"
|
||||
#include "FusedGraphKernel.h"
|
||||
#include "GraphKernelHelper.h"
|
||||
|
||||
using namespace Windows::AI::MachineLearning::Adapter;
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
template <typename T>
|
||||
static T AlignToPow2(T offset, T alignment)
|
||||
{
|
||||
static_assert(std::is_unsigned_v<T>);
|
||||
assert(alignment != 0);
|
||||
assert((alignment & (alignment - 1)) == 0);
|
||||
return (offset + alignment - 1) & ~(alignment - 1);
|
||||
}
|
||||
|
||||
class FusedGraphKernel : public onnxruntime::OpKernel
|
||||
{
|
||||
public:
|
||||
|
|
@ -73,42 +65,16 @@ namespace Dml
|
|||
|
||||
const uint32_t graphInputCount = kernelInfo.GetInputCount();
|
||||
|
||||
auto gpuGraphInputConstnessGetter = [&kernelInfo, &fusedNodeInputDefs, &transferredInitializerMap](uint32_t index)
|
||||
{
|
||||
// Transferred initializers are uploaded to GPU memory
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// If an initializer wasn't transferred, the constant input may be available from ORT
|
||||
const onnxruntime::Tensor* inputTensor = nullptr;
|
||||
if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the constant ORT input is in GPU memory
|
||||
if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) ||
|
||||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput ||
|
||||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
m_inputsConstant.resize(graphInputCount);
|
||||
for (uint32_t i = 0; i < graphInputCount; ++i)
|
||||
{
|
||||
m_inputsConstant[i] = gpuGraphInputConstnessGetter(i);
|
||||
m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap);
|
||||
}
|
||||
|
||||
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
|
||||
kernelInfo,
|
||||
m_inputsConstant,
|
||||
m_inputsConstant.data(),
|
||||
m_inputsConstant.size(),
|
||||
transferredInitializerMap,
|
||||
graph,
|
||||
fusedNodeInputDefs,
|
||||
|
|
@ -117,116 +83,27 @@ namespace Dml
|
|||
device.Get(),
|
||||
m_executionHandle);
|
||||
|
||||
// Determine the last input which uses an initializer, so initializers can be freed incrementally
|
||||
// while processing each input in order.
|
||||
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
|
||||
for (uint32_t i = 0; i < graphInputCount; i++)
|
||||
{
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
initializerToLastInputIndexMap[&iter->second] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Walk through each graph edge and mark used inputs
|
||||
m_inputsUsed.assign(graphInputCount, false);
|
||||
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges)
|
||||
{
|
||||
m_inputsUsed[edge.GraphInputIndex] = true;
|
||||
}
|
||||
|
||||
// Populate input bindings for operator initialization
|
||||
std::vector<ComPtr<ID3D12Resource>> initInputResources; // For lifetime control
|
||||
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initInputResources; // For lifetime control
|
||||
std::vector<DML_BUFFER_BINDING> initInputBindings(graphInputCount);
|
||||
m_nonOwnedGraphInputsFromInitializers.resize(graphInputCount);
|
||||
std::vector<ComPtr<ID3D12Resource>> initializeResourceRefs;
|
||||
|
||||
for (uint32_t i = 0; i < initInputBindings.size(); i++)
|
||||
{
|
||||
// If the input isn't actually used by the graph, nothing ever needs to be bound (either for
|
||||
// initialization or execution). So just throw away the transferred initializer and skip this input.
|
||||
if (!m_inputsUsed[i])
|
||||
{
|
||||
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Look for the initializer among those transferred from the graph during partitioning
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
std::byte* tensorPtr = nullptr;
|
||||
size_t tensorByteSize = 0;
|
||||
std::unique_ptr<std::byte[]> unpackedTensor;
|
||||
|
||||
auto& initializer = iter->second;
|
||||
|
||||
// The tensor may be stored as raw data or in typed fields.
|
||||
if (initializer.has_raw_data())
|
||||
{
|
||||
tensorPtr = (std::byte*)(initializer.raw_data().c_str());
|
||||
tensorByteSize = initializer.raw_data().size();
|
||||
}
|
||||
else
|
||||
{
|
||||
std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer);
|
||||
tensorPtr = unpackedTensor.get();
|
||||
}
|
||||
|
||||
// Tensor sizes in DML must be a multiple of 4 bytes large.
|
||||
tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
|
||||
|
||||
if (!m_inputsConstant[i])
|
||||
{
|
||||
// Store the resource to use during execution
|
||||
ComPtr<ID3D12Resource> defaultBuffer = CreateResource(tensorPtr, tensorByteSize);
|
||||
m_nonOwnedGraphInputsFromInitializers[i] = defaultBuffer;
|
||||
initializeResourceRefs.push_back(std::move(defaultBuffer));
|
||||
}
|
||||
else
|
||||
{
|
||||
ComPtr<ID3D12Resource> initializeInputBuffer;
|
||||
|
||||
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
|
||||
if (m_provider->IsMcdmDevice())
|
||||
{
|
||||
initializeInputBuffer = CreateResource(tensorPtr, tensorByteSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
initializeInputBuffer = CreateCpuResource(tensorPtr, tensorByteSize);
|
||||
}
|
||||
|
||||
// Set the binding for operator initialization to the buffer
|
||||
initInputBindings[i].Buffer = initializeInputBuffer.Get();
|
||||
initInputBindings[i].SizeInBytes = tensorByteSize;
|
||||
initializeResourceRefs.push_back(std::move(initializeInputBuffer));
|
||||
}
|
||||
|
||||
// Free the initializer if this is the last usage of it.
|
||||
if (initializerToLastInputIndexMap[&initializer] == i)
|
||||
{
|
||||
transferredInitializerMap.erase(iter);
|
||||
}
|
||||
}
|
||||
else if (m_inputsConstant[i])
|
||||
{
|
||||
const onnxruntime::Tensor* inputTensor = nullptr;
|
||||
THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor));
|
||||
|
||||
uint64_t allocId;
|
||||
UnwrapTensor(inputTensor, &initInputBindings[i].Buffer, &allocId);
|
||||
initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width;
|
||||
|
||||
initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference
|
||||
initInputResources.push_back(initInputBindings[i].Buffer);
|
||||
}
|
||||
}
|
||||
|
||||
// All initializers should have been consumed and freed above
|
||||
assert(transferredInitializerMap.empty());
|
||||
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initializeResourceRefs;
|
||||
|
||||
GraphKernelHelper::PopulateInputBindings(
|
||||
m_provider.Get(),
|
||||
m_winmlProvider.Get(),
|
||||
m_inputsConstant,
|
||||
kernelInfo,
|
||||
graphDesc,
|
||||
fusedNodeInputDefs,
|
||||
m_inputsUsed,
|
||||
initInputBindings,
|
||||
initInputResources,
|
||||
m_nonOwnedGraphInputsFromInitializers,
|
||||
initializeResourceRefs,
|
||||
transferredInitializerMap);
|
||||
|
||||
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
|
||||
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
|
||||
|
||||
|
|
@ -234,38 +111,15 @@ namespace Dml
|
|||
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlIntermediateEdges(graphDesc.intermediateEdges.size());
|
||||
|
||||
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
|
||||
{
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{ graphDesc.nodes[i].op.Get() };
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{ DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i] };
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
|
||||
{
|
||||
dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i] };
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
|
||||
{
|
||||
dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i] };
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
|
||||
{
|
||||
dmlIntermediateEdges[i] = DML_GRAPH_EDGE_DESC{ DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i] };
|
||||
}
|
||||
|
||||
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||
dmlGraphDesc.InputCount = graphInputCount;
|
||||
dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount();
|
||||
dmlGraphDesc.NodeCount = gsl::narrow_cast<uint32_t>(dmlGraphNodes.size());
|
||||
dmlGraphDesc.Nodes = dmlGraphNodes.data();
|
||||
dmlGraphDesc.InputEdgeCount = gsl::narrow_cast<uint32_t>(dmlInputEdges.size());
|
||||
dmlGraphDesc.InputEdges = dmlInputEdges.data();
|
||||
dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast<uint32_t>(dmlOutputEdges.size());
|
||||
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
|
||||
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
|
||||
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
|
||||
GraphKernelHelper::ConvertGraphDesc(
|
||||
graphDesc,
|
||||
dmlGraphDesc,
|
||||
kernelInfo,
|
||||
dmlOperatorGraphNodes,
|
||||
dmlGraphNodes,
|
||||
dmlInputEdges,
|
||||
dmlOutputEdges,
|
||||
dmlIntermediateEdges);
|
||||
|
||||
DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE;
|
||||
if (graphDesc.reuseCommandList)
|
||||
|
|
@ -533,10 +387,10 @@ namespace Dml
|
|||
const onnxruntime::Tensor* tensor = kernelContext->Input<onnxruntime::Tensor>(i);
|
||||
|
||||
uint64_t allocId;
|
||||
UnwrapTensor(tensor, &inputBindings[i].Buffer, &allocId);
|
||||
GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId);
|
||||
inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId);
|
||||
inputBindings[i].Buffer->Release(); // Avoid holding an additional reference
|
||||
inputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
|
||||
inputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
|
||||
inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]};
|
||||
m_inputBindingAllocIds[i] = allocId;
|
||||
}
|
||||
|
|
@ -570,10 +424,10 @@ namespace Dml
|
|||
);
|
||||
|
||||
uint64_t allocId;
|
||||
UnwrapTensor(tensor, &outputBindings[i].Buffer, &allocId);
|
||||
GraphKernelHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId);
|
||||
outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId);
|
||||
outputBindings[i].Buffer->Release(); // Avoid holding an additional reference
|
||||
outputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
|
||||
outputBindings[i].SizeInBytes = GraphKernelHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
|
||||
outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]};
|
||||
m_outputBindingAllocIds[i] = allocId;
|
||||
}
|
||||
|
|
@ -623,106 +477,6 @@ namespace Dml
|
|||
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
|
||||
}
|
||||
|
||||
void UnwrapTensor(const onnxruntime::Tensor* tensor, ID3D12Resource** resource, uint64_t* allocId) const
|
||||
{
|
||||
IUnknown* allocationUnk = static_cast<IUnknown*>(const_cast<void*>(tensor->DataRaw()));
|
||||
ComPtr<IUnknown> resourceUnk;
|
||||
m_winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk);
|
||||
|
||||
*allocId = m_winmlProvider->TryGetPooledAllocationId(allocationUnk, 0);
|
||||
|
||||
THROW_IF_FAILED(resourceUnk->QueryInterface(resource));
|
||||
}
|
||||
|
||||
ComPtr<ID3D12Resource> CreateResource(const std::byte* tensorPtr, size_t tensorByteSize) const
|
||||
{
|
||||
ComPtr<ID3D12Resource> buffer;
|
||||
|
||||
D3D12_HEAP_PROPERTIES heapProperties = {
|
||||
D3D12_HEAP_TYPE_DEFAULT,
|
||||
D3D12_CPU_PAGE_PROPERTY_UNKNOWN,
|
||||
D3D12_MEMORY_POOL_UNKNOWN,
|
||||
0,
|
||||
0
|
||||
};
|
||||
|
||||
D3D12_RESOURCE_DESC resourceDesc = {
|
||||
D3D12_RESOURCE_DIMENSION_BUFFER,
|
||||
0,
|
||||
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
DXGI_FORMAT_UNKNOWN,
|
||||
{ 1, 0 },
|
||||
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
|
||||
};
|
||||
|
||||
ComPtr<ID3D12Device> d3dDevice;
|
||||
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
|
||||
|
||||
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
|
||||
&heapProperties,
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
&resourceDesc,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(buffer.GetAddressOf())
|
||||
));
|
||||
|
||||
THROW_IF_FAILED(m_provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ComPtr<ID3D12Resource> CreateCpuResource(const std::byte* tensorPtr, size_t tensorByteSize) const
|
||||
{
|
||||
ComPtr<ID3D12Resource> buffer;
|
||||
|
||||
D3D12_HEAP_PROPERTIES heapProperties = {
|
||||
D3D12_HEAP_TYPE_CUSTOM,
|
||||
D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE,
|
||||
D3D12_MEMORY_POOL_L0,
|
||||
0,
|
||||
0
|
||||
};
|
||||
|
||||
D3D12_RESOURCE_DESC resourceDesc = {
|
||||
D3D12_RESOURCE_DIMENSION_BUFFER,
|
||||
0,
|
||||
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
DXGI_FORMAT_UNKNOWN,
|
||||
{ 1, 0 },
|
||||
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
|
||||
};
|
||||
|
||||
ComPtr<ID3D12Device> d3dDevice;
|
||||
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
|
||||
|
||||
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
|
||||
&heapProperties,
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
&resourceDesc,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(buffer.GetAddressOf())
|
||||
));
|
||||
|
||||
// Map the buffer and copy the data
|
||||
void* bufferData = nullptr;
|
||||
D3D12_RANGE range = {0, tensorByteSize};
|
||||
THROW_IF_FAILED(buffer->Map(0, &range, &bufferData));
|
||||
memcpy(bufferData, tensorPtr, tensorByteSize);
|
||||
buffer->Unmap(0, &range);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
|
||||
std::vector<bool> m_inputsUsed;
|
||||
const void* m_executionHandle = nullptr;
|
||||
|
|
|
|||
|
|
@ -60,7 +60,8 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
GraphDesc BuildGraphDesc(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
gsl::span<const uint8_t> isConstGpuGraphInput,
|
||||
const uint8_t* isConstGpuGraphInput,
|
||||
const size_t isConstGpuGraphInputCount,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
|
|
@ -226,7 +227,7 @@ namespace Dml::GraphDescBuilder
|
|||
graphInputEdges.push_back(edge);
|
||||
|
||||
// If this is a constant input, set the appropriate flags on the desc
|
||||
if (isConstGpuGraphInput[fusedNodeInputIndex])
|
||||
if (fusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[fusedNodeInputIndex])
|
||||
{
|
||||
DmlBufferTensorDesc* tensorDesc = inputTensorDescs[inputIndex];
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ namespace Dml
|
|||
|
||||
GraphDesc BuildGraphDesc(
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
gsl::span<const uint8_t> isConstGpuGraphInput,
|
||||
const uint8_t* isConstGpuGraphInput,
|
||||
const size_t isConstGpuGraphInputCount,
|
||||
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
|
||||
const onnxruntime::Graph& graph,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,313 @@
|
|||
#include "precomp.h"
|
||||
|
||||
#include "GraphKernelHelper.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
namespace GraphKernelHelper
|
||||
{
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource>
|
||||
CreateResource(
|
||||
Dml::IExecutionProvider* provider,
|
||||
const std::byte* tensorPtr,
|
||||
size_t tensorByteSize)
|
||||
{
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource> buffer;
|
||||
|
||||
D3D12_HEAP_PROPERTIES heapProperties = {
|
||||
D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0};
|
||||
|
||||
D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER,
|
||||
0,
|
||||
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
DXGI_FORMAT_UNKNOWN,
|
||||
{1, 0},
|
||||
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> d3dDevice;
|
||||
THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf()));
|
||||
|
||||
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
|
||||
&heapProperties,
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
&resourceDesc,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(buffer.GetAddressOf())));
|
||||
|
||||
THROW_IF_FAILED(provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource>
|
||||
CreateCpuResource(
|
||||
Dml::IExecutionProvider* provider,
|
||||
const std::byte* tensorPtr,
|
||||
size_t tensorByteSize)
|
||||
{
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource> buffer;
|
||||
|
||||
D3D12_HEAP_PROPERTIES heapProperties = {
|
||||
D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, D3D12_MEMORY_POOL_L0, 0, 0};
|
||||
|
||||
D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER,
|
||||
0,
|
||||
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
DXGI_FORMAT_UNKNOWN,
|
||||
{1, 0},
|
||||
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> d3dDevice;
|
||||
THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf()));
|
||||
|
||||
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
|
||||
&heapProperties,
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
&resourceDesc,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(buffer.GetAddressOf())));
|
||||
|
||||
// Map the buffer and copy the data
|
||||
void* bufferData = nullptr;
|
||||
D3D12_RANGE range = {0, tensorByteSize};
|
||||
THROW_IF_FAILED(buffer->Map(0, &range, &bufferData));
|
||||
memcpy(bufferData, tensorPtr, tensorByteSize);
|
||||
buffer->Unmap(0, &range);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void UnwrapTensor(
|
||||
IWinmlExecutionProvider* winmlProvider,
|
||||
const onnxruntime::Tensor* tensor,
|
||||
ID3D12Resource** resource,
|
||||
uint64_t* allocId)
|
||||
{
|
||||
IUnknown* allocationUnk = static_cast<IUnknown*>(const_cast<void*>(tensor->DataRaw()));
|
||||
Microsoft::WRL::ComPtr<IUnknown> resourceUnk;
|
||||
winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk);
|
||||
|
||||
*allocId = winmlProvider->TryGetPooledAllocationId(allocationUnk, 0);
|
||||
|
||||
THROW_IF_FAILED(resourceUnk->QueryInterface(resource));
|
||||
}
|
||||
|
||||
bool GetGraphInputConstness(
|
||||
uint32_t index,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
|
||||
{
|
||||
// Transferred initializers are uploaded to GPU memory
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// If an initializer wasn't transferred, the constant input may be available from ORT
|
||||
const onnxruntime::Tensor* inputTensor = nullptr;
|
||||
if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the constant ORT input is in GPU memory
|
||||
if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) ||
|
||||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput ||
|
||||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<std::vector<std::byte>> PopulateInputBindings(
|
||||
Dml::IExecutionProvider* provider,
|
||||
IWinmlExecutionProvider* winmlProvider,
|
||||
const std::vector<uint8_t>& inputsConstant,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
_Out_ std::vector<bool>& inputsUsed,
|
||||
_Out_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& nonOwnedGraphInputsFromInitializers,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& initializeResourceRefs,
|
||||
_Inout_ std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
|
||||
{
|
||||
std::vector<std::vector<std::byte>> inputRawData;
|
||||
|
||||
const uint32_t graphInputCount = kernelInfo.GetInputCount();
|
||||
// Determine the last input which uses an initializer, so initializers can be freed incrementally
|
||||
// while processing each input in order.
|
||||
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
|
||||
for (uint32_t i = 0; i < graphInputCount; i++)
|
||||
{
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
if (iter != transferredInitializerMap.end()) {
|
||||
initializerToLastInputIndexMap[&iter->second] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Walk through each graph edge and mark used inputs
|
||||
inputsUsed.assign(graphInputCount, false);
|
||||
for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) {
|
||||
inputsUsed[edge.GraphInputIndex] = true;
|
||||
}
|
||||
for (uint32_t i = 0; i < initInputBindings.size(); i++)
|
||||
{
|
||||
// If the input isn't actually used by the graph, nothing ever needs to be bound (either for
|
||||
// initialization or execution). So just throw away the transferred initializer and skip this input.
|
||||
if (!inputsUsed[i])
|
||||
{
|
||||
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
|
||||
inputRawData.push_back(std::vector<std::byte>());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Look for the initializer among those transferred from the graph during partitioning
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
std::byte* tensorPtr = nullptr;
|
||||
size_t tensorByteSize = 0;
|
||||
std::unique_ptr<std::byte[]> unpackedTensor;
|
||||
|
||||
auto& initializer = iter->second;
|
||||
|
||||
// The tensor may be stored as raw data or in typed fields.
|
||||
if (initializer.has_raw_data())
|
||||
{
|
||||
tensorPtr = (std::byte*)(initializer.raw_data().c_str());
|
||||
tensorByteSize = initializer.raw_data().size();
|
||||
}
|
||||
else
|
||||
{
|
||||
std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer);
|
||||
tensorPtr = unpackedTensor.get();
|
||||
}
|
||||
|
||||
// Tensor sizes in DML must be a multiple of 4 bytes large.
|
||||
tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
|
||||
|
||||
inputRawData.push_back(std::vector<std::byte>(tensorPtr, tensorPtr + tensorByteSize));
|
||||
|
||||
if (!inputsConstant[i])
|
||||
{
|
||||
// Store the resource to use during execution
|
||||
ComPtr<ID3D12Resource> defaultBuffer = CreateResource(provider, tensorPtr, tensorByteSize);
|
||||
nonOwnedGraphInputsFromInitializers[i] = defaultBuffer;
|
||||
initializeResourceRefs.push_back(std::move(defaultBuffer));
|
||||
}
|
||||
else
|
||||
{
|
||||
ComPtr<ID3D12Resource> initializeInputBuffer;
|
||||
|
||||
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
|
||||
if (provider->IsMcdmDevice())
|
||||
{
|
||||
initializeInputBuffer = CreateResource(provider, tensorPtr, tensorByteSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
initializeInputBuffer = CreateCpuResource(provider, tensorPtr, tensorByteSize);
|
||||
}
|
||||
|
||||
// Set the binding for operator initialization to the buffer
|
||||
initInputBindings[i].Buffer = initializeInputBuffer.Get();
|
||||
initInputBindings[i].SizeInBytes = tensorByteSize;
|
||||
initializeResourceRefs.push_back(std::move(initializeInputBuffer));
|
||||
}
|
||||
|
||||
// Free the initializer if this is the last usage of it.
|
||||
if (initializerToLastInputIndexMap[&initializer] == i)
|
||||
{
|
||||
transferredInitializerMap.erase(iter);
|
||||
}
|
||||
}
|
||||
else if (inputsConstant[i])
|
||||
{
|
||||
const onnxruntime::Tensor* inputTensor = nullptr;
|
||||
THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor));
|
||||
|
||||
const std::byte* tensorData = reinterpret_cast<const std::byte*>(inputTensor->DataRaw());
|
||||
inputRawData.push_back(
|
||||
std::vector<std::byte>(tensorData, tensorData + inputTensor->SizeInBytes()));
|
||||
|
||||
uint64_t allocId;
|
||||
UnwrapTensor(winmlProvider, inputTensor, &initInputBindings[i].Buffer, &allocId);
|
||||
initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width;
|
||||
|
||||
initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference
|
||||
initInputResources.push_back(initInputBindings[i].Buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
inputRawData.push_back(std::vector<std::byte>());
|
||||
}
|
||||
}
|
||||
|
||||
// All initializers should have been consumed and freed above
|
||||
assert(transferredInitializerMap.empty());
|
||||
return inputRawData;
|
||||
}
|
||||
|
||||
void ConvertGraphDesc(
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
_Out_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
|
||||
_Out_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges)
|
||||
{
|
||||
const uint32_t graphInputCount = kernelInfo.GetInputCount();
|
||||
|
||||
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
|
||||
{
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{graphDesc.nodes[i].op.Get()};
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
|
||||
{
|
||||
dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
|
||||
{
|
||||
dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
|
||||
{
|
||||
dmlIntermediateEdges[i] =
|
||||
DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]};
|
||||
}
|
||||
|
||||
dmlGraphDesc.InputCount = graphInputCount;
|
||||
dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount();
|
||||
dmlGraphDesc.NodeCount = gsl::narrow_cast<uint32_t>(dmlGraphNodes.size());
|
||||
dmlGraphDesc.Nodes = dmlGraphNodes.data();
|
||||
dmlGraphDesc.InputEdgeCount = gsl::narrow_cast<uint32_t>(dmlInputEdges.size());
|
||||
dmlGraphDesc.InputEdges = dmlInputEdges.data();
|
||||
dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast<uint32_t>(dmlOutputEdges.size());
|
||||
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
|
||||
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
|
||||
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
|
||||
}
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
#include "GraphDescBuilder.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
namespace GraphKernelHelper
|
||||
{
|
||||
using namespace Windows::AI::MachineLearning::Adapter;
|
||||
|
||||
template <typename T>
|
||||
static T AlignToPow2(T offset, T alignment)
|
||||
{
|
||||
static_assert(std::is_unsigned_v<T>);
|
||||
assert(alignment != 0);
|
||||
assert((alignment & (alignment - 1)) == 0);
|
||||
return (offset + alignment - 1) & ~(alignment - 1);
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource>
|
||||
CreateResource(
|
||||
Dml::IExecutionProvider* provider,
|
||||
const std::byte* tensorPtr,
|
||||
size_t tensorByteSize);
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Resource>
|
||||
CreateCpuResource(
|
||||
Dml::IExecutionProvider* provider,
|
||||
const std::byte* tensorPtr,
|
||||
size_t tensorByteSize);
|
||||
|
||||
void UnwrapTensor(
|
||||
IWinmlExecutionProvider* winmlProvider,
|
||||
const onnxruntime::Tensor* tensor,
|
||||
ID3D12Resource** resource,
|
||||
uint64_t* allocId);
|
||||
|
||||
bool GetGraphInputConstness(
|
||||
uint32_t index,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
|
||||
|
||||
std::vector<std::vector<std::byte>> PopulateInputBindings(
|
||||
Dml::IExecutionProvider* provider,
|
||||
IWinmlExecutionProvider* winmlProvider,
|
||||
const std::vector<uint8_t>& inputsConstant,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
|
||||
_Out_ std::vector<bool>& inputsUsed,
|
||||
_Out_ std::vector<DML_BUFFER_BINDING>& initInputBindings,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& initInputResources,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& nonOwnedGraphInputsFromInitializers,
|
||||
_Out_ std::vector<ComPtr<ID3D12Resource>>& initializeResourceRefs,
|
||||
_Inout_ std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap);
|
||||
|
||||
void ConvertGraphDesc(
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
|
||||
const onnxruntime::OpKernelInfo& kernelInfo,
|
||||
_Out_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
|
||||
_Out_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
|
||||
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -139,7 +139,7 @@ namespace Dml
|
|||
}
|
||||
};
|
||||
|
||||
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats)
|
||||
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask)
|
||||
{
|
||||
if (arg->Exists())
|
||||
{
|
||||
|
|
@ -151,16 +151,23 @@ namespace Dml
|
|||
{
|
||||
// TODO: Remove this by handling zeroing on the output of fused graph nodes and handling of non-float
|
||||
// types in DML's identity operator, which is used for strided copies.
|
||||
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::UInt64 ||
|
||||
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::Int64)
|
||||
|
||||
MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));
|
||||
|
||||
if (mlDataType == MLOperatorTensorDataType::UInt64 ||
|
||||
mlDataType == MLOperatorTensorDataType::Int64)
|
||||
{
|
||||
return false;
|
||||
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
|
||||
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (requiresFloatFormats)
|
||||
{
|
||||
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float &&
|
||||
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float16)
|
||||
if (mlDataType != MLOperatorTensorDataType::Float &&
|
||||
mlDataType != MLOperatorTensorDataType::Float16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -172,14 +179,19 @@ namespace Dml
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration)
|
||||
bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration, uint32_t supportedDeviceDataTypeMask)
|
||||
{
|
||||
for (size_t i = 0; i < node.InputDefs().size(); ++i)
|
||||
{
|
||||
bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) !=
|
||||
registration.requiredConstantCpuInputs.end();
|
||||
|
||||
if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs))
|
||||
if (!isConstantCpuInput &&
|
||||
!NodeArgSupportedInGraph(
|
||||
node.InputDefs()[i],
|
||||
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -187,7 +199,11 @@ namespace Dml
|
|||
|
||||
for (auto arg : node.OutputDefs())
|
||||
{
|
||||
if (!NodeArgSupportedInGraph(arg, registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs))
|
||||
if (!NodeArgSupportedInGraph(
|
||||
arg,
|
||||
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -220,97 +236,155 @@ namespace Dml
|
|||
bool DoesNodeContainSupportedDataTypes(
|
||||
const onnxruntime::Node& node,
|
||||
bool allow64BitInputThroughStrides,
|
||||
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true
|
||||
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap, // Only used when allow64BitInputThroughStrides is true
|
||||
_In_opt_ const InternalRegistrationInfo* regInfo,
|
||||
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
)
|
||||
{
|
||||
THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
|
||||
|
||||
bool prefer64BitTensorsDirectly = false;
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false;
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
|
||||
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
|
||||
|
||||
if (regInfo != nullptr)
|
||||
{
|
||||
// Read the operator flags for handling 64-bit tensors and whether it's allowed to fall back
|
||||
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
|
||||
// in this particular call, then the operator-specific flags do not matter as the caller has
|
||||
// disabled 64-bit support.
|
||||
if (allow64BitInputThroughStrides)
|
||||
{
|
||||
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
}
|
||||
|
||||
// Collect the list of CPU-bound input tensors, needed when checking 64-bit fallback
|
||||
// or for other data types like int-8 which may be supported for CPU inputs but not
|
||||
// GPU inputs.
|
||||
auto inputDefinitions = node.InputDefs();
|
||||
for (uint32_t i : regInfo->requiredConstantCpuInputs)
|
||||
{
|
||||
if (i < inputDefinitions.size())
|
||||
{
|
||||
constantCpuInputs.push_back(inputDefinitions[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assume data types are supported until proven otherwise.
|
||||
bool nodeContainsSupportedDataTypes = true;
|
||||
|
||||
// Callback to check each node's data type.
|
||||
// Callback to check each node's data type against registered operator support.
|
||||
std::function<void(const onnxruntime::NodeArg& nodeArg, bool isInput)> nodeCallback = [&](const onnxruntime::NodeArg& nodeArg, bool isInput) -> void
|
||||
{
|
||||
// Get the tensor element data type for this node, comparing against what the device actually supports.
|
||||
// Use the enumeration from the proto instead of nodeArg.Type() which returns a string.
|
||||
|
||||
// Reject node if undefined data type or non-tensor, as DML cannot handle it.
|
||||
MLOperatorTensorDataType onnxElementType;
|
||||
if (TryGetTensorDataType(nodeArg, &onnxElementType))
|
||||
if (!TryGetTensorDataType(nodeArg, &onnxElementType))
|
||||
{
|
||||
DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType);
|
||||
if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN)
|
||||
// We shouldn't have arrived here because (1) no DML operators should have been
|
||||
// registered which use non-tensor types (2) ONNX validation should have already
|
||||
// been done, checking for the right kind of inputs and attributes. In theory,
|
||||
// this branch could be reached with a bad custom operator or malformed file. If
|
||||
// a legitimate case reaches here and DML needs to support a new input/output type
|
||||
// besides tensors, then remove the assert.
|
||||
assert(false);
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Reject node for unknown DML data types.
|
||||
DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType);
|
||||
if (dmlElementType == DML_TENSOR_DATA_TYPE_UNKNOWN)
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Succeed if the tensor is CPU-bound, as the CPU-side reading code is generic enough
|
||||
// to handle multiple types regardless of GPU capability (typically these are just
|
||||
// scalars or simple 1D arrays).
|
||||
bool isConstantCpuInput = isInput && std::find(constantCpuInputs.begin(), constantCpuInputs.end(), &nodeArg) != constantCpuInputs.end();
|
||||
if (isConstantCpuInput)
|
||||
{
|
||||
// Leave nodeContainsSupportedDataTypes alone.
|
||||
return;
|
||||
}
|
||||
|
||||
// If this operator implements 64-bit support in terms of strided 32-bit tensors,
|
||||
// then the data type needs to be remapped, regardless of whether input or output.
|
||||
//
|
||||
// Some operators can fairly safely implement 64-bit tensors in terms of
|
||||
// strided 32-bit tensors regardless of input tensor's execution provider
|
||||
// because the indices measure along a single axis and should fall within
|
||||
// the range of an int32/uint32.
|
||||
//
|
||||
// Currently all DML kernels outputting int64 and uint64 are expected to
|
||||
// not *introduce* values out of range, which allows the temporary trick
|
||||
// using strides to emulate 64 bit tensors to work. If the source is a CPU
|
||||
// operator, graph input or initializer, it's not safe to assume the input
|
||||
// can be represented with 32 bits.
|
||||
//
|
||||
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
|
||||
bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask);
|
||||
if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit)
|
||||
{
|
||||
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
|
||||
|
||||
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
|
||||
{
|
||||
if (((1 << dmlElementType) & supportedDeviceDataTypeMask) == 0)
|
||||
// Look up the input partition. If it's a graph input or initializer it will be missing
|
||||
// from the partition map.
|
||||
const std::string& argName = nodeArg.Name();
|
||||
|
||||
// If input tensor's data comes from the output of a different execution provider,
|
||||
// consider it unsafe to apply fallback to.
|
||||
auto partitionIter = nodeNameToPartitionMap->find(argName);
|
||||
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reject node if the data type is unsupported by the device.
|
||||
if (!((1 << dmlElementType) & supportedDeviceDataTypeMask))
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise the node supports the tensor data type.
|
||||
};
|
||||
|
||||
// Check whether the node uses any data types which are unsupported by the device.
|
||||
node.ForEachDef(nodeCallback);
|
||||
|
||||
// DML kernels supporting int64 and uint64 are expected to not *introduce* values out of range, which allows
|
||||
// the temporary trick using strides to emulate 64 bit tensors to work. If the source is a CPU operator,
|
||||
// graph input or initializer, it's not safe to assume the input can be represented with 32 bits.
|
||||
if (regInfo)
|
||||
{
|
||||
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
|
||||
{
|
||||
const auto* arg = node.InputDefs()[i];
|
||||
MLOperatorTensorDataType onnxElementType;
|
||||
if (arg->Exists() && TryGetTensorDataType(*arg, &onnxElementType))
|
||||
{
|
||||
if (((onnxElementType == MLOperatorTensorDataType::UInt64) || (onnxElementType == MLOperatorTensorDataType::Int64)))
|
||||
{
|
||||
// Look up the input partition. If it's a graph input or initializer it will be missing
|
||||
// from the map. In this case or if the input comes from a CPU partition, it might be
|
||||
// out of range.
|
||||
const std::string& argName = arg->Name();
|
||||
// Check if the operator handles the input on the CPU as a constant input
|
||||
bool isConstantCpuInput = std::find(regInfo->requiredConstantCpuInputs.begin(), regInfo->requiredConstantCpuInputs.end(), i) !=
|
||||
regInfo->requiredConstantCpuInputs.end();
|
||||
|
||||
if (!isConstantCpuInput)
|
||||
{
|
||||
if (!allow64BitInputThroughStrides)
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
break;
|
||||
}
|
||||
|
||||
auto partitionIter = nodeNameToPartitionMap->find(argName);
|
||||
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nodeContainsSupportedDataTypes;
|
||||
}
|
||||
|
||||
bool IsNodeSupportedByDml(
|
||||
const onnxruntime::Node& node,
|
||||
const onnxruntime::Node& node,
|
||||
const onnxruntime::KernelRegistry& registry,
|
||||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
const InternalRegistrationInfoMap& internalRegInfoMap,
|
||||
bool allow64BitInputThroughStrides,
|
||||
_In_opt_ const std::unordered_map<std::string, GraphPartition*>* nodeNameToPartitionMap
|
||||
)
|
||||
)
|
||||
{
|
||||
THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
|
||||
|
||||
const onnxruntime::KernelCreateInfo* createInfo;
|
||||
Status st = registry.TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo);
|
||||
if (!st.IsOK()) {
|
||||
return false;
|
||||
if (!st.IsOK())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get());
|
||||
|
|
@ -337,7 +411,7 @@ namespace Dml
|
|||
// Gets properties of the registration for a node
|
||||
void GetRegistrationProperties(
|
||||
const onnxruntime::GraphViewer& graph,
|
||||
const onnxruntime::Node& node,
|
||||
const onnxruntime::Node& node,
|
||||
const std::vector<const onnxruntime::KernelRegistry*>& dmlRegistries,
|
||||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
const InternalRegistrationInfoMap& internalRegInfoMap,
|
||||
|
|
@ -368,7 +442,9 @@ namespace Dml
|
|||
// which is required for MLGraph compilation.
|
||||
const onnxruntime::KernelCreateInfo* createInfo;
|
||||
if (!registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider, &createInfo).IsOK())
|
||||
continue;
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get());
|
||||
if (regInfoIter != internalRegInfoMap.end())
|
||||
|
|
@ -376,7 +452,7 @@ namespace Dml
|
|||
auto internalRegInfo = regInfoIter->second;
|
||||
|
||||
if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration &&
|
||||
NodeTensorTypesSupportedInGraph(node, *internalRegInfo))
|
||||
NodeTensorTypesSupportedInGraph(node, *internalRegInfo, supportedDeviceDataTypeMask))
|
||||
{
|
||||
bool requiredCpuInputsConstant = true;
|
||||
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
|
||||
|
|
@ -912,7 +988,8 @@ namespace Dml
|
|||
std::move(graphNodePropertyMap),
|
||||
registryForPartitionKernels,
|
||||
partitionKernelPrefix,
|
||||
transferredInitializerMap));
|
||||
transferredInitializerMap
|
||||
));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ namespace Dml
|
|||
);
|
||||
|
||||
bool IsNodeSupportedByDml(
|
||||
const onnxruntime::Node& node,
|
||||
const onnxruntime::Node& node,
|
||||
const onnxruntime::KernelRegistry& registry,
|
||||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap,
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ namespace Dml
|
|||
if (!IsNodeSupportedByDml(
|
||||
node,
|
||||
*registry,
|
||||
m_providerImpl->GetSuppportedDeviceDataTypeMask(),
|
||||
m_providerImpl->GetSupportedDeviceDataTypeMask(),
|
||||
*m_providerImpl->GetInternalRegistrationInfoMap().get(),
|
||||
allow64BitInputThroughStrides,
|
||||
nullptr))
|
||||
|
|
|
|||
|
|
@ -133,12 +133,21 @@ namespace Dml
|
|||
));
|
||||
}
|
||||
|
||||
void DmlOperator::Initialize(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
uint32_t minDimensionCount
|
||||
)
|
||||
{
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, std::nullopt, std::nullopt, minDimensionCount);
|
||||
}
|
||||
|
||||
void DmlOperator::Initialize(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices,
|
||||
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices,
|
||||
const std::optional<gsl::span<const uint32_t>> inputShape,
|
||||
const std::optional<gsl::span<const uint32_t>> outputShape
|
||||
const std::optional<gsl::span<const uint32_t>> outputShape,
|
||||
uint32_t minDimensionCount
|
||||
)
|
||||
{
|
||||
if (kernelInputIndices)
|
||||
|
|
@ -179,7 +188,7 @@ namespace Dml
|
|||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
inputShape,
|
||||
NchwDimensionCount));
|
||||
minDimensionCount));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -200,7 +209,8 @@ namespace Dml
|
|||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
outputShape));
|
||||
outputShape,
|
||||
minDimensionCount));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -373,6 +383,34 @@ namespace Dml
|
|||
));
|
||||
}
|
||||
|
||||
void DmlOperator::Remap64bitDmlDataTypesTo32bit()
|
||||
{
|
||||
for (auto& tensor : m_inputTensorDescs)
|
||||
{
|
||||
tensor.Remap64bitDmlDataTypeTo32bit();
|
||||
}
|
||||
|
||||
for (auto& tensor : m_outputTensorDescs)
|
||||
{
|
||||
tensor.Remap64bitDmlDataTypeTo32bit();
|
||||
}
|
||||
}
|
||||
|
||||
void DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded()
|
||||
{
|
||||
// Conditionally remap 64-bit data types to strided 32-bit if DML does not
|
||||
// support 64-bit data types directly on the device.
|
||||
|
||||
uint32_t deviceTypeMask = Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
|
||||
uint32_t deviceTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_INT64) | (1 << DML_TENSOR_DATA_TYPE_UINT64);
|
||||
|
||||
// If the device doesn't support 64-bit tensors, fall back to 32-bit with strides.
|
||||
if (!(deviceTypeMask & deviceTypeMask64bit))
|
||||
{
|
||||
Remap64bitDmlDataTypesTo32bit();
|
||||
}
|
||||
}
|
||||
|
||||
TensorDesc DmlOperator::CreateTensorDescFromInput(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
uint32_t index,
|
||||
|
|
|
|||
|
|
@ -29,12 +29,18 @@ namespace Dml
|
|||
ComPtr<IUnknown> m_persistentResourcePoolingUnk; // Controls when the persistent resource is returned to the pool
|
||||
std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
|
||||
|
||||
void Initialize(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
uint32_t minDimensionCount
|
||||
);
|
||||
|
||||
void Initialize(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices = std::nullopt,
|
||||
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices = std::nullopt,
|
||||
const std::optional<gsl::span<const uint32_t>> inputShape = std::nullopt,
|
||||
const std::optional<gsl::span<const uint32_t>> outputShape = std::nullopt
|
||||
const std::optional<gsl::span<const uint32_t>> outputShape = std::nullopt,
|
||||
uint32_t minDimensionCount = NchwDimensionCount
|
||||
);
|
||||
|
||||
bool AllowHalfPrecisionComputation() const;
|
||||
|
|
@ -77,6 +83,11 @@ namespace Dml
|
|||
|
||||
void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);
|
||||
|
||||
// Remap 64-bit data types to 32-bit via doubled strides.
|
||||
// These should be called before GetDmlInputDescs or GetDmlOutputDescs.
|
||||
void Remap64bitDmlDataTypesTo32bit();
|
||||
void Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
TensorDesc CreateTensorDescFromInput(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
uint32_t index,
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ public:
|
|||
switch (operatorType)
|
||||
{
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
case DML_OPERATOR_ACTIVATION_CELU:
|
||||
operatorDesc.elu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
|
||||
break;
|
||||
|
||||
|
|
@ -154,6 +155,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(HardSigmoid, DmlOperatorActivationTempla
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Tanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_TANH>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScaledTanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SCALED_TANH>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Relu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_RELU>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Celu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_CELU>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(LeakyRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_LEAKY_RELU>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(PRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ThresholdedRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU>);
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ public:
|
|||
for (size_t i = 0; i < m_inputTensorDescs.size(); i++)
|
||||
{
|
||||
// DML doesn't support empty tensors for concat, so we ignore them
|
||||
if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetDmlSizes()))
|
||||
if (!OperatorHelper::ContainsEmptyDimensions(m_inputTensorDescs[i].GetSizes()))
|
||||
{
|
||||
inputDescs.push_back(m_inputTensorDescs[i].GetDmlDesc());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,176 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
|
||||
{
|
||||
public:
|
||||
DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
|
||||
: DmlOperator(kernelCreationContext),
|
||||
EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "EinSum expects one output tensor.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(8), "Update this switch.");
|
||||
switch (m_recognizedOperatorType)
|
||||
{
|
||||
case RecognizedOperatorType::Multiply:
|
||||
{
|
||||
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.ATensor = &inputDescs[0];
|
||||
operatorDesc.BTensor = &inputDescs[1];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &operatorDesc}, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::MatMul:
|
||||
case RecognizedOperatorType::MatMulTransposeA:
|
||||
case RecognizedOperatorType::MatMulTransposeB:
|
||||
{
|
||||
DML_GEMM_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.ATensor = &inputDescs[0];
|
||||
operatorDesc.BTensor = &inputDescs[1];
|
||||
// No operatorDesc.CTensor
|
||||
operatorDesc.OutputTensor = &outputDescs[0];
|
||||
operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
|
||||
operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
|
||||
operatorDesc.Alpha = 1.0;
|
||||
operatorDesc.Beta = 0.0;
|
||||
operatorDesc.FusedActivation = nullptr;
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::ReduceSum:
|
||||
{
|
||||
// Get how many axes are kept in the final output, either 0 or 1 supported
|
||||
// meaning full reduction or partial with one dimension left. *It could be
|
||||
// generalized to support any number of output dimensions, but it would need
|
||||
// to accomodate for Transposition too if the output labels are reordered.
|
||||
auto keptAxes = m_components.back().GetLabels(m_labelIndices);
|
||||
assert(keptAxes.size() <= 1);
|
||||
|
||||
// DML expects output rank to match input rank (as if ONNX ReduceSum keepdims=1).
|
||||
// So replace the existing tensor description with the input sizes, except that
|
||||
// reduced dimensions have size 1.
|
||||
std::vector<uint32_t> reducedAxes;
|
||||
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
std::vector<uint32_t> outputSizes = inputSizes;
|
||||
|
||||
// Determine which axes are being reduced by taking the opposite of those kept.
|
||||
uint32_t keptAxesMask = 0;
|
||||
for (auto axis : keptAxes)
|
||||
{
|
||||
keptAxesMask |= (1 << axis);
|
||||
}
|
||||
for (uint32_t axis = 0, axisCount = static_cast<uint32_t>(outputSizes.size()); axis < axisCount; ++axis)
|
||||
{
|
||||
if (~keptAxesMask & (1<<axis))
|
||||
{
|
||||
reducedAxes.push_back(axis);
|
||||
outputSizes[axis] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), inputSizes, std::nullopt, 0);
|
||||
m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), outputSizes, std::nullopt, 0);
|
||||
m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
|
||||
DML_REDUCE_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = inputDescs.data();
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.Function = DML_REDUCE_FUNCTION_SUM;
|
||||
operatorDesc.Axes = reducedAxes.data();
|
||||
operatorDesc.AxisCount = gsl::narrow_cast<uint32_t>(reducedAxes.size());
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_REDUCE, &operatorDesc }, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::Transpose:
|
||||
case RecognizedOperatorType::Identity:
|
||||
{
|
||||
if (m_recognizedOperatorType == RecognizedOperatorType::Transpose)
|
||||
{
|
||||
// Transpose via input strides. The output tensor is not strided.
|
||||
assert(m_components.front().GetDimensionCount() == m_components.back().GetDimensionCount());
|
||||
auto originalStrides = m_inputTensorDescs.front().GetStrides();
|
||||
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
std::vector<uint32_t> inputStrides(inputSizes.size());
|
||||
|
||||
// If there were no strides, compute them based in descending packed order
|
||||
// based on the input sizes.
|
||||
if (originalStrides.empty())
|
||||
{
|
||||
Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides);
|
||||
}
|
||||
else // Copy the original strides.
|
||||
{
|
||||
assert(originalStrides.size() >= inputStrides.size());
|
||||
size_t offset = originalStrides.size() - inputStrides.size();
|
||||
inputStrides.assign(originalStrides.begin() + offset, originalStrides.end());
|
||||
}
|
||||
|
||||
// Remap transposed strides using the component labels from input to output.
|
||||
auto labelIndices = m_components.back().GetLabels(m_labelIndices);
|
||||
|
||||
std::vector<uint32_t> newStrides(inputStrides.size());
|
||||
std::vector<uint32_t> newSizes(inputStrides.size());
|
||||
for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i)
|
||||
{
|
||||
uint32_t labelIndex = labelIndices[i];
|
||||
assert(labelIndex < inputStrides.size());
|
||||
newSizes[i] = inputSizes[labelIndex];
|
||||
newStrides[i] = inputStrides[labelIndex];
|
||||
}
|
||||
|
||||
// Override the initial input tensor with the new strides.
|
||||
m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), newSizes, newStrides, 0);
|
||||
m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newSizes, std::nullopt, 0);
|
||||
m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
}
|
||||
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = inputDescs.data();
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &operatorDesc}, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
|
||||
MLOperatorAttributes attributes(context);
|
||||
EinSumHelper helper(attributes);
|
||||
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
|
||||
|
||||
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(8), "Verify this test still matches the switch above.");
|
||||
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Einsum12, VersionedKernel<DmlOperatorEinSum, 12>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -462,6 +462,9 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Same operator signature as 11. Only difference is new type support
|
||||
using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11;
|
||||
|
||||
class DmlOperatorElementwisePow : public DmlOperator
|
||||
{
|
||||
public:
|
||||
|
|
@ -700,6 +703,8 @@ DML_OP_DEFINE_CREATION_FUNCTION(Erf, DmlOperatorElementwiseUnary<DM
|
|||
// Binary operators:
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Greater, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Less, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(GreaterOrEqual, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(LessOrEqual, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Equal, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(And, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Or, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC>);
|
||||
|
|
@ -718,6 +723,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
|
|||
// Operators with extra attributes:
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ public:
|
|||
TensorDesc inputTensorDesc =
|
||||
TensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(0).tensorDataType,
|
||||
m_outputTensorDescs[0].GetDmlSizes(),
|
||||
m_inputTensorDescs[0].GetDmlSizes(),
|
||||
m_outputTensorDescs[0].GetSizes(),
|
||||
m_inputTensorDescs[0].GetSizes(),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
|
|
@ -36,8 +36,8 @@ public:
|
|||
TensorDesc outputTensorDesc =
|
||||
TensorDesc(
|
||||
kernelCreationContext.GetOutputEdgeDescription(0).tensorDataType,
|
||||
m_outputTensorDescs[0].GetDmlSizes(),
|
||||
m_outputTensorDescs[0].GetDmlSizes(),
|
||||
m_outputTensorDescs[0].GetSizes(),
|
||||
m_outputTensorDescs[0].GetSizes(),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
|
|
|
|||
|
|
@ -16,19 +16,21 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "Gather expects 2 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Gather expects 1 output.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
|
||||
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
|
||||
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
|
||||
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 2);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
|
||||
|
||||
DML_GATHER_OPERATOR_DESC operatorDesc = {};
|
||||
|
|
@ -52,20 +54,22 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
|
||||
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
|
||||
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
|
||||
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 2);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
|
||||
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
|
||||
|
||||
DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {};
|
||||
|
|
@ -89,29 +93,30 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
|
||||
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
|
||||
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
|
||||
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 2);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
|
||||
DML_GATHER_ND_OPERATOR_DESC operatorDesc = {};
|
||||
DML_GATHER_ND1_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.IndicesTensor = &inputDescs[1];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
|
||||
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
|
||||
operatorDesc.BatchDimensionCount = m_batchCount;
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc };
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND1, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ public:
|
|||
std::vector<std::optional<uint32_t>> inputIndices = { 0, 1 }; // The 3rd tensor ('output_shape') is not bound, just 'X' and 'I' indices.
|
||||
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
|
||||
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bit();
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType(); // MaxUnpool accepts uint32_t.
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ public:
|
|||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
1, // minDimensionCount
|
||||
0
|
||||
);
|
||||
|
||||
|
|
@ -49,10 +49,12 @@ public:
|
|||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
1, // minDimensionCount
|
||||
0
|
||||
);
|
||||
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
// Adjust the axis so it's in DML's terms rather than the original ONNX indexing.
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(
|
||||
m_absoluteAxis,
|
||||
|
|
|
|||
|
|
@ -89,18 +89,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// A specific type of operation for registration.
|
||||
template <uint32_t opsetVersion>
|
||||
class DmlOperatorPaddingTemplate : public DmlOperatorPadding
|
||||
{
|
||||
public:
|
||||
DmlOperatorPaddingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperatorPadding(kernelInfo, opsetVersion)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, DmlOperatorPaddingTemplate<7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, DmlOperatorPaddingTemplate<11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -108,7 +108,14 @@ public:
|
|||
if (hasOutputIndices || hasDilations)
|
||||
{
|
||||
DML_MAX_POOLING2_OPERATOR_DESC desc = {};
|
||||
desc.OutputIndicesTensor = hasOutputIndices ? &outputDescs[1] : nullptr;
|
||||
|
||||
if (hasOutputIndices)
|
||||
{
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bit();
|
||||
m_outputTensorDescs[1].ForceUnsignedDataType(); // MaxPool accepts uint32_t.
|
||||
desc.OutputIndicesTensor = &outputDescs[1];
|
||||
}
|
||||
|
||||
desc.Dilations = m_kernel.dilations;
|
||||
SetOpDesc(desc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,17 +22,12 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
|
||||
// Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output.
|
||||
if ((function == DML_REDUCE_FUNCTION_ARGMAX) || (function == DML_REDUCE_FUNCTION_ARGMIN))
|
||||
{
|
||||
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> dmlAxes;
|
||||
std::vector<DimensionType> reducedDims = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
int dimOffset = gsl::narrow_cast<int>(OperatorHelper::NchwDimensionCount - reducedDims.size());
|
||||
int dimOffset = gsl::narrow_cast<int>(m_inputTensorDescs[0].GetDimensionCount() - reducedDims.size());
|
||||
for (auto& dim : m_axes)
|
||||
{
|
||||
assert(dim < reducedDims.size()); // ReduceHelperBase already validated this.
|
||||
reducedDims[dim] = 1;
|
||||
dmlAxes.push_back(static_cast<uint32_t>(dim + dimOffset));
|
||||
}
|
||||
|
|
@ -62,15 +57,59 @@ public:
|
|||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_REDUCE_OPERATOR_DESC reduceDesc = {};
|
||||
reduceDesc.InputTensor = inputDescs.data();
|
||||
reduceDesc.OutputTensor = outputDescs.data();
|
||||
reduceDesc.Function = function;
|
||||
reduceDesc.Axes = dmlAxes.data();
|
||||
reduceDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
// Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output.
|
||||
if (function == DML_REDUCE_FUNCTION_ARGMAX)
|
||||
{
|
||||
DML_ARGMAX_OPERATOR_DESC argmaxDesc;
|
||||
argmaxDesc.AxisDirection = static_cast<DML_AXIS_DIRECTION>(m_selectLastIndex);
|
||||
argmaxDesc.InputTensor = inputDescs.data();
|
||||
argmaxDesc.OutputTensor = outputDescs.data();
|
||||
argmaxDesc.Axes = dmlAxes.data();
|
||||
argmaxDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
|
||||
// of each element. If the device directly supports 64-bit elements, then no need.
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
|
||||
{
|
||||
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMAX, &argmaxDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
else if (function == DML_REDUCE_FUNCTION_ARGMIN)
|
||||
{
|
||||
DML_ARGMIN_OPERATOR_DESC argminDesc;
|
||||
argminDesc.AxisDirection = static_cast<DML_AXIS_DIRECTION>(m_selectLastIndex);
|
||||
argminDesc.InputTensor = inputDescs.data();
|
||||
argminDesc.OutputTensor = outputDescs.data();
|
||||
argminDesc.Axes = dmlAxes.data();
|
||||
argminDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
|
||||
// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
|
||||
// of each element. If the device directly supports 64-bit elements, then no need.
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
|
||||
{
|
||||
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMIN, &argminDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_REDUCE_OPERATOR_DESC reduceDesc = {};
|
||||
reduceDesc.InputTensor = inputDescs.data();
|
||||
reduceDesc.OutputTensor = outputDescs.data();
|
||||
reduceDesc.Function = function;
|
||||
reduceDesc.Axes = dmlAxes.data();
|
||||
reduceDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(const MLOperatorKernelContext& kernelContext) override
|
||||
|
|
|
|||
|
|
@ -251,17 +251,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// A specific type of operation for registration.
|
||||
template <uint32_t OpsetVersion>
|
||||
struct DmlOperatorResizeTemplate : public DmlOperatorResize
|
||||
{
|
||||
public:
|
||||
DmlOperatorResizeTemplate(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperatorResize(kernelInfo, OpsetVersion)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
|
|
@ -304,10 +293,10 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool*
|
|||
*isSupported = true;
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Resize10, DmlOperatorResizeTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Resize11, DmlOperatorResizeTemplate<11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, DmlOperatorResizeTemplate<7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, DmlOperatorResizeTemplate<9>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, DmlOperatorResizeTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel<DmlOperatorResize, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel<DmlOperatorResize, 11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel<DmlOperatorResize, 7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel<DmlOperatorResize, 9>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel<DmlOperatorResize, 10>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ public:
|
|||
0
|
||||
);
|
||||
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bit();
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType(); // DML operator accepts uint32_t.
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelper
|
||||
{
|
||||
public:
|
||||
using Self = DmlOperatorRegionOfInterestAlign;
|
||||
|
||||
DmlOperatorRegionOfInterestAlign(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
|
||||
: DmlOperator(kernelCreationContext),
|
||||
RoiAlignHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "RoiAlign expects 3 input tensors.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "RoiAlign expects 1 output tensor.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bit();
|
||||
m_inputTensorDescs[2].ForceUnsignedDataType();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
constexpr NameAndIndex mapping[] =
|
||||
{
|
||||
{"max", DML_REDUCE_FUNCTION_MAX},
|
||||
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
|
||||
};
|
||||
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
|
||||
const auto reductionFunction = MapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
|
||||
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
|
||||
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
|
||||
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
|
||||
|
||||
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.ROITensor = &inputDescs[1];
|
||||
operatorDesc.BatchIndicesTensor = &inputDescs[2];
|
||||
operatorDesc.OutputTensor = &outputDescs[0];
|
||||
operatorDesc.SpatialScaleX = spatialScale; // ONNX uses the same scale for X and Y.
|
||||
operatorDesc.SpatialScaleY = spatialScale;
|
||||
operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds.
|
||||
operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput;
|
||||
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
|
||||
operatorDesc.ReductionFunction = reductionFunction;
|
||||
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
|
||||
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -26,7 +26,7 @@ public:
|
|||
poolingDesc.ROITensor = &inputDescs[1];
|
||||
poolingDesc.OutputTensor = &outputDescs[0];
|
||||
poolingDesc.SpatialScale = m_spatialScale;
|
||||
poolingDesc.PooledSize = { m_pooledSizeH, m_pooledSizeW };
|
||||
poolingDesc.PooledSize = { m_outputSizeH, m_outputSizeW };
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_POOLING, &poolingDesc };
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size());
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
|
||||
// When the indices tensor is empty, Scatter is basically Identity. But since DML doesn't support empty or null
|
||||
// tensors, we have to special-case it outside of DML.
|
||||
|
|
@ -31,6 +30,7 @@ public:
|
|||
{
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices(1, 0);
|
||||
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
|
@ -49,14 +49,13 @@ public:
|
|||
else
|
||||
{
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 3);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
// Read the axis.
|
||||
int onnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(onnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
|
||||
|
|
@ -89,20 +88,16 @@ public:
|
|||
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
size_t dimensionCountMax = std::max({dataDimensions.size(), updatesDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
|
||||
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 3);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.IndicesTensor = &inputDescs[1];
|
||||
|
|
|
|||
|
|
@ -47,23 +47,12 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// A specific type of operation for registration.
|
||||
template <uint32_t opsetVersion>
|
||||
class DmlOperatorSliceTemplate : public DmlOperatorSlice
|
||||
{
|
||||
public:
|
||||
DmlOperatorSliceTemplate(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperatorSlice(kernelInfo, opsetVersion)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
|
||||
{
|
||||
*isSupported = (context->GetInputCount() <= 5);
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel<DmlOperatorSlice, 7> );
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel<DmlOperatorSlice, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel<DmlOperatorSlice, 11>);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ public:
|
|||
std::vector<std::optional<uint32_t>> inputIndices = { 0 }; // Use only the first tensor. The second tensor is CPU-based.
|
||||
std::vector<std::optional<uint32_t>> outputIndices = { 0, 1 };
|
||||
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
|
||||
DmlOperator::Remap64bitDmlDataTypesTo32bit();
|
||||
m_outputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
|
@ -70,19 +72,8 @@ private:
|
|||
ComPtr<IDMLCompiledOperator> m_zeroOperator;
|
||||
};
|
||||
|
||||
// A specific type of operation for registration.
|
||||
template <uint32_t OpsetVersion>
|
||||
class DmlOperatorTopKTemplate : public DmlOperatorTopK
|
||||
{
|
||||
public:
|
||||
DmlOperatorTopKTemplate(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperatorTopK(kernelInfo, OpsetVersion)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK7, DmlOperatorTopKTemplate<7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK10, DmlOperatorTopKTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK11, DmlOperatorTopKTemplate<11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK7, VersionedKernel<DmlOperatorTopK, 7 >);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK10, VersionedKernel<DmlOperatorTopK, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(TopK11, VersionedKernel<DmlOperatorTopK, 11>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -19,41 +19,25 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() >= 1);
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
|
||||
const MLOperatorEdgeDescription inputEdgeDescription = kernelInfo.GetInputEdgeDescription(0);
|
||||
|
||||
const std::vector<uint32_t> originalSizes = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(m_permutations.size() == originalSizes.size());
|
||||
|
||||
// Calculate strides from original shape.
|
||||
ML_CHECK_VALID_ARGUMENT(!originalSizes.empty());
|
||||
std::vector<uint32_t> inputStrides(originalSizes.size());
|
||||
inputStrides.back() = 1;
|
||||
for (int i = gsl::narrow_cast<int>(inputStrides.size()) - 2; i >= 0; i--)
|
||||
{
|
||||
inputStrides[i] = inputStrides[i + 1] * gsl::narrow_cast<uint32_t>(originalSizes[i + 1]);
|
||||
}
|
||||
Dml::GetDescendingPackedStrides(originalSizes, /*out*/ inputStrides);
|
||||
|
||||
const int leadingDims = gsl::narrow_cast<int32_t>(m_inputTensorDescs.front().GetDimensionCount() - originalSizes.size());
|
||||
|
||||
std::vector<uint32_t> sizes(m_inputTensorDescs.front().GetDimensionCount());
|
||||
std::vector<uint32_t> strides(m_inputTensorDescs.front().GetDimensionCount());
|
||||
|
||||
// Fill leading tensor desc sizes/strides with defaults.
|
||||
for (int dimDML = 0; dimDML < leadingDims; ++dimDML)
|
||||
{
|
||||
sizes[dimDML] = 1;
|
||||
strides[dimDML] = 0;
|
||||
}
|
||||
std::vector<uint32_t> sizes(inputStrides.size());
|
||||
std::vector<uint32_t> strides(inputStrides.size());
|
||||
|
||||
// Permute the shape and strides.
|
||||
for (int dimInput = 0, dimCount = gsl::narrow_cast<int>(originalSizes.size()); dimInput < dimCount; ++dimInput)
|
||||
{
|
||||
int dimDML = dimInput + leadingDims;
|
||||
int dimPermuted = m_permutations[dimInput];
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(gsl::narrow_cast<size_t>(dimPermuted) < originalSizes.size());
|
||||
sizes[dimDML] = gsl::narrow_cast<int32_t>(originalSizes[dimPermuted]);
|
||||
strides[dimDML] = inputStrides[dimPermuted];
|
||||
sizes[dimInput] = originalSizes[dimPermuted];
|
||||
strides[dimInput] = inputStrides[dimPermuted];
|
||||
}
|
||||
|
||||
// Override the initial tensor descs. The output tensor is not strided.
|
||||
|
|
|
|||
|
|
@ -34,22 +34,28 @@ enum class SupportedTensorDataTypes : uint32_t
|
|||
UInt64 = 1<<13,
|
||||
Complex64 = 1<<14,
|
||||
Complex128 = 1<<15,
|
||||
Int8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Int32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
|
||||
NumericDefault = Int8to32|Float16to32,
|
||||
NumericDefault = Ints8to32|Float16to32,
|
||||
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
||||
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
|
||||
Ints8Bit = UInt8|Int8,
|
||||
Ints16Bit = UInt16|Int16,
|
||||
Ints32Bit = UInt32|Int32,
|
||||
All = static_cast<uint32_t>(-1),
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
|
||||
|
||||
enum class DmGraphSupport
|
||||
enum class DmlGraphSupport : uint32_t
|
||||
{
|
||||
Supported = 0,
|
||||
NotSupported = 1,
|
||||
Supported = 0,
|
||||
NotSupported = 1,
|
||||
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
|
||||
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
|
||||
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
|
||||
|
||||
struct OperatorRegistrationInformation
|
||||
{
|
||||
|
|
@ -63,7 +69,7 @@ struct OperatorRegistrationInformation
|
|||
|
||||
gsl::span<char const* const> tensorTypeNames;
|
||||
gsl::span<const SupportedTensorDataTypes> supportedTensorDataTypes;
|
||||
DmGraphSupport DmGraphSupport;
|
||||
DmlGraphSupport dmlGraphSupport;
|
||||
|
||||
std::pair<std::array<const uint32_t, 4>, int> requiredConstantCpuInputs = {{}, 0};
|
||||
|
||||
|
|
@ -86,6 +92,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
|
||||
|
|
@ -117,8 +124,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Ceil);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Floor);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip12);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Greater);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Less);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LessOrEqual);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Equal);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Not);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(And);
|
||||
|
|
@ -133,6 +143,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Mean);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Max);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Min);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ReduceSum);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Einsum12);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ReduceMean);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ReduceProd);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ReduceLogSum);
|
||||
|
|
@ -160,6 +171,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LeakyRelu);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(PRelu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ThresholdedRelu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Elu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Celu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Selu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Softmax);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax);
|
||||
|
|
@ -233,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
|
|||
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Resize);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
|
||||
|
||||
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
|
||||
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
|
||||
|
|
@ -240,6 +253,7 @@ constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T
|
|||
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
|
||||
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
|
||||
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
|
||||
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
|
||||
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
|
||||
constexpr static std::array<const char*, 2> typeNameListScatterGather = { "T", "Tind" };
|
||||
constexpr static std::array<const char*, 1> typeNameListScatterGatherND = { "T" }; // Tind is curiously missing, only allowing 64-bit.
|
||||
|
|
@ -249,15 +263,18 @@ constexpr static std::array<const char*, 1> typeNameListEyeLike = { "T2" };
|
|||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
|
|
@ -275,7 +292,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogica
|
|||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListPadWithoutFloat16 = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
|
|
@ -305,7 +322,7 @@ constexpr auto requiredConstantCpuInputs(Args... args)
|
|||
|
||||
// Identity operators use Copy, alias their first input, and require floating point formats
|
||||
// for usage in the graph, besides constant inputs. This is because they currently use
|
||||
// element-wise identity operators in the graph for striding support, but issue actual copies
|
||||
// element-wise identity operators in the graph for striding support, but issue actual copies
|
||||
// outside the graph. Element-wise identity currently only supports floating point types.
|
||||
#define REG_INFO_ID(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, true, ##__VA_ARGS__,
|
||||
|
|
@ -325,230 +342,244 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
/// Support query function
|
||||
|
||||
// Deep Learning Standard Layers
|
||||
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
|
||||
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
|
||||
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
|
||||
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
|
||||
// Data Reorganization Layers
|
||||
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
|
||||
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
||||
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
|
||||
|
||||
// Data reorganization that merely changes the dimensions while keeping the data identical.
|
||||
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Elementwise
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
|
||||
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
|
||||
// Imaging Operators
|
||||
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
||||
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
||||
|
||||
// Activation Functions
|
||||
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
|
||||
// Uncategorized
|
||||
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Fused operators
|
||||
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
|
||||
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
|
||||
|
||||
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmGraphSupport::NotSupported)},
|
||||
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -572,10 +603,13 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
MLOperatorKernelDescription desc = {};
|
||||
desc.domain = information.domain;
|
||||
desc.name = information.operatorName;
|
||||
desc.executionType = MLOperatorExecutionType::D3D12;
|
||||
desc.executionType = MLOperatorExecutionType::D3D12;
|
||||
|
||||
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
|
||||
bool kernelSupportsGraph = (information.DmGraphSupport == DmGraphSupport::Supported);
|
||||
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
|
||||
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
|
||||
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
|
||||
|
||||
desc.options = information.shapeInferenceFunction ?
|
||||
MLOperatorKernelOptions::None : MLOperatorKernelOptions::AllowDynamicInputShapes;
|
||||
|
|
@ -651,6 +685,9 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
kernelSupportsGraph, // supportsGraph
|
||||
information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr,
|
||||
information.requiresFloatFormatsForGraph,
|
||||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
information.requiredConstantCpuInputs.first.data(),
|
||||
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
|
||||
));
|
||||
|
|
|
|||
|
|
@ -11,6 +11,19 @@ class MLOperatorKernelCreationContext;
|
|||
#define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
|
||||
#define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void CALLBACK Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported);
|
||||
|
||||
// A specific opset version for registration.
|
||||
// e.g.
|
||||
// DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorSlice, 10>);
|
||||
template <typename BaseClass, uint32_t opsetVersion>
|
||||
class VersionedKernel : public BaseClass
|
||||
{
|
||||
public:
|
||||
VersionedKernel(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: BaseClass(kernelInfo, opsetVersion)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Declares a callback creation function of the given operator class.
|
||||
// This does not register it, just declares it for usage by registration later.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ namespace Dml
|
|||
switch (function)
|
||||
{
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
case DML_OPERATOR_ACTIVATION_CELU:
|
||||
return 1.0f;
|
||||
|
||||
case DML_OPERATOR_ACTIVATION_LEAKY_RELU:
|
||||
|
|
@ -358,37 +359,45 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
{
|
||||
for (auto& nameAndIndex : nameAndIndexList)
|
||||
{
|
||||
if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0)
|
||||
{
|
||||
return nameAndIndex.index;
|
||||
}
|
||||
}
|
||||
|
||||
ML_INVALID_ARGUMENT("Unknown mode value.");
|
||||
}
|
||||
|
||||
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode)
|
||||
{
|
||||
// The ONNX modes are "nearest" and "linear." Other modes exist for compatibility,
|
||||
// since Winml supported them in the past.
|
||||
if (mode == "NEAREST" || mode == "nearest" || mode == "nn" || mode == "NN") {
|
||||
return DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR;
|
||||
}
|
||||
else if (mode == "BILINEAR" || mode == "bilinear" || mode == "linear")
|
||||
|
||||
constexpr NameAndIndex mapping[] =
|
||||
{
|
||||
return DML_INTERPOLATION_MODE_LINEAR;
|
||||
}
|
||||
else
|
||||
{
|
||||
ML_INVALID_ARGUMENT("Unknown sampling interpolation mode.");
|
||||
}
|
||||
{"nearest", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
|
||||
{"NEAREST", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
|
||||
{"NN", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
|
||||
{"nn", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
|
||||
{"linear", DML_INTERPOLATION_MODE_LINEAR},
|
||||
{"BILINEAR", DML_INTERPOLATION_MODE_LINEAR},
|
||||
{"bilinear", DML_INTERPOLATION_MODE_LINEAR},
|
||||
};
|
||||
return MapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping);
|
||||
}
|
||||
|
||||
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode)
|
||||
{
|
||||
if (mode == "DCR")
|
||||
constexpr NameAndIndex mapping[] =
|
||||
{
|
||||
return DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW;
|
||||
}
|
||||
else if (mode == "CRD")
|
||||
{
|
||||
return DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH;
|
||||
}
|
||||
else
|
||||
{
|
||||
ML_INVALID_ARGUMENT("Unknown depth space mode.");
|
||||
}
|
||||
{"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW},
|
||||
{"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH},
|
||||
};
|
||||
return MapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping);
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -57,6 +57,20 @@ namespace Dml
|
|||
|
||||
void GetDmlAdjustedAxes(/*inout*/ gsl::span<const int32_t> axes, uint32_t onnxDimCount, uint32_t dmlDimCount, std::vector<uint32_t>& dmlAxes);
|
||||
|
||||
struct NameAndIndex
|
||||
{
|
||||
const char* name; // Null terminated.
|
||||
uint32_t index;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
T MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
{
|
||||
return static_cast<T>(MapStringToIndex(mode, nameAndIndexList));
|
||||
}
|
||||
|
||||
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
|
||||
|
||||
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode);
|
||||
|
||||
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode);
|
||||
|
|
|
|||
|
|
@ -184,53 +184,6 @@ TensorDesc::TensorDesc(
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
// Handle 64-bit tensors.
|
||||
|
||||
uint64_t endPaddingInBytes = 0;
|
||||
|
||||
if (dataType == MLOperatorTensorDataType::UInt64 || dataType == MLOperatorTensorDataType::Int64)
|
||||
{
|
||||
// DirectML doesn't support tensor of int64 because Direct3D doesn't support
|
||||
// the data type. A workaround is to use strides to fake 64-bit memory access
|
||||
// while only the lower 32 bits contains the data. This trick obviously doesn't
|
||||
// work if the data element is genuine 64-bit. It also doesn't work if the data
|
||||
// element is negative as the signed bit will be incorrectly interpreted.
|
||||
m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT32;
|
||||
|
||||
// If the strides haven't been calculated yet, initialize them as packed.
|
||||
if (!useStrides)
|
||||
{
|
||||
uint32_t stride = 1;
|
||||
for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--)
|
||||
{
|
||||
m_strides[i] = stride;
|
||||
stride *= m_sizes[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Double the stride values to emulate 64-bit integer support.
|
||||
for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i)
|
||||
{
|
||||
m_strides[i] *= 2;
|
||||
}
|
||||
|
||||
useStrides = true;
|
||||
|
||||
// The physical size of the tensor will have an extra 4 bytes at the end.
|
||||
// DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last
|
||||
// addressable element of the tensor plus the space for the last element. However, the size
|
||||
// of the last element is now halved from 8 bytes to 4 bytes.
|
||||
//
|
||||
// Example:
|
||||
// Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48
|
||||
// Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44
|
||||
//
|
||||
// DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate
|
||||
// so that the entire region can be zeroed.
|
||||
endPaddingInBytes = sizeof(uint32_t);
|
||||
}
|
||||
|
||||
if (useStrides)
|
||||
{
|
||||
m_bufferTensorDesc.Strides = m_strides;
|
||||
|
|
@ -239,20 +192,84 @@ TensorDesc::TensorDesc(
|
|||
m_bufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE;
|
||||
m_bufferTensorDesc.GuaranteedBaseOffsetAlignment = guaranteedBaseOffsetAlignment;
|
||||
m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
|
||||
m_bufferTensorDesc.DataType,
|
||||
m_bufferTensorDesc.DimensionCount,
|
||||
m_sizes,
|
||||
m_bufferTensorDesc.DataType,
|
||||
m_bufferTensorDesc.DimensionCount,
|
||||
m_sizes,
|
||||
useStrides ? m_strides : nullptr
|
||||
);
|
||||
assert(m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType));
|
||||
}
|
||||
|
||||
void TensorDesc::Remap64bitDmlDataTypeTo32bit()
|
||||
{
|
||||
if (m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_UINT64 &&
|
||||
m_bufferTensorDesc.DataType != DML_TENSOR_DATA_TYPE_INT64)
|
||||
{
|
||||
return; // Nothing to do.
|
||||
}
|
||||
|
||||
uint64_t endPaddingInBytes = 0;
|
||||
|
||||
// A workaround for older devices is to use strides to fake 64-bit memory access
|
||||
// while only the lower 32 bits contains the data. This trick obviously doesn't
|
||||
// work if the data element is genuine 64-bit. It also doesn't work if the data
|
||||
// element is negative as the signed bit will be incorrectly interpreted.
|
||||
m_bufferTensorDesc.DataType = Dml::Remap64bitDmlDataTypeTo32bit(m_bufferTensorDesc.DataType);
|
||||
|
||||
// If the strides haven't been calculated yet, initialize them as packed.
|
||||
if (m_bufferTensorDesc.Strides == nullptr)
|
||||
{
|
||||
uint32_t stride = 1;
|
||||
for (int i = m_bufferTensorDesc.DimensionCount - 1; i >= 0; i--)
|
||||
{
|
||||
m_strides[i] = stride;
|
||||
stride *= m_sizes[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Double the stride values to emulate 64-bit integer support.
|
||||
for (uint32_t i = 0; i < m_bufferTensorDesc.DimensionCount; ++i)
|
||||
{
|
||||
m_strides[i] *= 2;
|
||||
}
|
||||
|
||||
// The physical size of the tensor will have an extra 4 bytes at the end.
|
||||
// DMLCalcBufferTensorSize calculates the minimum implied size, which is based on the last
|
||||
// addressable element of the tensor plus the space for the last element. However, the size
|
||||
// of the last element is now halved from 8 bytes to 4 bytes.
|
||||
//
|
||||
// Example:
|
||||
// Original Tensor: size={2,3}, strides={3,1}, type=int64, size = (1+{1,2}*{3,1})*sizeof(int64) = 6 * 8 = 48
|
||||
// Emulated Tensor: size={2,3}, strides={6,2}, type=int32, size = (1+{1,2}*{6,2})*sizeof(int32) = 11 * 4 = 44
|
||||
//
|
||||
// DirectML itself won't read/write the last 4 bytes, but we want the total size to be accurate
|
||||
// so that the entire region can be zeroed.
|
||||
endPaddingInBytes = sizeof(uint32_t);
|
||||
|
||||
m_bufferTensorDesc.Strides = m_strides;
|
||||
|
||||
m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(
|
||||
m_bufferTensorDesc.DataType,
|
||||
m_bufferTensorDesc.DimensionCount,
|
||||
m_sizes,
|
||||
m_strides
|
||||
) + endPaddingInBytes;
|
||||
}
|
||||
|
||||
bool TensorDesc::WasRemapped64bitTo32bit() const
|
||||
{
|
||||
bool was64BitIntType = (m_mlOperatorTensorDataType == MLOperatorTensorDataType::UInt64 || m_mlOperatorTensorDataType == MLOperatorTensorDataType::Int64);
|
||||
bool is32BitIntType = (m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_UINT32 || m_bufferTensorDesc.DataType == DML_TENSOR_DATA_TYPE_INT32);
|
||||
return was64BitIntType && is32BitIntType;
|
||||
}
|
||||
|
||||
gsl::span<const uint32_t> TensorDesc::GetStrides() const
|
||||
{
|
||||
if (m_bufferTensorDesc.Strides == nullptr)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount };
|
||||
return { m_strides, m_strides + m_bufferTensorDesc.DimensionCount };
|
||||
}
|
||||
|
||||
DML_TENSOR_DESC TensorDesc::GetDmlDesc()
|
||||
|
|
@ -297,7 +314,8 @@ void TensorDesc::ForceUnsignedDataType()
|
|||
m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8;
|
||||
break;
|
||||
|
||||
// Nothing to do if already unsigned
|
||||
// Nothing to do if already unsigned.
|
||||
case DML_TENSOR_DATA_TYPE_UINT64:
|
||||
case DML_TENSOR_DATA_TYPE_UINT32:
|
||||
case DML_TENSOR_DATA_TYPE_UINT16:
|
||||
case DML_TENSOR_DATA_TYPE_UINT8:
|
||||
|
|
@ -307,3 +325,35 @@ void TensorDesc::ForceUnsignedDataType()
|
|||
ML_INVALID_ARGUMENT("Can't coerce unknown or non-integral data type");
|
||||
}
|
||||
}
|
||||
|
||||
void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(alignment == TensorAxis::RightAligned || alignment == TensorAxis::LeftAligned);
|
||||
|
||||
const uint32_t oldDimensionCount = m_bufferTensorDesc.DimensionCount;
|
||||
const int32_t difference = static_cast<int32_t>(newDimensionCount - oldDimensionCount);
|
||||
if (difference == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t fillOffset = oldDimensionCount;
|
||||
int32_t fillCount = std::max(0, difference);
|
||||
|
||||
// alignment == TensorAxis::LeftAligned is the easy case.
|
||||
// Right alignment needs more work, shifting values over.
|
||||
if (alignment == TensorAxis::RightAligned)
|
||||
{
|
||||
fillOffset = 0; // Fill leading dimensions with 1's starting at the front.
|
||||
uint32_t moveCount = std::min(newDimensionCount, oldDimensionCount);
|
||||
memmove(&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof(m_sizes[0]) * moveCount);
|
||||
memmove(&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof(m_strides[0]) * moveCount);
|
||||
}
|
||||
if (fillCount > 0)
|
||||
{
|
||||
std::fill(&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u);
|
||||
std::fill(&m_strides[fillOffset], &m_strides[fillOffset] + fillCount, 0u);
|
||||
}
|
||||
m_bufferTensorDesc.DimensionCount = newDimensionCount;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,11 +36,13 @@ namespace Dml
|
|||
|
||||
inline DML_TENSOR_DATA_TYPE GetDmlDataType() const { return m_bufferTensorDesc.DataType; }
|
||||
inline MLOperatorTensorDataType GetMlOperatorDataType() const { return m_mlOperatorTensorDataType; }
|
||||
inline gsl::span<const uint32_t> GetDmlSizes() const { return m_sizes; }
|
||||
void ForceUnsignedDataType();
|
||||
void Remap64bitDmlDataTypeTo32bit();
|
||||
bool WasRemapped64bitTo32bit() const;
|
||||
|
||||
inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
|
||||
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
|
||||
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
|
||||
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
|
||||
gsl::span<const uint32_t> GetStrides() const;
|
||||
|
||||
|
|
@ -58,8 +60,6 @@ namespace Dml
|
|||
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TensorDescBuilder
|
||||
{
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ namespace AttrName
|
|||
static constexpr const char* Axis = "axis";
|
||||
static constexpr const char* AxisW = "axis_w";
|
||||
static constexpr const char* BatchAxis = "batch_axis";
|
||||
static constexpr const char* BatchDimensions = "batch_dims";
|
||||
static constexpr const char* Beta = "beta";
|
||||
static constexpr const char* Bias = "bias";
|
||||
static constexpr const char* BlockSize = "blocksize";
|
||||
|
|
@ -32,6 +33,7 @@ namespace AttrName
|
|||
static constexpr const char* Dtype = "dtype";
|
||||
static constexpr const char* Ends = "ends";
|
||||
static constexpr const char* Epsilon = "epsilon";
|
||||
static constexpr const char* Equation = "equation";
|
||||
static constexpr const char* ExcludeOutside = "exclude_outside";
|
||||
static constexpr const char* Exclusive = "exclusive";
|
||||
static constexpr const char* Exponent = "exponent";
|
||||
|
|
@ -45,6 +47,7 @@ namespace AttrName
|
|||
static constexpr const char* InputForget = "input_forget";
|
||||
static constexpr const char* K = "k";
|
||||
static constexpr const char* KeepDims = "keepdims";
|
||||
static constexpr const char* SelectLastIndex = "select_last_index";
|
||||
static constexpr const char* KernelShape = "kernel_shape";
|
||||
static constexpr const char* LinearBeforeReset = "linear_before_reset";
|
||||
static constexpr const char* Lambda = "lambd"; // Deliberate typo to match ONNX spec.
|
||||
|
|
@ -57,12 +60,15 @@ namespace AttrName
|
|||
static constexpr const char* NearestMode = "nearest_mode";
|
||||
static constexpr const char* NormalizeVariance = "normalize_variance";
|
||||
static constexpr const char* P = "p";
|
||||
static constexpr const char* OutputHeight = "output_height";
|
||||
static constexpr const char* OutputShape = "output_shape";
|
||||
static constexpr const char* OutputPadding = "output_padding";
|
||||
static constexpr const char* OutputWidth = "output_width";
|
||||
static constexpr const char* Pads = "pads";
|
||||
static constexpr const char* PooledShape = "pooled_shape";
|
||||
static constexpr const char* Reverse = "reverse";
|
||||
static constexpr const char* SampleSize = "sample_size";
|
||||
static constexpr const char* SamplingRatio = "sampling_ratio";
|
||||
static constexpr const char* Scale = "scale";
|
||||
static constexpr const char* Scales = "scales";
|
||||
static constexpr const char* Seed = "seed";
|
||||
|
|
|
|||
|
|
@ -20,13 +20,12 @@
|
|||
}\
|
||||
}
|
||||
|
||||
template<typename T, typename I> T clamp_cast(I input)
|
||||
{
|
||||
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
|
||||
}
|
||||
|
||||
namespace OperatorHelper
|
||||
{
|
||||
template<typename T, typename I> T clamp_cast(I input)
|
||||
{
|
||||
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
|
||||
}
|
||||
enum TensorAxis { N, C, H, W, DoNotCoerce = UINT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 };
|
||||
enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast };
|
||||
|
||||
|
|
|
|||
|
|
@ -113,6 +113,9 @@ IMLOperatorRegistryPrivate : public IUnknown
|
|||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph = nullptr,
|
||||
bool requiresFloatFormatsForGraph = false,
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
|
||||
uint32_t constantCpuInputCount = 0
|
||||
) const noexcept PURE;
|
||||
|
|
|
|||
|
|
@ -113,6 +113,36 @@ namespace OperatorHelper
|
|||
}
|
||||
}
|
||||
|
||||
float CastFloat16ToFloat32(uint16_t input)
|
||||
{
|
||||
// Promote float16m10e5s1 to float32m23e8s1.
|
||||
// Note this works on machines of both ascending and descending byte
|
||||
// endianness, so long as float32 and uint32 endianness match.
|
||||
// It does not work for a few abberant architectures which store
|
||||
// float32 and uint32 with opposite endianness.
|
||||
|
||||
const uint32_t float16unsignedValueMask = 0x7FFF;
|
||||
const uint32_t float16signMask = 0x8000;
|
||||
const uint32_t float16exponentMask = 0x7C00;
|
||||
const uint32_t float32exponentMask = 0x7F800000;
|
||||
|
||||
uint32_t float16unsignedValue = input & float16unsignedValueMask;
|
||||
uint32_t float16sign = input & float16signMask;
|
||||
uint32_t float16exponent = input & float16exponentMask;
|
||||
|
||||
// Shift mantissa bits left (23 - 10 = 13).
|
||||
// Adjust exponent bias (127 - 15 = 112, 112 << 23 == 0x38000000).
|
||||
// Move sign bit to float32 MSB (32 - 16 = 16).
|
||||
uint32_t float32unsignedValue = (float16unsignedValue << 13) + 0x38000000;
|
||||
uint32_t float32sign = float16sign << 16;
|
||||
uint32_t result = (float16exponent == 0) ? (float32unsignedValue & ~float32exponentMask) : // Denormal
|
||||
(float16exponent == float16exponentMask) ? (float32unsignedValue | float32exponentMask) : // Infinity
|
||||
float32unsignedValue; // Any other normal value
|
||||
result |= float32sign;
|
||||
|
||||
return reinterpret_cast<float&>(result);
|
||||
}
|
||||
|
||||
int64_t CastToInt64(MLOperatorTensorDataType tensorDataType, const void* p)
|
||||
{
|
||||
switch (tensorDataType)
|
||||
|
|
@ -150,7 +180,7 @@ namespace OperatorHelper
|
|||
case MLOperatorTensorDataType::Int64: return static_cast<double>(*reinterpret_cast<const int64_t*>(p));
|
||||
case MLOperatorTensorDataType::String: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::String type is unsupported for reading as an integer.");
|
||||
case MLOperatorTensorDataType::Bool: return static_cast<double>(*reinterpret_cast<const uint8_t*>(p));
|
||||
case MLOperatorTensorDataType::Float16: ML_INVALID_ARGUMENT("MLOperatorTensorDataType::Float16 type is unsupported for reading as an integer.");
|
||||
case MLOperatorTensorDataType::Float16: return static_cast<double>(CastFloat16ToFloat32(*reinterpret_cast<const uint16_t*>(p)));
|
||||
case MLOperatorTensorDataType::Double: return static_cast<double>(*reinterpret_cast<const double*>(p));
|
||||
case MLOperatorTensorDataType::UInt32: return static_cast<double>(*reinterpret_cast<const uint32_t*>(p));
|
||||
case MLOperatorTensorDataType::UInt64: return static_cast<double>(*reinterpret_cast<const uint64_t*>(p));
|
||||
|
|
@ -673,7 +703,7 @@ namespace OperatorHelper
|
|||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
|
||||
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
|
||||
ML_CHECK_VALID_ARGUMENT(outDimCount >= 0 && outDimCount <= NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(outDimCount >= 0);
|
||||
|
||||
std::vector<DimensionType> outputDimensions(outDimCount, 1);
|
||||
|
||||
|
|
@ -707,21 +737,27 @@ namespace OperatorHelper
|
|||
{
|
||||
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
|
||||
int32_t batchCount = m_batchCount;
|
||||
|
||||
// Determine the number of output dimensions.
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1);
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > batchCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() > batchCount);
|
||||
const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back();
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex);
|
||||
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - numberOfCoordinatesPerIndex;
|
||||
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - 1; // Strip off last dimension.
|
||||
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= batchCount + numberOfCoordinatesPerIndex);
|
||||
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - batchCount - numberOfCoordinatesPerIndex;
|
||||
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - batchCount - 1; // Strip off last dimension.
|
||||
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(batchCount + numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0);
|
||||
|
||||
// Form the full expected size by concatenating the prefix part of the indices tensor shape
|
||||
// with the suffix of the input tensor shape.
|
||||
// Form the full expected size by concatenating fragments:
|
||||
// 1 - batch count
|
||||
// 2 - prefix part of the indices tensor shape
|
||||
// 3 - suffix of the input tensor shape.
|
||||
std::vector<DimensionType> outputDimensions;
|
||||
outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1);
|
||||
outputDimensions.assign(inputDimensions.begin(), inputDimensions.begin() + batchCount);
|
||||
outputDimensions.insert(outputDimensions.end(), indicesDimensions.begin() + batchCount, indicesDimensions.end() - 1);
|
||||
outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end());
|
||||
|
||||
return { EdgeShapes(std::move(outputDimensions)) };
|
||||
|
|
@ -782,8 +818,6 @@ namespace OperatorHelper
|
|||
// Dim Offset : 1
|
||||
|
||||
std::vector<DimensionType> reducedDims = shapeInfo.GetInputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(reducedDims.size() <= NchwDimensionCount);
|
||||
|
||||
std::vector<bool> reduced(reducedDims.size(), false);
|
||||
|
||||
for (auto& dim : m_axes)
|
||||
|
|
@ -817,8 +851,6 @@ namespace OperatorHelper
|
|||
|
||||
void ReduceHelperBase::AdjustAxesAndOutputShape(const std::vector<uint32_t>& inputShape)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(inputShape.size() <= NchwDimensionCount);
|
||||
|
||||
// If axes is not specified, reduce over all the dimensions
|
||||
if (m_axes.empty())
|
||||
{
|
||||
|
|
@ -826,7 +858,264 @@ namespace OperatorHelper
|
|||
std::iota(m_axes.begin(), m_axes.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void EinSumHelper::Initialize()
|
||||
{
|
||||
ParseEquationComponents();
|
||||
m_recognizedOperatorType = DetermineRecognizedOperatorType();
|
||||
}
|
||||
|
||||
void EinSumHelper::ParseEquationComponents()
|
||||
{
|
||||
// Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to
|
||||
// numeric indices {(0,1}, {1,2}, {0,2}}. The last component is the output.
|
||||
|
||||
std::map<char, uint32_t> labelMap;
|
||||
std::set<char> repeatedLabels;
|
||||
|
||||
uint32_t currentLabelIndex = 0;
|
||||
Component currentComponent = {};
|
||||
bool foundOutput = false;
|
||||
bool reachedEnd = false;
|
||||
|
||||
// Read first to last character in equation, looking for letters, commas, and one arrow.
|
||||
for (char* token = m_equation.data(); !reachedEnd; ++token)
|
||||
{
|
||||
char ch = *token;
|
||||
|
||||
// Only ASCII letters are valid subscript symbols in numpy.einsum().
|
||||
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'))
|
||||
{
|
||||
// Check whether label already has an index.
|
||||
const auto [i, inserted] = labelMap.insert({ch, currentLabelIndex});
|
||||
if (inserted)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(!foundOutput, "Found label in equation output not matching any label from inputs.")
|
||||
++currentLabelIndex; // New label found.
|
||||
}
|
||||
else if (!foundOutput)
|
||||
{
|
||||
// If label in input already found earlier, then keep track of this later
|
||||
// to generate the default output in case one is not specified.
|
||||
repeatedLabels.insert(ch);
|
||||
}
|
||||
m_labelIndices.push_back(i->second);
|
||||
}
|
||||
else if (ch == ' ')
|
||||
{
|
||||
// Ignore spaces.
|
||||
}
|
||||
else
|
||||
{
|
||||
currentComponent.labelIndexEnd = static_cast<uint32_t>(m_labelIndices.size());
|
||||
m_components.push_back(currentComponent);
|
||||
currentComponent.labelIndexBegin = currentComponent.labelIndexEnd;
|
||||
|
||||
switch (ch)
|
||||
{
|
||||
case ',':
|
||||
// Note it's valid for 2 commas be adjacent, which indicates a scalar and generates
|
||||
// an empty component.
|
||||
break;
|
||||
|
||||
case '-': // Start of "->" (must be atomic, no space between them).
|
||||
++token; // Skip '-'.
|
||||
ML_CHECK_VALID_ARGUMENT(*token == '>', "Expected '->' for output.")
|
||||
ML_CHECK_VALID_ARGUMENT(foundOutput == false, "Only one output arrow '->' is valid.")
|
||||
foundOutput = true;
|
||||
break;
|
||||
|
||||
case '.':
|
||||
// Ellipsis is unsupported. Leave recognized operator as None, deferring to another EP.
|
||||
m_components.clear();
|
||||
return;
|
||||
|
||||
case '\0':
|
||||
reachedEnd = true;
|
||||
break; // End of string.
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Unsupported character in equation string. Must be a-z, A-Z, ',', or '->'.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundOutput)
|
||||
{
|
||||
// If no explicit output was given, generate an implicit output by ordering all the
|
||||
// labels in alphabetic order (by ASCII value consistent with numpy, so Z < a).
|
||||
// Exclude any labels that occurred more than once, as these cancel out.
|
||||
|
||||
for (auto i : labelMap)
|
||||
{
|
||||
if (repeatedLabels.count(i.first) == 0)
|
||||
{
|
||||
m_labelIndices.push_back(i.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Push the final component, which is the output.
|
||||
currentComponent.labelIndexEnd = static_cast<uint32_t>(m_labelIndices.size());
|
||||
m_components.push_back(currentComponent);
|
||||
}
|
||||
}
|
||||
|
||||
EinSumHelper::RecognizedOperatorType EinSumHelper::DetermineRecognizedOperatorType()
|
||||
{
|
||||
if (m_components.empty())
|
||||
{
|
||||
return RecognizedOperatorType::None; // Parsing may have found unsupported components - treating as unknown.
|
||||
}
|
||||
|
||||
// std::ranges::equal is not supported yet.
|
||||
auto equals = [](gsl::span<const uint32_t> a, gsl::span<const uint32_t> b)
|
||||
{
|
||||
return std::equal(a.begin(), a.end(), b.begin(), b.end());
|
||||
};
|
||||
|
||||
std::array<uint32_t, 3> componentRanks;
|
||||
if (m_components.size() > componentRanks.size())
|
||||
{
|
||||
// No recognized operator takes more than 2 inputs and 1 output.
|
||||
// EinSum itself is generic and can handle any variable number of inputs,
|
||||
// but DML's operators expect fixed counts.
|
||||
return RecognizedOperatorType::None;
|
||||
}
|
||||
else if (m_components.size() == 2)
|
||||
{
|
||||
auto& inputLabels = m_components[0].GetLabels(m_labelIndices);
|
||||
auto& outputLabels = m_components[1].GetLabels(m_labelIndices);
|
||||
if (inputLabels.size() == outputLabels.size())
|
||||
{
|
||||
// Check identity.
|
||||
if (equals(inputLabels, outputLabels))
|
||||
{
|
||||
// Handles: "->", "i->i", "ij->ij", "ijk->ijk", "ijkl->ijkl" ...
|
||||
return RecognizedOperatorType::Identity;
|
||||
}
|
||||
else // Transpose since a permutation exists.
|
||||
{
|
||||
// Handles: "ij->ji", "ijk->kji", "ijkl->lkji", "ijkl->ijkl" ...
|
||||
return RecognizedOperatorType::Transpose;
|
||||
}
|
||||
}
|
||||
else if (outputLabels.empty()) // Scalar output, with all inputs reduced.
|
||||
{
|
||||
// Handles: "i->", "ij->", "ijk->", "ijkl->" ...
|
||||
return RecognizedOperatorType::ReduceSum;
|
||||
}
|
||||
}
|
||||
else if (m_components.size() == 3)
|
||||
{
|
||||
// If all components have the same size and label order, then apply elementwise multiplication.
|
||||
auto& inputALabels = m_components[0].GetLabels(m_labelIndices);
|
||||
auto& inputBLabels = m_components[1].GetLabels(m_labelIndices);
|
||||
auto& outputLabels = m_components[2].GetLabels(m_labelIndices);
|
||||
if (equals(inputALabels, outputLabels) && equals(inputBLabels, outputLabels))
|
||||
{
|
||||
// Handles: "i,i->i", "ij,ij->ij", "ijk,ijk->ijk", "ijkl,ijkl->ijkl" ...
|
||||
return RecognizedOperatorType::Multiply;
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise check for special cases of dedicated operators...
|
||||
|
||||
struct RecognizedOperatorInfo
|
||||
{
|
||||
RecognizedOperatorType recognizedOperatorType;
|
||||
std::initializer_list<const uint32_t> componentRanks;
|
||||
std::initializer_list<const uint32_t> labelIndices;
|
||||
};
|
||||
|
||||
const RecognizedOperatorInfo recognizedOperators[] = {
|
||||
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
|
||||
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
|
||||
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
|
||||
};
|
||||
|
||||
// For each recognized operator, compare the labels-per-component and label indices.
|
||||
for (auto& recognizedOperator : recognizedOperators)
|
||||
{
|
||||
if (equals(m_labelIndices, recognizedOperator.labelIndices)
|
||||
&& m_components.size() == recognizedOperator.componentRanks.size())
|
||||
{
|
||||
for (size_t i = 0; i < m_components.size(); ++i)
|
||||
{
|
||||
componentRanks[i] = m_components[i].GetDimensionCount();
|
||||
}
|
||||
|
||||
if (equals(gsl::make_span(componentRanks.data(), m_components.size()), recognizedOperator.componentRanks))
|
||||
{
|
||||
return recognizedOperator.recognizedOperatorType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RecognizedOperatorType::None;
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> EinSumHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
assert(!m_components.empty()); // Should have already parsed components.
|
||||
|
||||
uint32_t inputCount = shapeInfo.GetInputCount();
|
||||
uint32_t outputCount = shapeInfo.GetOutputCount();
|
||||
ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count.");
|
||||
ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor.");
|
||||
|
||||
std::vector<uint32_t> labelSizes(m_labelIndices.size(), INT_MIN);
|
||||
|
||||
// Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier.
|
||||
for (uint32_t i = 0; i < inputCount; ++i)
|
||||
{
|
||||
auto inputShape = shapeInfo.GetInputTensorShape(i);
|
||||
auto& component = m_components[i];
|
||||
auto labelIndices = component.GetLabels(m_labelIndices);
|
||||
uint32_t dimensionCount = component.GetDimensionCount();
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(inputShape.size() == dimensionCount, "Mismatch between input tensor shape and string equation label count.");
|
||||
|
||||
for (uint32_t i = 0; i < dimensionCount; ++i)
|
||||
{
|
||||
// If this is the first time seeing this label, then record the size.
|
||||
// Otherwise any following occurrences of the label must match sizes.
|
||||
// e.g. Given "ij,ji", both i's and both j's must match dimension sizes.
|
||||
uint32_t dimensionSize = inputShape[i];
|
||||
uint32_t labelIndex = labelIndices[i];
|
||||
assert(labelIndex < labelSizes.size());
|
||||
|
||||
if (labelSizes[labelIndex] == INT_MIN)
|
||||
{
|
||||
labelSizes[labelIndex] = dimensionSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(labelSizes[labelIndex] == dimensionSize, "All labels must have the same dimension sizes.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate output dimensions from corresponding input tensor labels.
|
||||
// e.g. Given ij,jk->ij with [2,3] and [3,5], the output is [2,5].
|
||||
std::vector<uint32_t> outputDimensions;
|
||||
auto outputLabelIndices = m_components.back().GetLabels(m_labelIndices);
|
||||
for (auto labelIndex : outputLabelIndices)
|
||||
{
|
||||
outputDimensions.push_back(labelSizes[labelIndex]);
|
||||
}
|
||||
|
||||
return { std::move(EdgeShapes(outputDimensions)) };
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2);
|
||||
|
|
@ -971,7 +1260,6 @@ namespace OperatorHelper
|
|||
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto outputShape = shapeInfo.GetInputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(outputShape.size() <= NchwDimensionCount);
|
||||
|
||||
uint32_t inputCount = shapeInfo.GetInputCount();
|
||||
|
||||
|
|
@ -1110,8 +1398,25 @@ namespace OperatorHelper
|
|||
{
|
||||
roiShape[0], // number of ROIs
|
||||
inputShape[C], // number of channels
|
||||
static_cast<DimensionType>(m_pooledSizeH),
|
||||
static_cast<DimensionType>(m_pooledSizeW),
|
||||
static_cast<DimensionType>(m_outputSizeH),
|
||||
static_cast<DimensionType>(m_outputSizeW),
|
||||
};
|
||||
|
||||
return { std::move(EdgeShapes(outputDimensions)) };
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> RoiAlignHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
|
||||
auto inputShape = shapeInfo.GetInputTensorShape(InputTensors::INPUT);
|
||||
ML_CHECK_VALID_ARGUMENT(inputShape.size() >= 4, "inputShape must be >= 4.");
|
||||
|
||||
DimensionType outputDimensions[4] =
|
||||
{
|
||||
roiShape[0], // number of ROIs
|
||||
inputShape[C], // number of channels
|
||||
static_cast<DimensionType>(m_outputSizeH),
|
||||
static_cast<DimensionType>(m_outputSizeW),
|
||||
};
|
||||
|
||||
return { std::move(EdgeShapes(outputDimensions)) };
|
||||
|
|
|
|||
|
|
@ -633,7 +633,7 @@ public:
|
|||
{
|
||||
int dimIndex = axes.empty() ? i : axes[i];
|
||||
int stride = steps.empty() ? 1 : steps[i];
|
||||
ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions.");
|
||||
ML_CHECK_VALID_ARGUMENT(static_cast<size_t>(dimIndex) < static_cast<size_t>(inputDimensions.size()), "'axes' must be valid with within actual input dimensions.");
|
||||
ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0.");
|
||||
|
||||
// Positive values are offsets from 0.
|
||||
|
|
@ -733,6 +733,7 @@ class ReduceHelperBase {
|
|||
template <typename Info_t, typename Shape_t>
|
||||
ReduceHelperBase(const Info_t& info, const Shape_t& shape, bool usingAxes) {
|
||||
m_keepDims = info.GetOptionalAttribute<int>(AttrName::KeepDims, 1);
|
||||
m_selectLastIndex = info.GetOptionalAttribute<int>(AttrName::SelectLastIndex, 0);
|
||||
if (usingAxes) {
|
||||
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
|
||||
} else {
|
||||
|
|
@ -751,6 +752,7 @@ class ReduceHelperBase {
|
|||
protected:
|
||||
std::vector<int> m_axes;
|
||||
int m_keepDims = 0;
|
||||
int m_selectLastIndex = 0;
|
||||
};
|
||||
|
||||
class ArgMinArgMaxHelper : public ReduceHelperBase {
|
||||
|
|
@ -769,6 +771,70 @@ class ReduceHelper : public ReduceHelperBase {
|
|||
ReduceHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, true) {}
|
||||
};
|
||||
|
||||
class EinSumHelper
|
||||
{
|
||||
public:
|
||||
void Initialize();
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
EinSumHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
|
||||
{
|
||||
m_equation = info.GetAttribute(AttrName::Equation);
|
||||
Initialize();
|
||||
}
|
||||
|
||||
EinSumHelper(const MLOperatorAttributes& info)
|
||||
{
|
||||
m_equation = info.GetAttribute(AttrName::Equation);
|
||||
Initialize();
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
enum class RecognizedOperatorType
|
||||
{
|
||||
None,
|
||||
Identity,
|
||||
Multiply,
|
||||
MatMul,
|
||||
MatMulTransposeA,
|
||||
MatMulTransposeB,
|
||||
ReduceSum,
|
||||
Transpose,
|
||||
Total,
|
||||
};
|
||||
|
||||
RecognizedOperatorType GetRecognizedOperatorType() const noexcept { return m_recognizedOperatorType; }
|
||||
|
||||
protected:
|
||||
void ParseEquationComponents();
|
||||
RecognizedOperatorType DetermineRecognizedOperatorType();
|
||||
|
||||
protected:
|
||||
struct Component
|
||||
{
|
||||
uint32_t labelIndexBegin;
|
||||
uint32_t labelIndexEnd;
|
||||
|
||||
uint32_t GetDimensionCount() const noexcept
|
||||
{
|
||||
return labelIndexEnd - labelIndexBegin;
|
||||
}
|
||||
gsl::span<const uint32_t> GetLabels(gsl::span<const uint32_t> labels) const
|
||||
{
|
||||
return labels.subspan(labelIndexBegin, labelIndexEnd - labelIndexBegin);
|
||||
};
|
||||
};
|
||||
|
||||
std::string m_equation;
|
||||
std::vector<uint32_t> m_labelIndices; // Concatenation of all labels as rebased indices ("ij,ai" -> 0,1,2,0).
|
||||
std::vector<Component> m_components; // All components in order, including inputs and output.
|
||||
std::vector<uint32_t> m_outputDimensions;
|
||||
RecognizedOperatorType m_recognizedOperatorType = RecognizedOperatorType::None;
|
||||
};
|
||||
|
||||
class MatMulHelperBase {
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
|
|
@ -975,9 +1041,13 @@ class GatherNdHelper {
|
|||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
GatherNdHelper(const Info_t& info, const Shape_t& shape) {
|
||||
m_batchCount = info.GetOptionalAttribute<int32_t>(AttrName::BatchDimensions, 0);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
protected:
|
||||
int32_t m_batchCount;
|
||||
};
|
||||
|
||||
class PoolingHelperBase {
|
||||
|
|
@ -1040,26 +1110,51 @@ class PoolingHelper : public PoolingHelperBase {
|
|||
PoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
|
||||
};
|
||||
|
||||
class RoiPoolingHelper {
|
||||
public:
|
||||
enum InputTensors { INPUT,
|
||||
ROIS };
|
||||
class RoiPoolingHelperBase
|
||||
{
|
||||
public:
|
||||
enum InputTensors { INPUT, ROIS, BATCH_INDICES };
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
RoiPoolingHelper(const Info_t& info, const Shape_t& shape) {
|
||||
std::vector<int> pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape);
|
||||
ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2.");
|
||||
m_pooledSizeH = pooledShape[0];
|
||||
m_pooledSizeW = pooledShape[1];
|
||||
}
|
||||
RoiPoolingHelperBase()
|
||||
{}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
protected:
|
||||
uint32_t m_pooledSizeW;
|
||||
uint32_t m_pooledSizeH;
|
||||
protected:
|
||||
uint32_t m_outputSizeW = 1;
|
||||
uint32_t m_outputSizeH = 1;
|
||||
};
|
||||
|
||||
class RoiPoolingHelper : public RoiPoolingHelperBase
|
||||
{
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
RoiPoolingHelper(const Info_t& info, const Shape_t& shape)
|
||||
{
|
||||
std::vector<int> pooledShape = info.GetOptionalAttributeVectorInt32(AttrName::PooledShape);
|
||||
ML_CHECK_VALID_ARGUMENT(pooledShape.size() == 2, "Pooled shape must be 2.");
|
||||
m_outputSizeH = pooledShape[0];
|
||||
m_outputSizeW = pooledShape[1];
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
};
|
||||
|
||||
class RoiAlignHelper : public RoiPoolingHelperBase
|
||||
{
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
RoiAlignHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
|
||||
{
|
||||
m_outputSizeW = info.GetOptionalAttribute<uint32_t>(AttrName::OutputWidth, 1);
|
||||
m_outputSizeH = info.GetOptionalAttribute<uint32_t>(AttrName::OutputHeight, 1);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
};
|
||||
|
||||
class SqueezeHelper {
|
||||
|
|
@ -1330,6 +1425,7 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
|
|||
using ShapeInferenceHelper_LpPool = PoolingHelper;
|
||||
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
|
||||
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
|
||||
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
|
||||
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
|
||||
|
|
@ -1382,8 +1478,11 @@ using ShapeInferenceHelper_Ceil = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_GreaterOrEqual = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_LessOrEqual = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Equal = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Not = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_And = GetBroadcastedOutputShapeHelper;
|
||||
|
|
@ -1430,6 +1529,7 @@ using ShapeInferenceHelper_ReduceL1 = ReduceHelper;
|
|||
using ShapeInferenceHelper_ReduceL2 = ReduceHelper;
|
||||
using ShapeInferenceHelper_ReduceMax = ReduceHelper;
|
||||
using ShapeInferenceHelper_ReduceMin = ReduceHelper;
|
||||
using ShapeInferenceHelper_Einsum12 = VersionedOpsetHelper<EinSumHelper, 12>;
|
||||
using ShapeInferenceHelper_ArgMax = ArgMinArgMaxHelper;
|
||||
using ShapeInferenceHelper_ArgMin = ArgMinArgMaxHelper;
|
||||
using ShapeInferenceHelper_Gemm = GemmHelper;
|
||||
|
|
@ -1450,6 +1550,7 @@ using ShapeInferenceHelper_LeakyRelu = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_PRelu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_ThresholdedRelu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -245,6 +245,24 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Unsqueeze = 11;
|
||||
} // namespace OnnxOperatorSet11
|
||||
|
||||
namespace OnnxOperatorSet12
|
||||
{
|
||||
static const int sc_sinceVer_ArgMin = 12;
|
||||
static const int sc_sinceVer_ArgMax = 12;
|
||||
static const int sc_sinceVer_Celu = 12;
|
||||
static const int sc_sinceVer_Clip = 12;
|
||||
static const int sc_sinceVer_Einsum = 12;
|
||||
static const int sc_sinceVer_GatherND = 12;
|
||||
static const int sc_sinceVer_GreaterOrEqual = 12;
|
||||
static const int sc_sinceVer_LessOrEqual = 12;
|
||||
static const int sc_sinceVer_MaxPool = 12;
|
||||
static const int sc_sinceVer_Min = 12;
|
||||
static const int sc_sinceVer_Max = 12;
|
||||
static const int sc_sinceVer_Pow = 12;
|
||||
static const int sc_sinceVer_ReduceMax = 12;
|
||||
static const int sc_sinceVer_ReduceMin = 12;
|
||||
} // namespace OnnxOperatorSet12
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
{
|
||||
static const int sc_sinceVer_FusedConv = 1;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <numeric>
|
||||
|
||||
#include <wrl/client.h>
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="DirectML" version="2.1.0" targetFramework="native" />
|
||||
<package id="DirectML" version="3.0.0" targetFramework="native" />
|
||||
<package id="GoogleTestAdapter" version="0.17.1" targetFramework="net46" />
|
||||
</packages>
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ def generate_files(list, args):
|
|||
'" target="runtimes\\win-' + args.target_architecture + '\\native" />')
|
||||
files_list.append('<file src=' + '"' + os.path.join(args.native_build_path, 'DirectML.pdb') +
|
||||
'" target="runtimes\\win-' + args.target_architecture + '\\native" />')
|
||||
files_list.append('<file src=' + '"' + os.path.join(args.packages_path, 'DirectML.2.1.0\\LICENSE.txt') +
|
||||
files_list.append('<file src=' + '"' + os.path.join(args.packages_path, 'DirectML.2.1.1\\LICENSE.txt') +
|
||||
'" target="DirectML_LICENSE.txt" />')
|
||||
|
||||
if includes_winml:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph,
|
||||
bool requiresFloatFormatsForGraph,
|
||||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
|
||||
uint32_t constantCpuInputCount) const noexcept try {
|
||||
#ifdef LAYERING_DONE
|
||||
|
|
@ -79,6 +82,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
supportsGraph,
|
||||
requiredInputCountForGraph,
|
||||
requiresFloatFormatsForGraph,
|
||||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
requiredConstantCpuInputs,
|
||||
constantCpuInputCount);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
|
|||
bool supports_graph,
|
||||
const uint32_t* required_input_count_for_graph = nullptr,
|
||||
bool requires_float_formats_for_graph = false,
|
||||
bool supports_64bit_directly = false,
|
||||
bool allows_64bit_via_strides = false,
|
||||
bool allows_64bit_via_strides_from_any_ep = false,
|
||||
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
|
||||
uint32_t constant_cpu_input_count = 0) const noexcept override;
|
||||
|
||||
|
|
|
|||
|
|
@ -54,9 +54,10 @@ static std::string ToString(wfc::IVectorView<int64_t> shape) {
|
|||
static std::string ToString(
|
||||
winml::TensorKind kind,
|
||||
wfc::IVectorView<int64_t> shape) {
|
||||
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
|
||||
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
|
||||
FAIL_FAST_IF_MSG(kind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
|
||||
// Any unrecognized data type is considered "Undefined".
|
||||
if (static_cast<uint32_t>(kind) >= std::size(SzTensorKind)) {
|
||||
kind = winml::TensorKind::Undefined;
|
||||
}
|
||||
|
||||
std::ostringstream stream;
|
||||
stream << SzTensorKind[static_cast<uint32_t>(kind)] << ToString(shape);
|
||||
|
|
@ -73,9 +74,10 @@ static std::string ToString(winml::ITensor value) {
|
|||
|
||||
static std::string ToString(winml::IMapFeatureDescriptor descriptor) {
|
||||
auto keyKind = descriptor.KeyKind();
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
|
||||
// Any unrecognized data type is considered "Undefined".
|
||||
if (static_cast<uint32_t>(keyKind) >= std::size(SzTensorKind)) {
|
||||
keyKind = winml::TensorKind::Undefined;
|
||||
}
|
||||
|
||||
auto valueDescriptor = descriptor.ValueDescriptor();
|
||||
std::ostringstream stream;
|
||||
|
|
@ -86,9 +88,10 @@ static std::string ToString(winml::IMapFeatureDescriptor descriptor) {
|
|||
static std::string ToString(winrt::com_ptr<_winml::IMapFeatureValue> value) {
|
||||
winml::TensorKind keyKind;
|
||||
FAIL_FAST_IF_FAILED(value->get_KeyKind(&keyKind));
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128.");
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64.");
|
||||
FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined.");
|
||||
// Any unrecognized data type is considered "Undefined".
|
||||
if (static_cast<uint32_t>(keyKind) >= std::size(SzTensorKind)) {
|
||||
keyKind = winml::TensorKind::Undefined;
|
||||
}
|
||||
|
||||
winml::ILearningModelFeatureDescriptor valueDescriptor;
|
||||
FAIL_FAST_IF_FAILED(value->get_ValueDescriptor(&valueDescriptor));
|
||||
|
|
|
|||
Loading…
Reference in a new issue