[DML EP] Attention Kernel bug fix (#13879)

### Description
- Use same data type as input for mask_index tensor which is used as DML
GEMM API's C parameter.
- Remove gsl header include as it is already gets included transitively.



### Motivation and Context
- Why is this change required? What problem does it solve?
Bug found in internal conformance testing.
- If it fixes an open issue, please link to the issue here.
N/A
This commit is contained in:
Sumit Agarwal 2022-12-07 15:24:27 -08:00 committed by GitHub
parent 4c79977f52
commit 5b16593192
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 2 deletions

View file

@ -146,7 +146,7 @@ public:
valueSlicedOperatorDesc.InputWindowStrides = strides.data();
const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc};
TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Float, desiredMaskIndexShape);
TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredMaskIndexShape);
DML_TENSOR_DESC namedCastedMaskIndexTensorDesc = castedMaskIndexTensorDesc.GetDmlDesc();
DML_CAST_OPERATOR_DESC castMaskIndexOperatorDesc = {};

View file

@ -8,7 +8,6 @@
#include "filehelpers.h"
#include <fstream>
#include <MemoryBuffer.h>
#include <gsl/gsl>
#include "CustomOperatorProvider.h"
#include "CustomOps.h"