Fix unsqueeze for opset 13 for ReduceMean Grad (#10668)

* fix unsqueeze for opset 13 for reducemean grad

* fix input for reduce mean
This commit is contained in:
harshithapv 2022-02-28 09:55:52 -08:00 committed by GitHub
parent eb116595d4
commit 037f08f1ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 6 deletions

View file

@ -1033,10 +1033,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
}
ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
if (!keepdims) {
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
}else{
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
}
}
result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")}));

View file

@ -606,9 +606,11 @@ TEST(GradientCheckerTest, GemmGrad) {
TEST(GradientCheckerTest, ReduceMeanGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
OpDef op_def_opset11{"ReduceMean", kOnnxDomain, 11};
RunReductionTests(op_def_opset11);
RunReductionTests(op_def);
OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
RunReductionTests(op_def_opset13);
}
TEST(GradientCheckerTest, ReduceSumGrad) {