Add of ReduceMax Gradient (#23501)

This commit is contained in:
Corentin Maravat 2025-01-31 19:37:41 +01:00 committed by GitHub
parent 6bbf1bd948
commit a9d4d08ed1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 0 deletions

View file

@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;
// Check the "keepdims" attribute
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}
ArgDef grad = GO(0);
ArgDef reduced_output = O(0);
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
ArgDef unsqueeze_axes_arg;
bool axes_provided = false;
// Handle "axes" as attribute or input
if (attributes.find("axes") != attributes.end()) {
axes_provided = true;
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (SrcNodeOpsetVersion() >= 13) {
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
unsqueeze_axes_arg = axes_values_node.output_args[0];
}
} else if (numInputs == 2) {
axes_provided = true;
unsqueeze_axes_arg = I(1);
}
if (axes_provided) {
grad = IA("Unsqueezed_Grad");
reduced_output = IA("Unsqueezed_Output");
if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)}));
} else {
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad}));
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output}));
}
}
}
// Step 1: Recreate the boolean mask tensor indicating max positions
result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")}));
result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")}));
// Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast<int64_t>(OElemType(0)))}));
// Step 3: Multiply the input gradient by the mask
result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")}));
// Step 4: Ensure the output gradient has the same shape as the input
result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)}));
return result;
}

View file

@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient)
DECLARE_GRADIENT_BUILDER(GetExternalGradient)

View file

@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient);
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};

View file

@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
}
}
TEST(GradientCheckerTest, ReduceMaxGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def_11{"ReduceMax", kOnnxDomain, 11};
RunReductionTests(op_def_11, false, true);
OpDef op_def_12{"ReduceMax", kOnnxDomain, 12};
RunReductionTests(op_def_12, false, true);
OpDef op_def_13{"ReduceMax", kOnnxDomain, 13};
RunReductionTests(op_def_13, false, true);
// axes is input from opset 18.
OpDef op_def_18{"ReduceMax", kOnnxDomain, 18};
RunReductionTests(op_def_18, true, true);
OpDef op_def_20{"ReduceMax", kOnnxDomain, 20};
RunReductionTests(op_def_20, true, true);
}
} // namespace test
} // namespace onnxruntime