From d5e8d59243e98c0da42c8abf5fd268752dcfa5b8 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 25 Oct 2022 22:47:55 -0700 Subject: [PATCH] DML EP register all data types for Where operator (#13443) ### Description Register all datatypes for DML's `Where` operator since DML now supports everything. ### Motivation and Context Some transformer models use the `Where` operator on int64 data, but since DML wasn't supporting it, it needed to fall back to the CPU. --- .../DmlExecutionProvider/src/Operators/OperatorRegistration.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 26797d4e60..b58871a29a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -322,7 +322,7 @@ constexpr static std::array supportedTypeListQuanti constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars }; -constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };