diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 0c290d2abf..7b95ae9620 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -231,10 +231,22 @@ namespace Dml::GraphDescBuilder ComPtr tensor = nullptr; auto inputDefs = node.InputDefs(); + if (inputIndex < inputDefs.size()) { const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; tensor = constantCpuGraphInputGetter(arg->Name()); + + if (tensor == nullptr) + { + bool inputRequiredAsConstant = std::find( + requiredConstantCpuInputs.begin(), + requiredConstantCpuInputs.end(), + inputIndex) != requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); + } } return tensor; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 66e4b85a13..55f7cb4914 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1158,28 +1158,19 @@ namespace Windows::AI::MachineLearning::Adapter { ORT_TRY { - *tensor = nullptr; - auto constantInput = m_constantInputGetter(inputIndex); - if (!std::holds_alternative>(constantInput)) - { - assert(std::find( - m_requiredConstantCpuInputs.begin(), - m_requiredConstantCpuInputs.end(), - inputIndex) == m_requiredConstantCpuInputs.end()); - - return S_OK; - } + ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative>(constantInput)); auto tensorWrapper = std::get>(constantInput); if (tensorWrapper == nullptr) { - assert(std::find( - m_requiredConstantCpuInputs.begin(), - m_requiredConstantCpuInputs.end(), - inputIndex) == m_requiredConstantCpuInputs.end()); - - return S_OK; + bool inputRequiredAsConstant = std::find( + m_requiredConstantCpuInputs.begin(), + m_requiredConstantCpuInputs.end(), + inputIndex) != m_requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } *tensor = tensorWrapper.Detach(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index c1507cf1dc..4f8bb428b3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,17 +479,37 @@ public: ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) + { + std::vector> kernelInputIndices = {0}; - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); - DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; - opDesc.InputTensor = &inputDescs[0]; - opDesc.ExponentTensor = &inputDescs[1]; - opDesc.OutputTensor = &outputDescs[0]; + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.OutputTensor = &outputDescs[0]; + opDesc.Exponent = static_cast(ReadScalarTensorCastToFloat64(*constExpTensor)); + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); + } + else + { + Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.ExponentTensor = &inputDescs[1]; + opDesc.OutputTensor = &outputDescs[0]; + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + } } };