DirectML GEMM broken in opset 11 and 13 when optional tensor C not provided (#12568)

Set kernel input indices to be fixed to 0,1,2. C input is now optional, so last tensor must be specified.
This commit is contained in:
Sheil Kumar 2022-08-11 16:01:27 -07:00 committed by GitHub
parent 580f2294bc
commit 67f6b7ce29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,12 +10,13 @@ class DmlOperatorGemm : public DmlOperator, public GemmHelper
{
public:
DmlOperatorGemm(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
: DmlOperator(kernelInfo),
GemmHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
auto kernelInputIndices = std::vector<std::optional<uint32_t>> { 0, 1, 2 };
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
bool containsBiasTensor = kernelInfo.IsInputValid(2);