From aa3a825816555d1f859385a256804dfceefebd2f Mon Sep 17 00:00:00 2001 From: sumitsays Date: Tue, 7 Jun 2022 14:31:55 -0700 Subject: [PATCH] Added Softmax/Hardmax/LogSoftmax-13 (#11772) * Added Softmax/Hardmax/LogSoftmax-13 * Removed redundant method specifier Co-authored-by: Sumit Agarwal --- .../src/External/DirectMLHelpers/ApiHelpers.h | 6 +++ .../src/External/DirectMLHelpers/ApiTraits.h | 46 ++++++++++++++++- .../External/DirectMLHelpers/DirectMLSchema.h | 45 +++++++++++++++++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 42 ++++++++++++++++ .../src/Operators/DmlOperatorActivation.cpp | 49 +++++++++++++++---- .../src/Operators/OperatorRegistration.cpp | 12 ++--- .../dml/OperatorAuthorHelper/OperatorHelper.h | 3 ++ .../OperatorAuthorHelper/OperatorVersions.h | 3 ++ 8 files changed, 189 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h index 28ab7e167f..8c85e4ec1d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h @@ -9,10 +9,12 @@ union ActivationOperatorDescUnion DML_ACTIVATION_ELU_OPERATOR_DESC elu; DML_ACTIVATION_CELU_OPERATOR_DESC celu; DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax; + DML_ACTIVATION_HARDMAX1_OPERATOR_DESC hardmax1; DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid; DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu; DML_ACTIVATION_LINEAR_OPERATOR_DESC linear; DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC logSoftmax; + DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC logSoftmax1; DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC parameterizedRelu; DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC parametricSoftplus; DML_ACTIVATION_RELU_OPERATOR_DESC relu; @@ -20,6 +22,7 @@ union ActivationOperatorDescUnion DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC scaledElu; DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoid; DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax; + DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmax1; DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus; DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC softsign; DML_ACTIVATION_TANH_OPERATOR_DESC tanh; @@ -41,11 +44,13 @@ struct ActivationOperatorDesc case DML_OPERATOR_ACTIVATION_ELU: return { activationType, ¶ms.elu }; case DML_OPERATOR_ACTIVATION_CELU: return { activationType, ¶ms.celu }; case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, ¶ms.hardmax }; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return { activationType, ¶ms.hardmax1 }; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, ¶ms.sigmoid }; case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, ¶ms.identity }; case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return { activationType, ¶ms.leakyRelu }; case DML_OPERATOR_ACTIVATION_LINEAR: return { activationType, ¶ms.linear }; case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return { activationType, ¶ms.logSoftmax }; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return { activationType, ¶ms.logSoftmax1 }; case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return { activationType, ¶ms.parameterizedRelu }; case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return { activationType, ¶ms.parametricSoftplus }; case DML_OPERATOR_ACTIVATION_RELU: return { activationType, ¶ms.relu }; @@ -53,6 +58,7 @@ struct ActivationOperatorDesc case DML_OPERATOR_ACTIVATION_SCALED_TANH: return { activationType, ¶ms.scaledTanh }; case DML_OPERATOR_ACTIVATION_SIGMOID: return { activationType, ¶ms.sigmoid }; case DML_OPERATOR_ACTIVATION_SOFTMAX: return { activationType, ¶ms.softmax }; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return { activationType, ¶ms.softmax1 }; case DML_OPERATOR_ACTIVATION_SOFTPLUS: return { activationType, ¶ms.softplus }; case DML_OPERATOR_ACTIVATION_SOFTSIGN: return { activationType, ¶ms.softsign }; case DML_OPERATOR_ACTIVATION_TANH: return { activationType, ¶ms.tanh }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 1774a8a2b0..7d4c75e6ca 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 154; - static constexpr size_t ActivationFunctionCount = 21; + static constexpr auto ValueCount = 157; + static constexpr size_t ActivationFunctionCount = 24; }; template <> @@ -1011,6 +1011,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX1; +}; + template <> struct OperatorDescTraits { @@ -1041,6 +1047,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1; +}; + template <> struct OperatorDescTraits { @@ -1083,6 +1095,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX1; +}; + template <> struct OperatorDescTraits { @@ -1959,6 +1977,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX> using DescType = DML_ACTIVATION_HARDMAX_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX1> +{ + using DescType = DML_ACTIVATION_HARDMAX1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SIGMOID> { @@ -1989,6 +2013,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX using DescType = DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1> +{ + using DescType = DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU> { @@ -2031,6 +2061,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX> using DescType = DML_ACTIVATION_SOFTMAX_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX1> +{ + using DescType = DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTPLUS> { @@ -2360,6 +2396,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARDMAX: return std::invoke(std::forward(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARDMAX1: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARDMAX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_IDENTITY: @@ -2370,6 +2408,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_LINEAR_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return std::invoke(std::forward(visitor), DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: + return std::invoke(std::forward(visitor), DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: @@ -2384,6 +2424,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_SOFTMAX: return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SOFTMAX1: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_SOFTPLUS: return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_SOFTSIGN: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 137ea2b030..a993300291 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -2288,6 +2288,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA { DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARDMAX1", + DML_OPERATOR_ACTIVATION_HARDMAX1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2358,6 +2373,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA { DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", + DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SlopeTensor", false }, @@ -2456,6 +2486,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA { DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SOFTMAX1", + DML_OPERATOR_ACTIVATION_SOFTMAX1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index e0a705330f..b37389ce17 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1404,6 +1404,15 @@ inline std::vector GetFields(const DML_ACTIVATION_HARDMAX_OPERATO OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_HARDMAX1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + }; +} inline std::vector GetFields(const DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC& desc) { return { @@ -1444,6 +1453,15 @@ inline std::vector GetFields(const DML_ACTIVATION_LOG_SOFTMAX_OPE OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + }; +} inline std::vector GetFields(const DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC& desc) { return { @@ -1500,6 +1518,15 @@ inline std::vector GetFields(const DML_ACTIVATION_SOFTMAX_OPERATO OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + }; +} inline std::vector GetFields(const DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC& desc) { return { @@ -1689,11 +1716,13 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_LINEAR: return DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_RELU: return DML_ACTIVATION_RELU_OPERATOR_SCHEMA; @@ -1701,6 +1730,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_SCALED_TANH: return DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SIGMOID: return DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SOFTMAX: return DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SOFTPLUS: return DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SOFTSIGN: return DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_TANH: return DML_ACTIVATION_TANH_OPERATOR_SCHEMA; @@ -2276,6 +2306,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARDMAX1: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return AbstractOperatorDesc( &DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA, @@ -2296,6 +2330,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: + return AbstractOperatorDesc( + &DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return AbstractOperatorDesc( &DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA, @@ -2324,6 +2362,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SOFTMAX1: + return AbstractOperatorDesc( + &DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_SOFTPLUS: return AbstractOperatorDesc( &DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp index c7c9e50c90..61baa10cfb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorActivation.cpp @@ -25,7 +25,7 @@ public: ActivationOperatorDescUnion operatorDesc = {}; - int coerceAxis = TensorAxis::DoNotCoerce; + std::vector dmlAxes; switch (operatorType) { @@ -39,7 +39,29 @@ public: case DML_OPERATOR_ACTIVATION_HARDMAX: { const uint32_t onnxDimCount = gsl::narrow_cast(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size()); - coerceAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute(AttrName::Axis, 1), onnxDimCount); + int axis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute(AttrName::Axis, 1), onnxDimCount); + std::vector onnxAxes(onnxDimCount - axis); + std::iota(onnxAxes.begin(), onnxAxes.end(), static_cast(axis)); + + dmlAxes.resize(onnxDimCount - axis); + GetDmlAdjustedAxes(onnxAxes, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount(), /*out*/ dmlAxes); + + operatorDesc.hardmax1.Axes = dmlAxes.data(); + operatorDesc.hardmax1.AxisCount = gsl::narrow_cast(dmlAxes.size()); + } + break; + + case DML_OPERATOR_ACTIVATION_SOFTMAX1: + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: + case DML_OPERATOR_ACTIVATION_HARDMAX1: + { + const uint32_t onnxDimCount = gsl::narrow_cast(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size()); + int onnxAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute(AttrName::Axis, -1), onnxDimCount); + + dmlAxes.push_back(GetDmlAdjustedAxis(onnxAxis, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount())); + + operatorDesc.hardmax1.Axes = dmlAxes.data(); + operatorDesc.hardmax1.AxisCount = gsl::narrow_cast(dmlAxes.size()); } break; @@ -100,12 +122,6 @@ public: break; } - if (coerceAxis != TensorAxis::DoNotCoerce) - { - m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, coerceAxis); - m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, coerceAxis); - } - gsl::span outputSizes = m_outputTensorDescs[0].GetSizes(); std::vector inputDescs; std::vector outputDescs; @@ -135,9 +151,24 @@ public: operatorDesc.elu.OutputTensor = outputDescs.data(); } - DML_OPERATOR_DESC opDesc = { operatorType, &operatorDesc }; + DML_OPERATOR_DESC opDesc = { remappedOperatorType(operatorType), &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } + +private: + DML_OPERATOR_TYPE remappedOperatorType(const DML_OPERATOR_TYPE operatorType) const { + switch (operatorType) + { + case DML_OPERATOR_ACTIVATION_HARDMAX: + return DML_OPERATOR_ACTIVATION_HARDMAX1; + case DML_OPERATOR_ACTIVATION_SOFTMAX: + return DML_OPERATOR_ACTIVATION_SOFTMAX1; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: + return DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1; + default: + return operatorType; + } + } }; // A specific type of operation for registration. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index d39f177d37..bf1e4c762c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -179,8 +179,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Elu); DML_OP_EXTERN_CREATION_FUNCTION(Celu); DML_OP_EXTERN_CREATION_FUNCTION(Selu); DML_OP_EXTERN_CREATION_FUNCTION(Softmax); +DML_OP_EXTERN_CREATION_FUNCTION(Softmax13); DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax); +DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax13); DML_OP_EXTERN_CREATION_FUNCTION(Hardmax); +DML_OP_EXTERN_CREATION_FUNCTION(Hardmax13); DML_OP_EXTERN_CREATION_FUNCTION(Softsign); DML_OP_EXTERN_CREATION_FUNCTION(Softplus); DML_OP_EXTERN_CREATION_FUNCTION(ParametricSoftplus); @@ -642,16 +645,13 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - // TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879 - // {REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - // TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879 - // {REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - // TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879 - // {REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 188eb3dff0..daceecab5c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1459,8 +1459,11 @@ using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Softmax13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_LogSoftmax13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Hardmax = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Hardmax13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Softsign = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Softplus = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_ParametricSoftplus = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index ec6e4dca67..598d282bf2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -336,6 +336,9 @@ namespace OperatorHelper static const int sc_sinceVer_Transpose = 13; static const int sc_sinceVer_Unsqueeze = 13; static const int sc_sinceVer_ReduseSum = 13; + static const int sc_sinceVer_Softmax = 13; + static const int sc_sinceVer_LogSoftmax = 13; + static const int sc_sinceVer_Hardmax = 13; } // namespace OnnxOperatorSet13 namespace OnnxOperatorSet14