|
|
|
|
@ -35,13 +35,13 @@ enum class SupportedTensorDataTypes : uint32_t
|
|
|
|
|
Complex64 = 1<<14,
|
|
|
|
|
Complex128 = 1<<15,
|
|
|
|
|
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
|
|
|
|
Int32to64 = UInt32|Int32|UInt64|Int64,
|
|
|
|
|
Ints32to64 = UInt32|Int32|UInt64|Int64,
|
|
|
|
|
UInt8to64 = UInt8|UInt16|UInt32|UInt64,
|
|
|
|
|
Float16to32 = Float16|Float32,
|
|
|
|
|
Float16to64 = Float16|Float32|Float64,
|
|
|
|
|
NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string.
|
|
|
|
|
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
|
|
|
|
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool,
|
|
|
|
|
AllScalarsButFloat64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
|
|
|
|
|
Ints8Bit = UInt8|Int8,
|
|
|
|
|
Ints16Bit = UInt16|Int16,
|
|
|
|
|
Ints32Bit = UInt32|Int32,
|
|
|
|
|
@ -56,6 +56,7 @@ enum class DmlGraphSupport : uint32_t
|
|
|
|
|
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
|
|
|
|
|
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
|
|
|
|
|
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
|
|
|
|
|
Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types.
|
|
|
|
|
};
|
|
|
|
|
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
|
|
|
|
|
|
|
|
|
|
@ -254,6 +255,7 @@ constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
|
|
|
|
|
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
|
|
|
|
|
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
|
|
|
|
|
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
|
|
|
|
|
constexpr static std::array<const char*, 2> typeNameListMaxPool = { "T", "I" };
|
|
|
|
|
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
|
|
|
|
|
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
|
|
|
|
|
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
|
|
|
|
|
@ -269,31 +271,32 @@ constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat1
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalarsButFloat64 = { SupportedTensorDataTypes::AllScalarsButFloat64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalarsButFloat64, SupportedTensorDataTypes::AllScalarsButFloat64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32 /* float32 ROI read by CPU */};
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
|
|
|
|
|
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
|
|
|
|
@ -361,10 +364,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
|
|
|
|
|
|
|
|
|
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
@ -385,43 +388,43 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
|
|
|
|
|
|
|
|
|
// Data Reorganization Layers
|
|
|
|
|
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
|
|
|
|
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
|
|
|
|
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
|
|
|
|
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
|
|
|
|
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis.
|
|
|
|
|
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis.
|
|
|
|
|
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
|
|
|
|
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
|
|
|
|
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
|
|
|
|
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
|
|
|
|
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
|
|
|
|
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
|
|
|
|
|
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))},
|
|
|
|
|
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(0))},
|
|
|
|
|
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
|
|
|
|
|
|
|
|
|
|
// Data reorganization that merely changes the dimensions while keeping the data identical.
|
|
|
|
|
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
|
|
|
|
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))},
|
|
|
|
|
|
|
|
|
|
// Elementwise
|
|
|
|
|
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
@ -435,7 +438,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
|
|
|
|
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
|
|
|
|
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
|
|
|
|
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
|
|
|
|
@ -446,7 +449,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
|
|
|
|
@ -523,7 +526,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
|
|
|
|
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
|
|
|
|
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
|
|
|
|
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
|
|
|
|
{REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
|
|
|
|
|
|
|
|
|
// Activation Functions
|
|
|
|
|
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
@ -554,8 +557,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
// Uncategorized
|
|
|
|
|
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
|
|
|
|
|
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
|
|
|
|
|
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
|
|
|
|
@ -577,7 +580,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|
|
|
|
|
|
|
|
|
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
|
|
|
|
|
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
|
|
|
|
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
|
|
|
|
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
|
|
|
|
@ -620,6 +623,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|
|
|
|
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
|
|
|
|
|
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
|
|
|
|
|
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
|
|
|
|
|
bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation);
|
|
|
|
|
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
|
|
|
|
|
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
|
|
|
|
|
|
|
|
|
|
@ -705,6 +709,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|
|
|
|
supportedWith64BitTensorsVia32BitStrides,
|
|
|
|
|
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
|
|
|
|
prefer64BitTensorsDirectly,
|
|
|
|
|
support64BitTensorsViaEmulation,
|
|
|
|
|
information.requiredConstantCpuInputs.first.data(),
|
|
|
|
|
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
|
|
|
|
|
));
|
|
|
|
|
|