mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Fix unsqueeze for opset 13 for ReduceMean Grad (#10668)
* fix unsqueeze for opset 13 for reducemean grad * fix input for reduce mean
This commit is contained in:
parent
eb116595d4
commit
037f08f1ff
2 changed files with 16 additions and 6 deletions
|
|
@ -1033,10 +1033,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
|
|||
}
|
||||
|
||||
ArgDef grad = GO(0);
|
||||
if (!keepdims && attributes.find("axes") != attributes.end()) {
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
grad = IA("Unqueezed_Grad");
|
||||
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
|
||||
if (!keepdims) {
|
||||
if (attributes.find("axes") != attributes.end()) {
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
grad = IA("Unqueezed_Grad");
|
||||
if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze
|
||||
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
|
||||
}else{
|
||||
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
|
||||
result.push_back(axes_values_node);
|
||||
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")}));
|
||||
|
|
|
|||
|
|
@ -606,9 +606,11 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
|
||||
TEST(GradientCheckerTest, ReduceMeanGrad) {
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
|
||||
OpDef op_def_opset11{"ReduceMean", kOnnxDomain, 11};
|
||||
RunReductionTests(op_def_opset11);
|
||||
|
||||
RunReductionTests(op_def);
|
||||
OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
|
||||
RunReductionTests(op_def_opset13);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceSumGrad) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue