mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add of ReduceMax Gradient (#23501)
This commit is contained in:
parent
6bbf1bd948
commit
a9d4d08ed1
4 changed files with 86 additions and 0 deletions
|
|
@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
|
|||
|
||||
result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
|
||||
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
|
||||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
bool keepdims = true;
|
||||
|
||||
// Check the "keepdims" attribute
|
||||
if (attributes.find("keepdims") != attributes.end() &&
|
||||
attributes.at("keepdims").has_i()) {
|
||||
keepdims = static_cast<bool>(attributes.at("keepdims").i());
|
||||
}
|
||||
|
||||
ArgDef grad = GO(0);
|
||||
ArgDef reduced_output = O(0);
|
||||
|
||||
if (!keepdims) {
|
||||
size_t numInputs = GetSrcNodeInputSize();
|
||||
ArgDef unsqueeze_axes_arg;
|
||||
bool axes_provided = false;
|
||||
|
||||
// Handle "axes" as attribute or input
|
||||
if (attributes.find("axes") != attributes.end()) {
|
||||
axes_provided = true;
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
if (SrcNodeOpsetVersion() >= 13) {
|
||||
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
|
||||
result.push_back(axes_values_node);
|
||||
unsqueeze_axes_arg = axes_values_node.output_args[0];
|
||||
}
|
||||
} else if (numInputs == 2) {
|
||||
axes_provided = true;
|
||||
unsqueeze_axes_arg = I(1);
|
||||
}
|
||||
|
||||
if (axes_provided) {
|
||||
grad = IA("Unsqueezed_Grad");
|
||||
reduced_output = IA("Unsqueezed_Output");
|
||||
if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) {
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
|
||||
result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)}));
|
||||
} else {
|
||||
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad}));
|
||||
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Recreate the boolean mask tensor indicating max positions
|
||||
result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
|
||||
result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")}));
|
||||
result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")}));
|
||||
// Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
|
||||
result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast<int64_t>(OElemType(0)))}));
|
||||
// Step 3: Multiply the input gradient by the mask
|
||||
result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")}));
|
||||
// Step 4: Ensure the output gradient has the same shape as the input
|
||||
result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)}));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient)
|
||||
|
||||
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
|
||||
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient);
|
||||
|
||||
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceMaxGrad) {
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def_11{"ReduceMax", kOnnxDomain, 11};
|
||||
|
||||
RunReductionTests(op_def_11, false, true);
|
||||
|
||||
OpDef op_def_12{"ReduceMax", kOnnxDomain, 12};
|
||||
|
||||
RunReductionTests(op_def_12, false, true);
|
||||
|
||||
OpDef op_def_13{"ReduceMax", kOnnxDomain, 13};
|
||||
|
||||
RunReductionTests(op_def_13, false, true);
|
||||
|
||||
// axes is input from opset 18.
|
||||
OpDef op_def_18{"ReduceMax", kOnnxDomain, 18};
|
||||
|
||||
RunReductionTests(op_def_18, true, true);
|
||||
|
||||
OpDef op_def_20{"ReduceMax", kOnnxDomain, 20};
|
||||
|
||||
RunReductionTests(op_def_20, true, true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue