mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Added Softmax/Hardmax/LogSoftmax-13 (#11772)
* Added Softmax/Hardmax/LogSoftmax-13 * Removed redundant method specifier Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
parent
a7fa735286
commit
aa3a825816
8 changed files with 189 additions and 17 deletions
|
|
@ -9,10 +9,12 @@ union ActivationOperatorDescUnion
|
|||
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
|
||||
DML_ACTIVATION_CELU_OPERATOR_DESC celu;
|
||||
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
|
||||
DML_ACTIVATION_HARDMAX1_OPERATOR_DESC hardmax1;
|
||||
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
|
||||
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
|
||||
DML_ACTIVATION_LINEAR_OPERATOR_DESC linear;
|
||||
DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC logSoftmax;
|
||||
DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC logSoftmax1;
|
||||
DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC parameterizedRelu;
|
||||
DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC parametricSoftplus;
|
||||
DML_ACTIVATION_RELU_OPERATOR_DESC relu;
|
||||
|
|
@ -20,6 +22,7 @@ union ActivationOperatorDescUnion
|
|||
DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC scaledElu;
|
||||
DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoid;
|
||||
DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax;
|
||||
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmax1;
|
||||
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus;
|
||||
DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC softsign;
|
||||
DML_ACTIVATION_TANH_OPERATOR_DESC tanh;
|
||||
|
|
@ -41,11 +44,13 @@ struct ActivationOperatorDesc
|
|||
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, ¶ms.elu };
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, ¶ms.celu };
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, ¶ms.hardmax };
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX1: return { activationType, ¶ms.hardmax1 };
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, ¶ms.sigmoid };
|
||||
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, ¶ms.identity };
|
||||
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return { activationType, ¶ms.leakyRelu };
|
||||
case DML_OPERATOR_ACTIVATION_LINEAR: return { activationType, ¶ms.linear };
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return { activationType, ¶ms.logSoftmax };
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return { activationType, ¶ms.logSoftmax1 };
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return { activationType, ¶ms.parameterizedRelu };
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return { activationType, ¶ms.parametricSoftplus };
|
||||
case DML_OPERATOR_ACTIVATION_RELU: return { activationType, ¶ms.relu };
|
||||
|
|
@ -53,6 +58,7 @@ struct ActivationOperatorDesc
|
|||
case DML_OPERATOR_ACTIVATION_SCALED_TANH: return { activationType, ¶ms.scaledTanh };
|
||||
case DML_OPERATOR_ACTIVATION_SIGMOID: return { activationType, ¶ms.sigmoid };
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX: return { activationType, ¶ms.softmax };
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX1: return { activationType, ¶ms.softmax1 };
|
||||
case DML_OPERATOR_ACTIVATION_SOFTPLUS: return { activationType, ¶ms.softplus };
|
||||
case DML_OPERATOR_ACTIVATION_SOFTSIGN: return { activationType, ¶ms.softsign };
|
||||
case DML_OPERATOR_ACTIVATION_TANH: return { activationType, ¶ms.tanh };
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 154;
|
||||
static constexpr size_t ActivationFunctionCount = 21;
|
||||
static constexpr auto ValueCount = 157;
|
||||
static constexpr size_t ActivationFunctionCount = 24;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -1011,6 +1011,12 @@ struct OperatorDescTraits<DML_ACTIVATION_HARDMAX_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_HARDMAX1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1041,6 +1047,12 @@ struct OperatorDescTraits<DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1083,6 +1095,12 @@ struct OperatorDescTraits<DML_ACTIVATION_SOFTMAX_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1959,6 +1977,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX>
|
|||
using DescType = DML_ACTIVATION_HARDMAX_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX1>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_HARDMAX1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SIGMOID>
|
||||
{
|
||||
|
|
@ -1989,6 +2013,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX
|
|||
using DescType = DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU>
|
||||
{
|
||||
|
|
@ -2031,6 +2061,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX>
|
|||
using DescType = DML_ACTIVATION_SOFTMAX_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX1>
|
||||
{
|
||||
using DescType = DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTPLUS>
|
||||
{
|
||||
|
|
@ -2360,6 +2396,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_IDENTITY:
|
||||
|
|
@ -2370,6 +2408,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LINEAR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS:
|
||||
|
|
@ -2384,6 +2424,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_SOFTPLUS:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_SOFTSIGN:
|
||||
|
|
|
|||
|
|
@ -2288,6 +2288,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA {
|
|||
DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ACTIVATION_HARDMAX1",
|
||||
DML_OPERATOR_ACTIVATION_HARDMAX1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -2358,6 +2373,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA {
|
|||
DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1",
|
||||
DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SlopeTensor", false },
|
||||
|
|
@ -2456,6 +2486,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA {
|
|||
DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ACTIVATION_SOFTMAX1",
|
||||
DML_OPERATOR_ACTIVATION_SOFTMAX1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
|
|||
|
|
@ -1404,6 +1404,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX_OPERATO
|
|||
OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
|
||||
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1444,6 +1453,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_LOG_SOFTMAX_OPE
|
|||
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
|
||||
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1500,6 +1518,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTMAX_OPERATO
|
|||
OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
|
||||
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1689,11 +1716,13 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX1: return DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_LINEAR: return DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_RELU: return DML_ACTIVATION_RELU_OPERATOR_SCHEMA;
|
||||
|
|
@ -1701,6 +1730,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_ACTIVATION_SCALED_TANH: return DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_SIGMOID: return DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX: return DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX1: return DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_SOFTPLUS: return DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_SOFTSIGN: return DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_TANH: return DML_ACTIVATION_TANH_OPERATOR_SCHEMA;
|
||||
|
|
@ -2276,6 +2306,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_HARDMAX_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_HARDMAX1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA,
|
||||
|
|
@ -2296,6 +2330,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA,
|
||||
|
|
@ -2324,6 +2362,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_SOFTMAX_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_SOFTPLUS:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ public:
|
|||
|
||||
ActivationOperatorDescUnion operatorDesc = {};
|
||||
|
||||
int coerceAxis = TensorAxis::DoNotCoerce;
|
||||
std::vector<uint32_t> dmlAxes;
|
||||
|
||||
switch (operatorType)
|
||||
{
|
||||
|
|
@ -39,7 +39,29 @@ public:
|
|||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
{
|
||||
const uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size());
|
||||
coerceAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 1), onnxDimCount);
|
||||
int axis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 1), onnxDimCount);
|
||||
std::vector<int32_t> onnxAxes(onnxDimCount - axis);
|
||||
std::iota(onnxAxes.begin(), onnxAxes.end(), static_cast<int32_t>(axis));
|
||||
|
||||
dmlAxes.resize(onnxDimCount - axis);
|
||||
GetDmlAdjustedAxes(onnxAxes, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount(), /*out*/ dmlAxes);
|
||||
|
||||
operatorDesc.hardmax1.Axes = dmlAxes.data();
|
||||
operatorDesc.hardmax1.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
}
|
||||
break;
|
||||
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX1:
|
||||
{
|
||||
const uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size());
|
||||
int onnxAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, -1), onnxDimCount);
|
||||
|
||||
dmlAxes.push_back(GetDmlAdjustedAxis(onnxAxis, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount()));
|
||||
|
||||
operatorDesc.hardmax1.Axes = dmlAxes.data();
|
||||
operatorDesc.hardmax1.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
|
||||
}
|
||||
break;
|
||||
|
||||
|
|
@ -100,12 +122,6 @@ public:
|
|||
break;
|
||||
}
|
||||
|
||||
if (coerceAxis != TensorAxis::DoNotCoerce)
|
||||
{
|
||||
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, coerceAxis);
|
||||
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, coerceAxis);
|
||||
}
|
||||
|
||||
gsl::span<const uint32_t> outputSizes = m_outputTensorDescs[0].GetSizes();
|
||||
std::vector<DML_TENSOR_DESC> inputDescs;
|
||||
std::vector<DML_TENSOR_DESC> outputDescs;
|
||||
|
|
@ -135,9 +151,24 @@ public:
|
|||
operatorDesc.elu.OutputTensor = outputDescs.data();
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { operatorType, &operatorDesc };
|
||||
DML_OPERATOR_DESC opDesc = { remappedOperatorType(operatorType), &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
|
||||
private:
|
||||
DML_OPERATOR_TYPE remappedOperatorType(const DML_OPERATOR_TYPE operatorType) const {
|
||||
switch (operatorType)
|
||||
{
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX:
|
||||
return DML_OPERATOR_ACTIVATION_HARDMAX1;
|
||||
case DML_OPERATOR_ACTIVATION_SOFTMAX:
|
||||
return DML_OPERATOR_ACTIVATION_SOFTMAX1;
|
||||
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX:
|
||||
return DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1;
|
||||
default:
|
||||
return operatorType;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// A specific type of operation for registration.
|
||||
|
|
|
|||
|
|
@ -179,8 +179,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Elu);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Celu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Selu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Softmax);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Softmax13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Hardmax);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Hardmax13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Softsign);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Softplus);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ParametricSoftplus);
|
||||
|
|
@ -642,16 +645,13 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
|
||||
// {REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
|
||||
// {REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
|
||||
// {REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -1459,8 +1459,11 @@ using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Softmax13 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LogSoftmax13 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Hardmax = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Hardmax13 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Softsign = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Softplus = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_ParametricSoftplus = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -336,6 +336,9 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Transpose = 13;
|
||||
static const int sc_sinceVer_Unsqueeze = 13;
|
||||
static const int sc_sinceVer_ReduseSum = 13;
|
||||
static const int sc_sinceVer_Softmax = 13;
|
||||
static const int sc_sinceVer_LogSoftmax = 13;
|
||||
static const int sc_sinceVer_Hardmax = 13;
|
||||
} // namespace OnnxOperatorSet13
|
||||
|
||||
namespace OnnxOperatorSet14
|
||||
|
|
|
|||
Loading…
Reference in a new issue