mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Fix GatherND and ScatterND.
This commit is contained in:
parent
271126a86e
commit
5366273110
9 changed files with 286 additions and 136 deletions
|
|
@ -25,7 +25,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 107;
|
||||
static constexpr auto ValueCount = 110;
|
||||
static constexpr size_t ActivationFunctionCount = 19;
|
||||
};
|
||||
|
||||
|
|
@ -688,6 +688,24 @@ struct OperatorDescTraits<DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_GATHER_ELEMENTS_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_GATHER_ND_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_SCATTER_ND_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1330,6 +1348,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES>
|
|||
using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS>
|
||||
{
|
||||
using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND>
|
||||
{
|
||||
using DescType = DML_GATHER_ND_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND>
|
||||
{
|
||||
using DescType = DML_SCATTER_ND_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
|
||||
{
|
||||
|
|
@ -1631,6 +1667,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_GATHER_ELEMENTS:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_GATHER_ND:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_SCATTER_ND:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
|
|
@ -1768,6 +1810,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE";
|
||||
case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION";
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES";
|
||||
case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
|
||||
case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
|
||||
case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND";
|
||||
default:
|
||||
assert(false);
|
||||
return "<unknown>";
|
||||
|
|
|
|||
|
|
@ -1393,6 +1393,53 @@ constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA {
|
|||
DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ELEMENTS_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_GATHER_ELEMENTS",
|
||||
DML_OPERATOR_GATHER_ELEMENTS,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_GATHER_ND",
|
||||
DML_OPERATOR_GATHER_ND,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
5,
|
||||
DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_SCATTER_ND",
|
||||
DML_OPERATOR_SCATTER_ND,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
6,
|
||||
DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
|
|
|
|||
|
|
@ -822,6 +822,36 @@ inline std::vector<OperatorField> GetFields(const DML_REVERSE_SUBSEQUENCES_OPERA
|
|||
OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_GATHER_ELEMENTS_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
|
||||
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
|
||||
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
|
||||
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_SCATTER_ND_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.UpdatesTensor))),
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
|
||||
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1064,6 +1094,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
|
||||
|
|
@ -1440,6 +1473,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_GATHER_ELEMENTS:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_GATHER_ELEMENTS_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_GATHER_ND:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_GATHER_ND_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_GATHER_ND_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_SCATTER_ND:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_SCATTER_ND_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_SCATTER_ND_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
|
||||
|
|
|
|||
|
|
@ -43,9 +43,81 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class DmlOperatorGatherElements : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorGatherElements(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 2);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
|
||||
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
|
||||
|
||||
DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.IndicesTensor = &inputDescs[1];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.Axis = dmlAxis;
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ELEMENTS, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
class DmlOperatorGatherNd : public DmlOperator, public GatherNdHelper
|
||||
{
|
||||
public:
|
||||
DmlOperatorGatherNd(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext),
|
||||
GatherNdHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output.");
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 2);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
|
||||
DML_GATHER_ND_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.IndicesTensor = &inputDescs[1];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
|
||||
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Gather, DmlOperatorGather);
|
||||
// TODO:::
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGather);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGather);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGatherElements);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGatherNd);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ public:
|
|||
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
|
||||
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size());
|
||||
|
|
@ -74,9 +74,51 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter, DmlOperatorScatter);
|
||||
// TODO:::
|
||||
class DmlOperatorScatterNd : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorScatterNd(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "ScatterND expects 3 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "ScatterND expects 1 output.");
|
||||
|
||||
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
|
||||
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
|
||||
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
|
||||
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
|
||||
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
assert(inputDescs.size() == 3);
|
||||
assert(outputDescs.size() == 1);
|
||||
|
||||
m_inputTensorDescs[1].ForceUnsignedDataType();
|
||||
|
||||
DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.IndicesTensor = &inputDescs[1];
|
||||
operatorDesc.UpdatesTensor = &inputDescs[2];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
|
||||
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SCATTER_ND, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -202,7 +202,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shrink);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(OneHot);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(EyeLike);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter9);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Resize);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(IsInf);
|
||||
|
|
@ -346,7 +347,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
|
|||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
|
||||
// Data reorganization that merely changes the dimensions while keeping the data identical.
|
||||
|
|
|
|||
|
|
@ -355,8 +355,8 @@ namespace OperatorHelper
|
|||
|
||||
std::vector<EdgeShapes> GetOutputShapeAsInputShapeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
assert(shapeInfo.GetInputCount() >= 1);
|
||||
std::vector<DimensionType> outputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
assert(shapeInfo.GetInputCount() > m_inputTensorIndex);
|
||||
std::vector<DimensionType> outputDimensions = shapeInfo.GetInputTensorShape(m_inputTensorIndex);
|
||||
return { std::move(outputDimensions) };
|
||||
}
|
||||
|
||||
|
|
@ -591,102 +591,26 @@ namespace OperatorHelper
|
|||
return { EdgeShapes(std::move(outputDimensions)) };
|
||||
}
|
||||
|
||||
// TODO:::
|
||||
|
||||
void GatherNDHelper::Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
)
|
||||
{
|
||||
int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute<int>(AttrName::Axis, 0);
|
||||
uint32_t inputRank = gsl::narrow_cast<int>(inputDimensions.size());
|
||||
m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GatherNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
std::vector<EdgeShapes> GatherNdHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
|
||||
|
||||
// Determine the number of output dimensions.
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
|
||||
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
|
||||
ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1);
|
||||
const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back();
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex);
|
||||
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - numberOfCoordinatesPerIndex;
|
||||
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - 1; // Strip off last dimension.
|
||||
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount);
|
||||
|
||||
std::vector<DimensionType> outputDimensions(outDimCount, 1);
|
||||
|
||||
// The input dimensions following the gather axis determine the final output dimensions.
|
||||
int outputDim = outDimCount - 1;
|
||||
int inputDim = gsl::narrow_cast<int>(inputDimensions.size() - 1);
|
||||
for (; inputDim > m_axis; --outputDim, --inputDim)
|
||||
{
|
||||
outputDimensions[outputDim] = inputDimensions[inputDim];
|
||||
}
|
||||
|
||||
// The shape of the index tensor is reflected in the middle dimensions of the output tensor.
|
||||
int indexDim = gsl::narrow_cast<int>(indicesDimensions.size() - 1);
|
||||
for (; indexDim >= 0; --outputDim, --indexDim)
|
||||
{
|
||||
outputDimensions[outputDim] = indicesDimensions[indexDim];
|
||||
}
|
||||
|
||||
// The gather dimension is skipped for the purposes of sizing because the index values choose slices
|
||||
// across it. Preceding input dimensions determine the shape of the output's leading dimensions.
|
||||
inputDim = m_axis - 1;
|
||||
for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim)
|
||||
{
|
||||
outputDimensions[outputDim] = inputDimensions[inputDim];
|
||||
}
|
||||
|
||||
return { EdgeShapes(std::move(outputDimensions)) };
|
||||
}
|
||||
|
||||
// TODO:::
|
||||
|
||||
void ScatterNDHelper::Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
)
|
||||
{
|
||||
int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute<int>(AttrName::Axis, 0);
|
||||
uint32_t inputRank = gsl::narrow_cast<int>(inputDimensions.size());
|
||||
m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> ScatterNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
|
||||
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
|
||||
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
|
||||
ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount);
|
||||
|
||||
std::vector<DimensionType> outputDimensions(outDimCount, 1);
|
||||
|
||||
// The input dimensions following the gather axis determine the final output dimensions.
|
||||
int outputDim = outDimCount - 1;
|
||||
int inputDim = gsl::narrow_cast<int>(inputDimensions.size() - 1);
|
||||
for (; inputDim > m_axis; --outputDim, --inputDim)
|
||||
{
|
||||
outputDimensions[outputDim] = inputDimensions[inputDim];
|
||||
}
|
||||
|
||||
// The shape of the index tensor is reflected in the middle dimensions of the output tensor.
|
||||
int indexDim = gsl::narrow_cast<int>(indicesDimensions.size() - 1);
|
||||
for (; indexDim >= 0; --outputDim, --indexDim)
|
||||
{
|
||||
outputDimensions[outputDim] = indicesDimensions[indexDim];
|
||||
}
|
||||
|
||||
// The gather dimension is skipped for the purposes of sizing because the index values choose slices
|
||||
// across it. Preceding input dimensions determine the shape of the output's leading dimensions.
|
||||
inputDim = m_axis - 1;
|
||||
for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim)
|
||||
{
|
||||
outputDimensions[outputDim] = inputDimensions[inputDim];
|
||||
}
|
||||
// Form the full expected size by concatenating the prefix part of the indices tensor shape
|
||||
// with the suffix of the input tensor shape.
|
||||
std::vector<DimensionType> outputDimensions;
|
||||
outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1);
|
||||
outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end());
|
||||
|
||||
return { EdgeShapes(std::move(outputDimensions)) };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -927,44 +927,15 @@ class GatherHelper {
|
|||
int m_axis = 0;
|
||||
};
|
||||
|
||||
class GatherNDHelper {
|
||||
class GatherNdHelper {
|
||||
public:
|
||||
void Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> dataDimensions
|
||||
);
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
GatherNDHelper(const Info_t& info, const Shape_t& shape) {
|
||||
Initialize(info, shape.GetInputTensorShape(0));
|
||||
GatherNdHelper(const Info_t& info, const Shape_t& shape) {
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
protected:
|
||||
int m_axis = 0;
|
||||
};
|
||||
|
||||
class ScatterNDHelper {
|
||||
public:
|
||||
void Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> dataDimensions
|
||||
);
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
ScatterNDHelper(const Info_t& info, const Shape_t& shape) {
|
||||
Initialize(info, shape.GetInputTensorShape(0));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
protected:
|
||||
int m_axis = 0;
|
||||
};
|
||||
|
||||
class PoolingHelperBase {
|
||||
|
|
@ -1291,9 +1262,10 @@ using ShapeInferenceHelper_LSTM = RecurrentHelper;
|
|||
using ShapeInferenceHelper_Gather = GatherHelper;
|
||||
using ShapeInferenceHelper_GatherElements = GetOutputShapeAsSpecificInputShapeHelper<1>;
|
||||
using ShapeInferenceHelper_ScatterElements = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Scatter = ShapeInferenceHelper_ScatterElements;
|
||||
using ShapeInferenceHelper_GatherND = GatherNDHelper;
|
||||
using ShapeInferenceHelper_ScatterND = ScatterNDHelper;
|
||||
using ShapeInferenceHelper_Scatter9 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements.
|
||||
using ShapeInferenceHelper_Scatter11 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements.
|
||||
using ShapeInferenceHelper_GatherND = GatherNdHelper;
|
||||
using ShapeInferenceHelper_ScatterND = GetOutputShapeAsInputShapeHelper;
|
||||
|
||||
using ShapeInferenceHelper_Flatten = FlattenHelper;
|
||||
using ShapeInferenceHelper_Split = SplitHelper;
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Resize = 11;
|
||||
static const int sc_sinceVer_Round = 11;
|
||||
static const int sc_sinceVer_Scan = 11;
|
||||
static const int sc_sinceVer_Scatter = 11; // Deprecated alias
|
||||
static const int sc_sinceVer_ScatterElements = 11;
|
||||
static const int sc_sinceVer_ScatterND = 11;
|
||||
static const int sc_sinceVer_Slice = 11;
|
||||
|
|
|
|||
Loading…
Reference in a new issue