From f780f06240ef407a838d7ea6d11a24cfcc436733 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 24 Dec 2021 10:02:52 +0800 Subject: [PATCH] ConcatGrad for OpSet13 (#10109) --- .../core/graph/gradient_builder.cc | 96 ++++++++----------- .../test/gradient/gradient_ops_test.cc | 2 + 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 10987beb0d..55b521e13a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -510,72 +510,60 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { ORT_ENFORCE(attributes.at("axis").has_i()); auto axis = attributes.at("axis").i(); - std::vector split_attribute(GetSrcNodeInputSize()); - std::vector outputs; - for (int i = 0; i < GetSrcNodeInputSize(); ++i) { - std::vector data_shape; - ORT_ENFORCE(GetShape(I(i), data_shape).IsOK()); - int64_t axis_index = axis < 0 ? static_cast(data_shape.size()) + axis : axis; - if (axis_index >= 0 && axis_index < static_cast(data_shape.size()) && data_shape[axis_index].has_dim_value()) { - split_attribute[i] = data_shape[axis_index].dim_value(); - } else { - ORT_THROW("Error: can't infer split attribute value for ConcatGrad"); - } - outputs.push_back(GI(i)); - } - + std::vector node_outputs; std::vector new_attributes; new_attributes.push_back(MakeAttribute("axis", axis)); - new_attributes.push_back(MakeAttribute("split", split_attribute)); - return std::vector{ - NodeDef("Split", - {GO(0)}, - outputs, - new_attributes)}; + // Split Op before OpSet13 has "split" as attribute, but as input since OpSet13. + if (SrcNodeOpsetVersion() < 13) { + std::vector split_attribute(GetSrcNodeInputSize()); + for (int i = 0; i < GetSrcNodeInputSize(); ++i) { + std::vector data_shape; + ORT_ENFORCE(GetShape(I(i), data_shape).IsOK()); + int64_t axis_index = axis < 0 ? static_cast(data_shape.size()) + axis : axis; + if (axis_index >= 0 && axis_index < static_cast(data_shape.size()) && + data_shape[axis_index].has_dim_value()) { + split_attribute[i] = data_shape[axis_index].dim_value(); + } else { + ORT_THROW("Error: can't infer split attribute value for ConcatGrad"); + } + node_outputs.push_back(GI(i)); + } + + new_attributes.push_back(MakeAttribute("split", split_attribute)); + return std::vector{NodeDef("Split", {GO(0)}, node_outputs, new_attributes)}; + } + + std::vector output; + NodeDef axis_const_node = ConstantScalarNode(axis, {1}, Name(std::to_string(axis) + "_int64")); + ArgDef axis_arg_def = axis_const_node.output_args[0]; + output.emplace_back(axis_const_node); + std::vector split_sizes; + for (int i = 0; i < GetSrcNodeInputSize(); ++i) { + ArgDef shape_arg_def = IA("shape_" + std::to_string(i)); + ArgDef split_size_arg_def = IA("split_size_" + std::to_string(i)); + output.emplace_back(NodeDef("Shape", {I(i)}, {shape_arg_def})); + output.emplace_back( + NodeDef("Gather", {shape_arg_def, axis_arg_def}, {split_size_arg_def}, {MakeAttribute("axis", int64_t(0))})); + split_sizes.emplace_back(split_size_arg_def); + node_outputs.emplace_back(GI(i)); + } + output.emplace_back(NodeDef("Concat", split_sizes, {IA("split_sizes")}, {MakeAttribute("axis", int64_t(0))})); + output.emplace_back(NodeDef("Split", {GO(0), IA("split_sizes")}, node_outputs, new_attributes)); + return output; } IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) { auto attributes = SrcNodeAttributes(); ORT_ENFORCE(utils::HasInt(attributes.at("axis"))); auto axis = attributes.at("axis").i(); - - std::vector split_attribute(GetSrcNodeInputSize()); - std::vector outputs; - bool known_shapes = true; - for (int i = 0; i < GetSrcNodeInputSize(); ++i) { - std::vector data_shape; - if (GetShape(I(i), data_shape).IsOK()) { - int64_t rank = static_cast(data_shape.size()); - int64_t axis_index = HandleNegativeAxis(axis, rank); - if (data_shape[axis_index].has_dim_value()) { - split_attribute[i] = data_shape[axis_index].dim_value(); - } else { - known_shapes = false; - } - } else { - known_shapes = false; - } - - outputs.push_back(GI(i)); - } - std::vector new_attributes; new_attributes.push_back(MakeAttribute("axis", axis)); - if (known_shapes) { - new_attributes.push_back(MakeAttribute("split", split_attribute)); - return std::vector{ - NodeDef("Split", - {GO(0)}, - outputs, - new_attributes)}; - } else { - return std::vector{ - NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, - {GO(0), O(1)}, - outputs, - new_attributes)}; + std::vector outputs; + for (int i = 0; i < GetSrcNodeInputSize(); ++i) { + outputs.push_back(GI(i)); } + return std::vector{NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, {GO(0), O(1)}, outputs, new_attributes)}; } IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) { diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 9b95925410..4511336435 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1126,7 +1126,9 @@ static void TestConcatOpGrad(const std::string& op_type, } TEST(GradientCheckerTest, ConcatGrad) { + // Concat's gradient uses Split, and Split Op move "split" attribute to input since OpSet13. TestConcatOpGrad("Concat"); + TestConcatOpGrad("Concat", kOnnxDomain, 13); } TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */