Fix GatherND and ScatterND.

This commit is contained in:
Dwayne Robinson 2020-03-24 18:02:39 -07:00
parent 271126a86e
commit 5366273110
9 changed files with 286 additions and 136 deletions

View file

@ -25,7 +25,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 107;
static constexpr auto ValueCount = 110;
static constexpr size_t ActivationFunctionCount = 19;
};
@ -688,6 +688,24 @@ struct OperatorDescTraits<DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES;
};
template <>
struct OperatorDescTraits<DML_GATHER_ELEMENTS_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS;
};
template <>
struct OperatorDescTraits<DML_GATHER_ND_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND;
};
template <>
struct OperatorDescTraits<DML_SCATTER_ND_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
@ -1330,6 +1348,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES>
using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS>
{
using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND>
{
using DescType = DML_GATHER_ND_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND>
{
using DescType = DML_SCATTER_ND_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
{
@ -1631,6 +1667,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
return std::invoke(std::forward<Visitor>(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_GATHER_ELEMENTS:
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_GATHER_ND:
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_SCATTER_ND:
return std::invoke(std::forward<Visitor>(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARDMAX:
@ -1768,6 +1810,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE";
case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION";
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES";
case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND";
default:
assert(false);
return "<unknown>";

View file

@ -1393,6 +1393,53 @@ constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA {
DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
};
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ELEMENTS_OPERATOR_SCHEMA {
"DML_OPERATOR_GATHER_ELEMENTS",
DML_OPERATOR_GATHER_ELEMENTS,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
};
constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND_OPERATOR_SCHEMA {
"DML_OPERATOR_GATHER_ND",
DML_OPERATOR_GATHER_ND,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
5,
DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false },
};
constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA {
"DML_OPERATOR_SCATTER_ND",
DML_OPERATOR_SCATTER_ND,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
6,
DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },

View file

@ -822,6 +822,36 @@ inline std::vector<OperatorField> GetFields(const DML_REVERSE_SUBSEQUENCES_OPERA
OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
};
}
inline std::vector<OperatorField> GetFields(const DML_GATHER_ELEMENTS_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
};
}
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_SCATTER_ND_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.IndicesTensor))),
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.UpdatesTensor))),
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.InputDimensionCount))),
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@ -1064,6 +1094,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA;
case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA;
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA;
case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
@ -1440,6 +1473,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_GATHER_ELEMENTS:
return AbstractOperatorDesc(
&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_GATHER_ELEMENTS_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_GATHER_ND:
return AbstractOperatorDesc(
&DML_GATHER_ND_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_GATHER_ND_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_SCATTER_ND:
return AbstractOperatorDesc(
&DML_SCATTER_ND_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_SCATTER_ND_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,

View file

@ -43,9 +43,81 @@ public:
}
};
class DmlOperatorGatherElements : public DmlOperator
{
public:
DmlOperatorGatherElements(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ELEMENTS, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
class DmlOperatorGatherNd : public DmlOperator, public GatherNdHelper
{
public:
DmlOperatorGatherNd(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
GatherNdHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
DML_GATHER_ND_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Gather, DmlOperatorGather);
// TODO:::
DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGather);
DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGather);
DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGatherElements);
DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGatherNd);
} // namespace Dml

View file

@ -19,7 +19,7 @@ public:
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size());
@ -74,9 +74,51 @@ public:
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Scatter, DmlOperatorScatter);
// TODO:::
class DmlOperatorScatterNd : public DmlOperator
{
public:
DmlOperatorScatterNd(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "ScatterND expects 3 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "ScatterND expects 1 output.");
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetOutputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount);
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 3);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.UpdatesTensor = &inputDescs[2];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InputDimensionCount = static_cast<uint32_t>(dataDimensions.size());
operatorDesc.IndicesDimensionCount = static_cast<uint32_t>(indicesDimensions.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SCATTER_ND, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd);
} // namespace Dml

View file

@ -202,7 +202,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shrink);
DML_OP_EXTERN_CREATION_FUNCTION(OneHot);
DML_OP_EXTERN_CREATION_FUNCTION(EyeLike);
DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter9);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter11);
DML_OP_EXTERN_CREATION_FUNCTION(Resize);
DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape);
DML_OP_EXTERN_CREATION_FUNCTION(IsInf);
@ -346,7 +347,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
// Data reorganization that merely changes the dimensions while keeping the data identical.

View file

