DML constant pow operator

This commit is contained in:
Jeff Bloomfield 2023-06-16 14:46:53 -07:00
parent 1c8da45d04
commit 8b0a55e7cc
3 changed files with 49 additions and 26 deletions

View file

@ -231,10 +231,22 @@ namespace Dml::GraphDescBuilder
ComPtr<IMLOperatorTensor> tensor = nullptr;
auto inputDefs = node.InputDefs();
if (inputIndex < inputDefs.size())
{
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
tensor = constantCpuGraphInputGetter(arg->Name());
if (tensor == nullptr)
{
bool inputRequiredAsConstant = std::find(
requiredConstantCpuInputs.begin(),
requiredConstantCpuInputs.end(),
inputIndex) != requiredConstantCpuInputs.end();
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
}
return tensor;

View file

@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
}
ORT_CATCH_RETURN
}
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
{
@ -1158,28 +1158,19 @@ namespace Windows::AI::MachineLearning::Adapter
{
ORT_TRY
{
*tensor = nullptr;
auto constantInput = m_constantInputGetter(inputIndex);
if (!std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput))
{
assert(std::find(
m_requiredConstantCpuInputs.begin(),
m_requiredConstantCpuInputs.end(),
inputIndex) == m_requiredConstantCpuInputs.end());
return S_OK;
}
ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput));
auto tensorWrapper = std::get<ComPtr<IMLOperatorTensor>>(constantInput);
if (tensorWrapper == nullptr)
{
assert(std::find(
m_requiredConstantCpuInputs.begin(),
m_requiredConstantCpuInputs.end(),
inputIndex) == m_requiredConstantCpuInputs.end());
return S_OK;
bool inputRequiredAsConstant = std::find(
m_requiredConstantCpuInputs.begin(),
m_requiredConstantCpuInputs.end(),
inputIndex) != m_requiredConstantCpuInputs.end();
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
*tensor = tensorWrapper.Detach();

View file

@ -479,17 +479,37 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1);
if (constExpTensor && constExpTensor->GetTotalElementCount() == 1)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0};
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ExponentTensor = &inputDescs[1];
opDesc.OutputTensor = &outputDescs[0];
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.OutputTensor = &outputDescs[0];
opDesc.Exponent = static_cast<float>(ReadScalarTensorCastToFloat64(*constExpTensor));
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
}
else
{
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ExponentTensor = &inputDescs[1];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
}
}
};