diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 666dc9c0f1..4847bc4d8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -25,7 +25,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 107; + static constexpr auto ValueCount = 110; static constexpr size_t ActivationFunctionCount = 19; }; @@ -688,6 +688,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND; +}; + template <> struct OperatorDescTraits { @@ -1330,6 +1348,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES> using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS> +{ + using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND> +{ + using DescType = DML_GATHER_ND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND> +{ + using DescType = DML_SCATTER_ND_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -1631,6 +1667,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_REVERSE_SUBSEQUENCES: return std::invoke(std::forward(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ELEMENTS: + return std::invoke(std::forward(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ND: + return std::invoke(std::forward(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SCATTER_ND: + return std::invoke(std::forward(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARDMAX: @@ -1768,6 +1810,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; + case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; + case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; + case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index b95fbfc95c..ff2a5e0fbd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1393,6 +1393,53 @@ constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA { DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_ELEMENTS_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER_ELEMENTS", + DML_OPERATOR_GATHER_ELEMENTS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER_ND", + DML_OPERATOR_GATHER_ND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA { + "DML_OPERATOR_SCATTER_ND", + DML_OPERATOR_SCATTER_ND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS, +}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index d565590ece..04411a94fe 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -822,6 +822,36 @@ inline std::vector GetFields(const DML_REVERSE_SUBSEQUENCES_OPERA OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), }; } +inline std::vector GetFields(const DML_GATHER_ELEMENTS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_GATHER_ND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputDimensionCount))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), + }; +} +inline std::vector GetFields(const DML_SCATTER_ND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.UpdatesTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.InputDimensionCount))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1064,6 +1094,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA; case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA; + case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; @@ -1440,6 +1473,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER_ELEMENTS: + return AbstractOperatorDesc( + &DML_GATHER_ELEMENTS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER_ND: + return AbstractOperatorDesc( + &DML_GATHER_ND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SCATTER_ND: + return AbstractOperatorDesc( + &DML_SCATTER_ND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp index 01dd7379d3..54ef525b46 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGather.cpp @@ -43,9 +43,81 @@ public: } }; +class DmlOperatorGatherElements : public DmlOperator +{ +public: + DmlOperatorGatherElements(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherElements expects 2 inputs."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherElements expects 1 output."); + + DmlOperator::Initialize(kernelCreationContext); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + assert(inputDescs.size() == 2); + assert(outputDescs.size() == 1); + + m_inputTensorDescs[1].ForceUnsignedDataType(); + + int32_t signedOnnxAxis = kernelCreationContext.GetOptionalAttribute(AttrName::Axis, 0); + auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); + uint32_t dmlAxis = GetDmlAdjustedAxis(signedOnnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount()); + + DML_GATHER_ELEMENTS_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[0]; + operatorDesc.IndicesTensor = &inputDescs[1]; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.Axis = dmlAxis; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ELEMENTS, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +class DmlOperatorGatherNd : public DmlOperator, public GatherNdHelper +{ +public: + DmlOperatorGatherNd(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + GatherNdHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "GatherND expects 2 inputs."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "GatherND expects 1 output."); + + DmlOperator::Initialize(kernelCreationContext); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + assert(inputDescs.size() == 2); + assert(outputDescs.size() == 1); + + m_inputTensorDescs[1].ForceUnsignedDataType(); + + auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount); + + DML_GATHER_ND_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[0]; + operatorDesc.IndicesTensor = &inputDescs[1]; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InputDimensionCount = static_cast(dataDimensions.size()); + operatorDesc.IndicesDimensionCount = static_cast(indicesDimensions.size()); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER_ND, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + DML_OP_DEFINE_CREATION_FUNCTION(Gather, DmlOperatorGather); -// TODO::: -DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGather); -DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGather); +DML_OP_DEFINE_CREATION_FUNCTION(GatherElements, DmlOperatorGatherElements); +DML_OP_DEFINE_CREATION_FUNCTION(GatherND, DmlOperatorGatherNd); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp index 97b881d542..b9dfef8ba3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp @@ -19,7 +19,7 @@ public: std::vector dataDimensions = tensorShapeDescription.GetInputTensorShape(0); std::vector indicesDimensions = tensorShapeDescription.GetInputTensorShape(1); std::vector updatesDimensions = tensorShapeDescription.GetInputTensorShape(2); - std::vector outputDimensions = tensorShapeDescription.GetInputTensorShape(0); + std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions); ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions); ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size()); @@ -74,9 +74,51 @@ public: } }; -DML_OP_DEFINE_CREATION_FUNCTION(Scatter, DmlOperatorScatter); -// TODO::: +class DmlOperatorScatterNd : public DmlOperator +{ +public: + DmlOperatorScatterNd(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "ScatterND expects 3 inputs."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "ScatterND expects 1 output."); + + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector dataDimensions = tensorShapeDescription.GetInputTensorShape(0); + std::vector indicesDimensions = tensorShapeDescription.GetInputTensorShape(1); + std::vector updatesDimensions = tensorShapeDescription.GetInputTensorShape(2); + std::vector outputDimensions = tensorShapeDescription.GetOutputTensorShape(0); + ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions); + ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() <= OperatorHelper::NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(updatesDimensions.size() <= OperatorHelper::NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount); + + DmlOperator::Initialize(kernelCreationContext); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + assert(inputDescs.size() == 3); + assert(outputDescs.size() == 1); + + m_inputTensorDescs[1].ForceUnsignedDataType(); + + DML_SCATTER_ND_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[0]; + operatorDesc.IndicesTensor = &inputDescs[1]; + operatorDesc.UpdatesTensor = &inputDescs[2]; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InputDimensionCount = static_cast(dataDimensions.size()); + operatorDesc.IndicesDimensionCount = static_cast(indicesDimensions.size()); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SCATTER_ND, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter); +DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter); -DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatter); +DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index a3770ca2fa..34066a6c88 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -202,7 +202,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shrink); DML_OP_EXTERN_CREATION_FUNCTION(OneHot); DML_OP_EXTERN_CREATION_FUNCTION(EyeLike); DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool); -DML_OP_EXTERN_CREATION_FUNCTION(Scatter); +DML_OP_EXTERN_CREATION_FUNCTION(Scatter9); +DML_OP_EXTERN_CREATION_FUNCTION(Scatter11); DML_OP_EXTERN_CREATION_FUNCTION(Resize); DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape); DML_OP_EXTERN_CREATION_FUNCTION(IsInf); @@ -346,7 +347,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})}, {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})}, {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, - {REG_INFO( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 2a17a8e288..c18652f29f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -355,8 +355,8 @@ namespace OperatorHelper std::vector GetOutputShapeAsInputShapeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { - assert(shapeInfo.GetInputCount() >= 1); - std::vector outputDimensions = shapeInfo.GetInputTensorShape(0); + assert(shapeInfo.GetInputCount() > m_inputTensorIndex); + std::vector outputDimensions = shapeInfo.GetInputTensorShape(m_inputTensorIndex); return { std::move(outputDimensions) }; } @@ -591,102 +591,26 @@ namespace OperatorHelper return { EdgeShapes(std::move(outputDimensions)) }; } -// TODO::: - - void GatherNDHelper::Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions - ) - { - int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute(AttrName::Axis, 0); - uint32_t inputRank = gsl::narrow_cast(inputDimensions.size()); - m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank); - } - - std::vector GatherNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector GatherNdHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { std::vector inputDimensions = shapeInfo.GetInputTensorShape(0); std::vector indicesDimensions = shapeInfo.GetInputTensorShape(1); + // Determine the number of output dimensions. ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1); - ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0); - int outDimCount = gsl::narrow_cast(inputDimensions.size() + indicesDimensions.size() - 1); - ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount); + ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 1); + const uint32_t numberOfCoordinatesPerIndex = indicesDimensions.back(); + ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= numberOfCoordinatesPerIndex); + const uint32_t numberOfOutputDimensionsFromInput = static_cast(inputDimensions.size()) - numberOfCoordinatesPerIndex; + const uint32_t numberOfOutputDimensionsFromIndices = static_cast(indicesDimensions.size()) - 1; // Strip off last dimension. + uint32_t outputDimensionCount = gsl::narrow_cast(numberOfOutputDimensionsFromIndices + numberOfOutputDimensionsFromInput); + ML_CHECK_VALID_ARGUMENT(outputDimensionCount > 0 && outputDimensionCount <= NchwDimensionCount); - std::vector outputDimensions(outDimCount, 1); - - // The input dimensions following the gather axis determine the final output dimensions. - int outputDim = outDimCount - 1; - int inputDim = gsl::narrow_cast(inputDimensions.size() - 1); - for (; inputDim > m_axis; --outputDim, --inputDim) - { - outputDimensions[outputDim] = inputDimensions[inputDim]; - } - - // The shape of the index tensor is reflected in the middle dimensions of the output tensor. - int indexDim = gsl::narrow_cast(indicesDimensions.size() - 1); - for (; indexDim >= 0; --outputDim, --indexDim) - { - outputDimensions[outputDim] = indicesDimensions[indexDim]; - } - - // The gather dimension is skipped for the purposes of sizing because the index values choose slices - // across it. Preceding input dimensions determine the shape of the output's leading dimensions. - inputDim = m_axis - 1; - for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim) - { - outputDimensions[outputDim] = inputDimensions[inputDim]; - } - - return { EdgeShapes(std::move(outputDimensions)) }; - } - -// TODO::: - - void ScatterNDHelper::Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions - ) - { - int32_t signedOnnxAxis = operatorAttributes.GetOptionalAttribute(AttrName::Axis, 0); - uint32_t inputRank = gsl::narrow_cast(inputDimensions.size()); - m_axis = HandleNegativeAxis(signedOnnxAxis, inputRank); - } - - std::vector ScatterNDHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const - { - std::vector inputDimensions = shapeInfo.GetInputTensorShape(0); - std::vector indicesDimensions = shapeInfo.GetInputTensorShape(1); - - ML_CHECK_VALID_ARGUMENT(inputDimensions.size() >= 1); - ML_CHECK_VALID_ARGUMENT(indicesDimensions.size() >= 0); - int outDimCount = gsl::narrow_cast(inputDimensions.size() + indicesDimensions.size() - 1); - ML_CHECK_VALID_ARGUMENT(outDimCount > 0 && outDimCount <= NchwDimensionCount); - - std::vector outputDimensions(outDimCount, 1); - - // The input dimensions following the gather axis determine the final output dimensions. - int outputDim = outDimCount - 1; - int inputDim = gsl::narrow_cast(inputDimensions.size() - 1); - for (; inputDim > m_axis; --outputDim, --inputDim) - { - outputDimensions[outputDim] = inputDimensions[inputDim]; - } - - // The shape of the index tensor is reflected in the middle dimensions of the output tensor. - int indexDim = gsl::narrow_cast(indicesDimensions.size() - 1); - for (; indexDim >= 0; --outputDim, --indexDim) - { - outputDimensions[outputDim] = indicesDimensions[indexDim]; - } - - // The gather dimension is skipped for the purposes of sizing because the index values choose slices - // across it. Preceding input dimensions determine the shape of the output's leading dimensions. - inputDim = m_axis - 1; - for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim) - { - outputDimensions[outputDim] = inputDimensions[inputDim]; - } + // Form the full expected size by concatenating the prefix part of the indices tensor shape + // with the suffix of the input tensor shape. + std::vector outputDimensions; + outputDimensions.assign(indicesDimensions.begin(), indicesDimensions.end() - 1); + outputDimensions.insert(outputDimensions.end(), inputDimensions.end() - numberOfOutputDimensionsFromInput, inputDimensions.end()); return { EdgeShapes(std::move(outputDimensions)) }; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 585011ff91..5b8757199f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -927,44 +927,15 @@ class GatherHelper { int m_axis = 0; }; -class GatherNDHelper { +class GatherNdHelper { public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span dataDimensions - ); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template - GatherNDHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); + GatherNdHelper(const Info_t& info, const Shape_t& shape) { } std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - - protected: - int m_axis = 0; -}; - -class ScatterNDHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span dataDimensions - ); - - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - ScatterNDHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } - - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - - protected: - int m_axis = 0; }; class PoolingHelperBase { @@ -1291,9 +1262,10 @@ using ShapeInferenceHelper_LSTM = RecurrentHelper; using ShapeInferenceHelper_Gather = GatherHelper; using ShapeInferenceHelper_GatherElements = GetOutputShapeAsSpecificInputShapeHelper<1>; using ShapeInferenceHelper_ScatterElements = GetOutputShapeAsInputShapeHelper; -using ShapeInferenceHelper_Scatter = ShapeInferenceHelper_ScatterElements; -using ShapeInferenceHelper_GatherND = GatherNDHelper; -using ShapeInferenceHelper_ScatterND = ScatterNDHelper; +using ShapeInferenceHelper_Scatter9 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements. +using ShapeInferenceHelper_Scatter11 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements. +using ShapeInferenceHelper_GatherND = GatherNdHelper; +using ShapeInferenceHelper_ScatterND = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Flatten = FlattenHelper; using ShapeInferenceHelper_Split = SplitHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index 9d8cb743ea..5fda5b29fe 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -224,6 +224,7 @@ namespace OperatorHelper static const int sc_sinceVer_Resize = 11; static const int sc_sinceVer_Round = 11; static const int sc_sinceVer_Scan = 11; + static const int sc_sinceVer_Scatter = 11; // Deprecated alias static const int sc_sinceVer_ScatterElements = 11; static const int sc_sinceVer_ScatterND = 11; static const int sc_sinceVer_Slice = 11;