@ -355,8 +355,8 @@ namespace OperatorHelper
std::vector<EdgeShapes> GetOutputShapeAsInputShapeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
assert(shapeInfo.GetInputCount() >= 1);
std::vector<DimensionType> outputDimensions = shapeInfo.GetInputTensorShape(0);
assert(shapeInfo.GetInputCount() > m_inputTensorIndex);
std::vector<DimensionType> outputDimensions = shapeInfo.GetInputTensorShape(m_inputTensorIndex);
return { std::move(outputDimensions) };
}
@ -591,102 +591,26 @@ namespace OperatorHelper
return { EdgeShapes(std::move(outputDimensions)) };
}
// TODO:::
void GatherNDHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
)
{
int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute<int>(AttrName::Axis, 0);
uint32_t inputRank = gsl::narrow_cast<int>(inputDimensions.size());
m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank);
}
std::vector<EdgeShapes> GatherNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
std::vector<EdgeShapes> GatherNdHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
// Determine the number of output dimensions.
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1);
const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back();
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex);
const uint32_t numberOfOutputDimensionsFromInput = static_cast<uint32_t>(inputDimensions.size()) - numberOfCoordinatesPerIndex;
const uint32_t numberOfOutputDimensionsFromIndices = static_cast<uint32_t>(indicesDimensions.size()) - 1; // Strip off last dimension.
uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput);
ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount);
std::vector<DimensionType> outputDimensions(outDimCount, 1);
// The input dimensions following the gather axis determine the final output dimensions.
int outputDim = outDimCount - 1;
int inputDim = gsl::narrow_cast<int>(inputDimensions.size() - 1);
for (; inputDim > m_axis; --outputDim, --inputDim)
{
outputDimensions[outputDim] = inputDimensions[inputDim];
}
// The shape of the index tensor is reflected in the middle dimensions of the output tensor.
int indexDim = gsl::narrow_cast<int>(indicesDimensions.size() - 1);
for (; indexDim >= 0; --outputDim, --indexDim)
{
outputDimensions[outputDim] = indicesDimensions[indexDim];
}
// The gather dimension is skipped for the purposes of sizing because the index values choose slices
// across it. Preceding input dimensions determine the shape of the output's leading dimensions.
inputDim = m_axis - 1;
for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim)
{
outputDimensions[outputDim] = inputDimensions[inputDim];
}
return { EdgeShapes(std::move(outputDimensions)) };
}
// TODO:::
void ScatterNDHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
)
{
int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute<int>(AttrName::Axis, 0);
uint32_t inputRank = gsl::narrow_cast<int>(inputDimensions.size());
m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank);
}
std::vector<EdgeShapes> ScatterNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
std::vector<DimensionType> inputDimensions = shapeInfo.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1);
ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0);
int outDimCount = gsl::narrow_cast<int>(inputDimensions.size() + indicesDimensions.size() - 1);
ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount);
std::vector<DimensionType> outputDimensions(outDimCount, 1);
// The input dimensions following the gather axis determine the final output dimensions.
int outputDim = outDimCount - 1;
int inputDim = gsl::narrow_cast<int>(inputDimensions.size() - 1);
for (; inputDim > m_axis; --outputDim, --inputDim)
{
outputDimensions[outputDim] = inputDimensions[inputDim];
}
// The shape of the index tensor is reflected in the middle dimensions of the output tensor.
int indexDim = gsl::narrow_cast<int>(indicesDimensions.size() - 1);
for (; indexDim >= 0; --outputDim, --indexDim)
{
outputDimensions[outputDim] = indicesDimensions[indexDim];
}
// The gather dimension is skipped for the purposes of sizing because the index values choose slices
// across it. Preceding input dimensions determine the shape of the output's leading dimensions.
inputDim = m_axis - 1;
for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim)
{
outputDimensions[outputDim] = inputDimensions[inputDim];
}
// Form the full expected size by concatenating the prefix part of the indices tensor shape
// with the suffix of the input tensor shape.
std::vector<DimensionType> outputDimensions;
outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1);
outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end());
return { EdgeShapes(std::move(outputDimensions)) };
}

View file

@ -927,44 +927,15 @@ class GatherHelper {
int m_axis = 0;
};
class GatherNDHelper {
class GatherNdHelper {
public:
void Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> dataDimensions
);
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
GatherNDHelper(const Info_t& info, const Shape_t& shape) {
Initialize(info, shape.GetInputTensorShape(0));
GatherNdHelper(const Info_t& info, const Shape_t& shape) {
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
protected:
int m_axis = 0;
};
class ScatterNDHelper {
public:
void Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> dataDimensions
);
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
ScatterNDHelper(const Info_t& info, const Shape_t& shape) {
Initialize(info, shape.GetInputTensorShape(0));
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
protected:
int m_axis = 0;
};
class PoolingHelperBase {
@ -1291,9 +1262,10 @@ using ShapeInferenceHelper_LSTM = RecurrentHelper;
using ShapeInferenceHelper_Gather = GatherHelper;
using ShapeInferenceHelper_GatherElements = GetOutputShapeAsSpecificInputShapeHelper<1>;
using ShapeInferenceHelper_ScatterElements = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Scatter = ShapeInferenceHelper_ScatterElements;
using ShapeInferenceHelper_GatherND = GatherNDHelper;
using ShapeInferenceHelper_ScatterND = ScatterNDHelper;
using ShapeInferenceHelper_Scatter9 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements.
using ShapeInferenceHelper_Scatter11 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements.
using ShapeInferenceHelper_GatherND = GatherNdHelper;
using ShapeInferenceHelper_ScatterND = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Flatten = FlattenHelper;
using ShapeInferenceHelper_Split = SplitHelper;

View file

@ -224,6 +224,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Resize = 11;
static const int sc_sinceVer_Round = 11;
static const int sc_sinceVer_Scan = 11;
static const int sc_sinceVer_Scatter = 11; // Deprecated alias
static const int sc_sinceVer_ScatterElements = 11;
static const int sc_sinceVer_ScatterND = 11;
static const int sc_sinceVer_Slice = 11;