mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
DML EP register all data types for Where operator (#13443)
### Description Register all datatypes for DML's `Where` operator since DML now supports everything. ### Motivation and Context Some transformer models use the `Where` operator on int64 data, but since DML wasn't supporting it, it needed to fall back to the CPU.
This commit is contained in:
parent
70b73afd36
commit
d5e8d59243
1 changed files with 1 additions and 1 deletions
|
|
@ -322,7 +322,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuanti
|
|||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };
|
||||
|
|
|
|||
Loading…
Reference in a new issue