mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
DirectML GEMM broken in opset 11 and 13 when optional tensor C not provided (#12568)
Set kernel input indices to be fixed to 0,1,2. C input is now optional, so last tensor must be specified.
This commit is contained in:
parent
580f2294bc
commit
67f6b7ce29
1 changed files with 3 additions and 2 deletions
|
|
@ -10,12 +10,13 @@ class DmlOperatorGemm : public DmlOperator, public GemmHelper
|
|||
{
|
||||
public:
|
||||
DmlOperatorGemm(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperator(kernelInfo),
|
||||
: DmlOperator(kernelInfo),
|
||||
GemmHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
auto kernelInputIndices = std::vector<std::optional<uint32_t>> { 0, 1, 2 };
|
||||
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
|
||||
|
||||
bool containsBiasTensor = kernelInfo.IsInputValid(2);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue