Merged PR 6516718: [DML EP] Direct INT64 support for indices tensor for TopK/MaxPool/MaxUnpool

Updated OperatorRegistration.cpp to enable direct INT64 support for indices tensor for onnx operators like TopK/MaxPool/MaxUnpool.
This commit is contained in:
Sumit Agarwal 2021-09-29 23:55:26 +00:00
parent dae5b1d4ac
commit beab1ef1bb

View file

@ -280,8 +280,9 @@ constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumeri
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault | SupportedTensorDataTypes::Ints64Bit, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxUnpool = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
@ -370,10 +371,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -567,9 +568,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
@ -592,8 +593,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},