ConcatGrad for OpSet13 (#10109)

This commit is contained in:
Vincent Wang 2021-12-24 10:02:52 +08:00 committed by GitHub
parent 05d20343ee
commit f780f06240
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 54 deletions

View file

@ -510,72 +510,60 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) {
ORT_ENFORCE(attributes.at("axis").has_i());
auto axis = attributes.at("axis").i();
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
std::vector<ArgDef> outputs;
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
std::vector<Dimension> data_shape;
ORT_ENFORCE(GetShape(I(i), data_shape).IsOK());
int64_t axis_index = axis < 0 ? static_cast<int64_t>(data_shape.size()) + axis : axis;
if (axis_index >= 0 && axis_index < static_cast<int64_t>(data_shape.size()) && data_shape[axis_index].has_dim_value()) {
split_attribute[i] = data_shape[axis_index].dim_value();
} else {
ORT_THROW("Error: can't infer split attribute value for ConcatGrad");
}
outputs.push_back(GI(i));
}
std::vector<ArgDef> node_outputs;
std::vector<AttributeProto> new_attributes;
new_attributes.push_back(MakeAttribute("axis", axis));
new_attributes.push_back(MakeAttribute("split", split_attribute));
return std::vector<NodeDef>{
NodeDef("Split",
{GO(0)},
outputs,
new_attributes)};
// Split Op before OpSet13 has "split" as attribute, but as input since OpSet13.
if (SrcNodeOpsetVersion() < 13) {
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
std::vector<Dimension> data_shape;
ORT_ENFORCE(GetShape(I(i), data_shape).IsOK());
int64_t axis_index = axis < 0 ? static_cast<int64_t>(data_shape.size()) + axis : axis;
if (axis_index >= 0 && axis_index < static_cast<int64_t>(data_shape.size()) &&
data_shape[axis_index].has_dim_value()) {
split_attribute[i] = data_shape[axis_index].dim_value();
} else {
ORT_THROW("Error: can't infer split attribute value for ConcatGrad");
}
node_outputs.push_back(GI(i));
}
new_attributes.push_back(MakeAttribute("split", split_attribute));
return std::vector<NodeDef>{NodeDef("Split", {GO(0)}, node_outputs, new_attributes)};
}
std::vector<NodeDef> output;
NodeDef axis_const_node = ConstantScalarNode(axis, {1}, Name(std::to_string(axis) + "_int64"));
ArgDef axis_arg_def = axis_const_node.output_args[0];
output.emplace_back(axis_const_node);
std::vector<ArgDef> split_sizes;
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
ArgDef shape_arg_def = IA("shape_" + std::to_string(i));
ArgDef split_size_arg_def = IA("split_size_" + std::to_string(i));
output.emplace_back(NodeDef("Shape", {I(i)}, {shape_arg_def}));
output.emplace_back(
NodeDef("Gather", {shape_arg_def, axis_arg_def}, {split_size_arg_def}, {MakeAttribute("axis", int64_t(0))}));
split_sizes.emplace_back(split_size_arg_def);
node_outputs.emplace_back(GI(i));
}
output.emplace_back(NodeDef("Concat", split_sizes, {IA("split_sizes")}, {MakeAttribute("axis", int64_t(0))}));
output.emplace_back(NodeDef("Split", {GO(0), IA("split_sizes")}, node_outputs, new_attributes));
return output;
}
IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) {
auto attributes = SrcNodeAttributes();
ORT_ENFORCE(utils::HasInt(attributes.at("axis")));
auto axis = attributes.at("axis").i();
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
std::vector<ArgDef> outputs;
bool known_shapes = true;
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
std::vector<Dimension> data_shape;
if (GetShape(I(i), data_shape).IsOK()) {
int64_t rank = static_cast<int64_t>(data_shape.size());
int64_t axis_index = HandleNegativeAxis(axis, rank);
if (data_shape[axis_index].has_dim_value()) {
split_attribute[i] = data_shape[axis_index].dim_value();
} else {
known_shapes = false;
}
} else {
known_shapes = false;
}
outputs.push_back(GI(i));
}
std::vector<AttributeProto> new_attributes;
new_attributes.push_back(MakeAttribute("axis", axis));
if (known_shapes) {
new_attributes.push_back(MakeAttribute("split", split_attribute));
return std::vector<NodeDef>{
NodeDef("Split",
{GO(0)},
outputs,
new_attributes)};
} else {
return std::vector<NodeDef>{
NodeDef(OpDef{"SplitTraining", kMSDomain, 1},
{GO(0), O(1)},
outputs,
new_attributes)};
std::vector<ArgDef> outputs;
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
outputs.push_back(GI(i));
}
return std::vector<NodeDef>{NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, {GO(0), O(1)}, outputs, new_attributes)};
}
IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) {

View file

@ -1126,7 +1126,9 @@ static void TestConcatOpGrad(const std::string& op_type,
}
TEST(GradientCheckerTest, ConcatGrad) {
// Concat's gradient uses Split, and Split Op move "split" attribute to input since OpSet13.
TestConcatOpGrad("Concat");
TestConcatOpGrad("Concat", kOnnxDomain, 13);
}
TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */