mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[DML EP] Attention Kernel bug fix (#13879)
### Description - Use same data type as input for mask_index tensor which is used as DML GEMM API's C parameter. - Remove gsl header include as it is already gets included transitively. ### Motivation and Context - Why is this change required? What problem does it solve? Bug found in internal conformance testing. - If it fixes an open issue, please link to the issue here. N/A
This commit is contained in:
parent
4c79977f52
commit
5b16593192
2 changed files with 1 additions and 2 deletions
|
|
@ -146,7 +146,7 @@ public:
|
|||
valueSlicedOperatorDesc.InputWindowStrides = strides.data();
|
||||
const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc};
|
||||
|
||||
TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Float, desiredMaskIndexShape);
|
||||
TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredMaskIndexShape);
|
||||
DML_TENSOR_DESC namedCastedMaskIndexTensorDesc = castedMaskIndexTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_CAST_OPERATOR_DESC castMaskIndexOperatorDesc = {};
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#include "filehelpers.h"
|
||||
#include <fstream>
|
||||
#include <MemoryBuffer.h>
|
||||
#include <gsl/gsl>
|
||||
#include "CustomOperatorProvider.h"
|
||||
#include "CustomOps.h"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue