Added Softmax/Hardmax/LogSoftmax-13 (#11772)

* Added Softmax/Hardmax/LogSoftmax-13

* Removed redundant method specifier

Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
sumitsays 2022-06-07 14:31:55 -07:00 committed by GitHub
parent a7fa735286
commit aa3a825816
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 189 additions and 17 deletions

View file

@ -9,10 +9,12 @@ union ActivationOperatorDescUnion
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
DML_ACTIVATION_CELU_OPERATOR_DESC celu;
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
DML_ACTIVATION_HARDMAX1_OPERATOR_DESC hardmax1;
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
DML_ACTIVATION_LINEAR_OPERATOR_DESC linear;
DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC logSoftmax;
DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC logSoftmax1;
DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC parameterizedRelu;
DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC parametricSoftplus;
DML_ACTIVATION_RELU_OPERATOR_DESC relu;
@ -20,6 +22,7 @@ union ActivationOperatorDescUnion
DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC scaledElu;
DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoid;
DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax;
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmax1;
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus;
DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC softsign;
DML_ACTIVATION_TANH_OPERATOR_DESC tanh;
@ -41,11 +44,13 @@ struct ActivationOperatorDesc
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, &params.elu };
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, &params.celu };
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, &params.hardmax };
case DML_OPERATOR_ACTIVATION_HARDMAX1: return { activationType, &params.hardmax1 };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, &params.identity };
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return { activationType, &params.leakyRelu };
case DML_OPERATOR_ACTIVATION_LINEAR: return { activationType, &params.linear };
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return { activationType, &params.logSoftmax };
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return { activationType, &params.logSoftmax1 };
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return { activationType, &params.parameterizedRelu };
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return { activationType, &params.parametricSoftplus };
case DML_OPERATOR_ACTIVATION_RELU: return { activationType, &params.relu };
@ -53,6 +58,7 @@ struct ActivationOperatorDesc
case DML_OPERATOR_ACTIVATION_SCALED_TANH: return { activationType, &params.scaledTanh };
case DML_OPERATOR_ACTIVATION_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_SOFTMAX: return { activationType, &params.softmax };
case DML_OPERATOR_ACTIVATION_SOFTMAX1: return { activationType, &params.softmax1 };
case DML_OPERATOR_ACTIVATION_SOFTPLUS: return { activationType, &params.softplus };
case DML_OPERATOR_ACTIVATION_SOFTSIGN: return { activationType, &params.softsign };
case DML_OPERATOR_ACTIVATION_TANH: return { activationType, &params.tanh };

View file

@ -24,8 +24,8 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 154;
static constexpr size_t ActivationFunctionCount = 21;
static constexpr auto ValueCount = 157;
static constexpr size_t ActivationFunctionCount = 24;
};
template <>
@ -1011,6 +1011,12 @@ struct OperatorDescTraits<DML_ACTIVATION_HARDMAX_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_HARDMAX1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX1;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>
{
@ -1041,6 +1047,12 @@ struct OperatorDescTraits<DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC>
{
@ -1083,6 +1095,12 @@ struct OperatorDescTraits<DML_ACTIVATION_SOFTMAX_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX1;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>
{
@ -1959,6 +1977,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX>
using DescType = DML_ACTIVATION_HARDMAX_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX1>
{
using DescType = DML_ACTIVATION_HARDMAX1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SIGMOID>
{
@ -1989,6 +2013,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX
using DescType = DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1>
{
using DescType = DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU>
{
@ -2031,6 +2061,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX>
using DescType = DML_ACTIVATION_SOFTMAX_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX1>
{
using DescType = DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTPLUS>
{
@ -2360,6 +2396,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_CELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARDMAX:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARDMAX1:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARDMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_IDENTITY:
@ -2370,6 +2408,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LINEAR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS:
@ -2384,6 +2424,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_SOFTMAX:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_SOFTPLUS:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_SOFTSIGN:

View file

@ -2288,6 +2288,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA {
DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA {
"DML_OPERATOR_ACTIVATION_HARDMAX1",
DML_OPERATOR_ACTIVATION_HARDMAX1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -2358,6 +2373,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA {
DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA {
"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1",
DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SlopeTensor", false },
@ -2456,6 +2486,21 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA {
DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA {
"DML_OPERATOR_ACTIVATION_SOFTMAX1",
DML_OPERATOR_ACTIVATION_SOFTMAX1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },

View file

@ -1404,6 +1404,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX_OPERATO
OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARDMAX1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC& desc)
{
return {
@ -1444,6 +1453,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_LOG_SOFTMAX_OPE
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC& desc)
{
return {
@ -1500,6 +1518,15 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTMAX_OPERATO
OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC& desc)
{
return {
@ -1689,11 +1716,13 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX1: return DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_LINEAR: return DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_RELU: return DML_ACTIVATION_RELU_OPERATOR_SCHEMA;
@ -1701,6 +1730,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ACTIVATION_SCALED_TANH: return DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SIGMOID: return DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SOFTMAX: return DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SOFTMAX1: return DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SOFTPLUS: return DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SOFTSIGN: return DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_TANH: return DML_ACTIVATION_TANH_OPERATOR_SCHEMA;
@ -2276,6 +2306,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_HARDMAX_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_HARDMAX1:
return AbstractOperatorDesc(
&DML_ACTIVATION_HARDMAX1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_HARDMAX1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
return AbstractOperatorDesc(
&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA,
@ -2296,6 +2330,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
return AbstractOperatorDesc(
&DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_LOG_SOFTMAX1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA,
@ -2324,6 +2362,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_SOFTMAX_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
return AbstractOperatorDesc(
&DML_ACTIVATION_SOFTMAX1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_SOFTPLUS:
return AbstractOperatorDesc(
&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA,

View file

@ -25,7 +25,7 @@ public:
ActivationOperatorDescUnion operatorDesc = {};
int coerceAxis = TensorAxis::DoNotCoerce;
std::vector<uint32_t> dmlAxes;
switch (operatorType)
{
@ -39,7 +39,29 @@ public:
case DML_OPERATOR_ACTIVATION_HARDMAX:
{
const uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size());
coerceAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 1), onnxDimCount);
int axis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 1), onnxDimCount);
std::vector<int32_t> onnxAxes(onnxDimCount - axis);
std::iota(onnxAxes.begin(), onnxAxes.end(), static_cast<int32_t>(axis));
dmlAxes.resize(onnxDimCount - axis);
GetDmlAdjustedAxes(onnxAxes, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount(), /*out*/ dmlAxes);
operatorDesc.hardmax1.Axes = dmlAxes.data();
operatorDesc.hardmax1.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
}
break;
case DML_OPERATOR_ACTIVATION_SOFTMAX1:
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1:
case DML_OPERATOR_ACTIVATION_HARDMAX1:
{
const uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size());
int onnxAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, -1), onnxDimCount);
dmlAxes.push_back(GetDmlAdjustedAxis(onnxAxis, onnxDimCount, m_inputTensorDescs.front().GetDimensionCount()));
operatorDesc.hardmax1.Axes = dmlAxes.data();
operatorDesc.hardmax1.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
}
break;
@ -100,12 +122,6 @@ public:
break;
}
if (coerceAxis != TensorAxis::DoNotCoerce)
{
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, coerceAxis);
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, coerceAxis);
}
gsl::span<const uint32_t> outputSizes = m_outputTensorDescs[0].GetSizes();
std::vector<DML_TENSOR_DESC> inputDescs;
std::vector<DML_TENSOR_DESC> outputDescs;
@ -135,9 +151,24 @@ public:
operatorDesc.elu.OutputTensor = outputDescs.data();
}
DML_OPERATOR_DESC opDesc = { operatorType, &operatorDesc };
DML_OPERATOR_DESC opDesc = { remappedOperatorType(operatorType), &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
private:
DML_OPERATOR_TYPE remappedOperatorType(const DML_OPERATOR_TYPE operatorType) const {
switch (operatorType)
{
case DML_OPERATOR_ACTIVATION_HARDMAX:
return DML_OPERATOR_ACTIVATION_HARDMAX1;
case DML_OPERATOR_ACTIVATION_SOFTMAX:
return DML_OPERATOR_ACTIVATION_SOFTMAX1;
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX:
return DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1;
default:
return operatorType;
}
}
};
// A specific type of operation for registration.

View file

@ -179,8 +179,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(Elu);
DML_OP_EXTERN_CREATION_FUNCTION(Celu);
DML_OP_EXTERN_CREATION_FUNCTION(Selu);
DML_OP_EXTERN_CREATION_FUNCTION(Softmax);
DML_OP_EXTERN_CREATION_FUNCTION(Softmax13);
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax);
DML_OP_EXTERN_CREATION_FUNCTION(LogSoftmax13);
DML_OP_EXTERN_CREATION_FUNCTION(Hardmax);
DML_OP_EXTERN_CREATION_FUNCTION(Hardmax13);
DML_OP_EXTERN_CREATION_FUNCTION(Softsign);
DML_OP_EXTERN_CREATION_FUNCTION(Softplus);
DML_OP_EXTERN_CREATION_FUNCTION(ParametricSoftplus);
@ -642,16 +645,13 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
// {REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
// {REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
// TODO: Update Softmax-13/LogSoftmax-13/Hardmax-13 family ops behavior to align with other frameworks https://github.com/onnx/onnx/pull/2879
// {REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},

View file

@ -1459,8 +1459,11 @@ using ShapeInferenceHelper_Elu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Celu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Selu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Softmax = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Softmax13 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LogSoftmax = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LogSoftmax13 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Hardmax = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Hardmax13 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Softsign = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Softplus = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_ParametricSoftplus = GetOutputShapeAsInputShapeHelper;

View file

@ -336,6 +336,9 @@ namespace OperatorHelper
static const int sc_sinceVer_Transpose = 13;
static const int sc_sinceVer_Unsqueeze = 13;
static const int sc_sinceVer_ReduseSum = 13;
static const int sc_sinceVer_Softmax = 13;
static const int sc_sinceVer_LogSoftmax = 13;
static const int sc_sinceVer_Hardmax = 13;
} // namespace OnnxOperatorSet13
namespace OnnxOperatorSet14