From 5b16593192e03b728f8b096e04a6a79f802670e0 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Wed, 7 Dec 2022 15:24:27 -0800 Subject: [PATCH] [DML EP] Attention Kernel bug fix (#13879) ### Description - Use same data type as input for mask_index tensor which is used as DML GEMM API's C parameter. - Remove gsl header include as it is already gets included transitively. ### Motivation and Context - Why is this change required? What problem does it solve? Bug found in internal conformance testing. - If it fixes an open issue, please link to the issue here. N/A --- .../DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp | 2 +- winml/test/scenario/cppwinrt/CustomOps.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index c524e6e6d8..63bae80c51 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -146,7 +146,7 @@ public: valueSlicedOperatorDesc.InputWindowStrides = strides.data(); const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc}; - TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Float, desiredMaskIndexShape); + TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredMaskIndexShape); DML_TENSOR_DESC namedCastedMaskIndexTensorDesc = castedMaskIndexTensorDesc.GetDmlDesc(); DML_CAST_OPERATOR_DESC castMaskIndexOperatorDesc = {}; diff --git a/winml/test/scenario/cppwinrt/CustomOps.cpp b/winml/test/scenario/cppwinrt/CustomOps.cpp index 926bae30b3..91606f6efb 100644 --- a/winml/test/scenario/cppwinrt/CustomOps.cpp +++ b/winml/test/scenario/cppwinrt/CustomOps.cpp @@ -8,7 +8,6 @@ #include "filehelpers.h" #include #include -#include #include "CustomOperatorProvider.h" #include "CustomOps.h"