mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
DML constant pow operator
This commit is contained in:
parent
1c8da45d04
commit
8b0a55e7cc
3 changed files with 49 additions and 26 deletions
|
|
@ -231,10 +231,22 @@ namespace Dml::GraphDescBuilder
|
|||
ComPtr<IMLOperatorTensor> tensor = nullptr;
|
||||
|
||||
auto inputDefs = node.InputDefs();
|
||||
|
||||
if (inputIndex < inputDefs.size())
|
||||
{
|
||||
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
|
||||
tensor = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
if (tensor == nullptr)
|
||||
{
|
||||
bool inputRequiredAsConstant = std::find(
|
||||
requiredConstantCpuInputs.begin(),
|
||||
requiredConstantCpuInputs.end(),
|
||||
inputIndex) != requiredConstantCpuInputs.end();
|
||||
|
||||
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
|
||||
}
|
||||
}
|
||||
|
||||
return tensor;
|
||||
|
|
|
|||
|
|
@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
}
|
||||
ORT_CATCH_RETURN
|
||||
}
|
||||
|
||||
|
||||
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
|
||||
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
|
||||
{
|
||||
|
|
@ -1158,28 +1158,19 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
{
|
||||
ORT_TRY
|
||||
{
|
||||
*tensor = nullptr;
|
||||
|
||||
auto constantInput = m_constantInputGetter(inputIndex);
|
||||
if (!std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput))
|
||||
{
|
||||
assert(std::find(
|
||||
m_requiredConstantCpuInputs.begin(),
|
||||
m_requiredConstantCpuInputs.end(),
|
||||
inputIndex) == m_requiredConstantCpuInputs.end());
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput));
|
||||
|
||||
auto tensorWrapper = std::get<ComPtr<IMLOperatorTensor>>(constantInput);
|
||||
if (tensorWrapper == nullptr)
|
||||
{
|
||||
assert(std::find(
|
||||
m_requiredConstantCpuInputs.begin(),
|
||||
m_requiredConstantCpuInputs.end(),
|
||||
inputIndex) == m_requiredConstantCpuInputs.end());
|
||||
|
||||
return S_OK;
|
||||
bool inputRequiredAsConstant = std::find(
|
||||
m_requiredConstantCpuInputs.begin(),
|
||||
m_requiredConstantCpuInputs.end(),
|
||||
inputIndex) != m_requiredConstantCpuInputs.end();
|
||||
|
||||
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
|
||||
}
|
||||
|
||||
*tensor = tensorWrapper.Detach();
|
||||
|
|
|
|||
|
|
@ -479,17 +479,37 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1);
|
||||
if (constExpTensor && constExpTensor->GetTotalElementCount() == 1)
|
||||
{
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = {0};
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.ExponentTensor = &inputDescs[1];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
|
||||
DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
opDesc.Exponent = static_cast<float>(ReadScalarTensorCastToFloat64(*constExpTensor));
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
|
||||
}
|
||||
else
|
||||
{
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.ExponentTensor = &inputDescs[1];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue