From beab1ef1bbd2b9c930e5ed005ae0ef2376a9d79b Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Wed, 29 Sep 2021 23:55:26 +0000 Subject: [PATCH] Merged PR 6516718: [DML EP] Direct INT64 support for indices tensor for TopK/MaxPool/MaxUnpool Updated OperatorRegistration.cpp to enable direct INT64 support for indices tensor for onnx operators like TopK/MaxPool/MaxUnpool. --- .../src/Operators/OperatorRegistration.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 3dbe2c3934..5c982c8ac5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -280,8 +280,9 @@ constexpr static std::array supportedTypeListNumeri constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; -constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault | SupportedTensorDataTypes::Ints64Bit, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListMaxUnpool = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; @@ -370,10 +371,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -567,9 +568,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, - {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, @@ -592,8 +593,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, - {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, - {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, + {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, // 11 is identical to 9. {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},