mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
ConcatGrad for OpSet13 (#10109)
This commit is contained in:
parent
05d20343ee
commit
f780f06240
2 changed files with 44 additions and 54 deletions
|
|
@ -510,72 +510,60 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) {
|
|||
ORT_ENFORCE(attributes.at("axis").has_i());
|
||||
auto axis = attributes.at("axis").i();
|
||||
|
||||
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
|
||||
std::vector<ArgDef> outputs;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
std::vector<Dimension> data_shape;
|
||||
ORT_ENFORCE(GetShape(I(i), data_shape).IsOK());
|
||||
int64_t axis_index = axis < 0 ? static_cast<int64_t>(data_shape.size()) + axis : axis;
|
||||
if (axis_index >= 0 && axis_index < static_cast<int64_t>(data_shape.size()) && data_shape[axis_index].has_dim_value()) {
|
||||
split_attribute[i] = data_shape[axis_index].dim_value();
|
||||
} else {
|
||||
ORT_THROW("Error: can't infer split attribute value for ConcatGrad");
|
||||
}
|
||||
outputs.push_back(GI(i));
|
||||
}
|
||||
|
||||
std::vector<ArgDef> node_outputs;
|
||||
std::vector<AttributeProto> new_attributes;
|
||||
new_attributes.push_back(MakeAttribute("axis", axis));
|
||||
new_attributes.push_back(MakeAttribute("split", split_attribute));
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Split",
|
||||
{GO(0)},
|
||||
outputs,
|
||||
new_attributes)};
|
||||
// Split Op before OpSet13 has "split" as attribute, but as input since OpSet13.
|
||||
if (SrcNodeOpsetVersion() < 13) {
|
||||
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
std::vector<Dimension> data_shape;
|
||||
ORT_ENFORCE(GetShape(I(i), data_shape).IsOK());
|
||||
int64_t axis_index = axis < 0 ? static_cast<int64_t>(data_shape.size()) + axis : axis;
|
||||
if (axis_index >= 0 && axis_index < static_cast<int64_t>(data_shape.size()) &&
|
||||
data_shape[axis_index].has_dim_value()) {
|
||||
split_attribute[i] = data_shape[axis_index].dim_value();
|
||||
} else {
|
||||
ORT_THROW("Error: can't infer split attribute value for ConcatGrad");
|
||||
}
|
||||
node_outputs.push_back(GI(i));
|
||||
}
|
||||
|
||||
new_attributes.push_back(MakeAttribute("split", split_attribute));
|
||||
return std::vector<NodeDef>{NodeDef("Split", {GO(0)}, node_outputs, new_attributes)};
|
||||
}
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
NodeDef axis_const_node = ConstantScalarNode(axis, {1}, Name(std::to_string(axis) + "_int64"));
|
||||
ArgDef axis_arg_def = axis_const_node.output_args[0];
|
||||
output.emplace_back(axis_const_node);
|
||||
std::vector<ArgDef> split_sizes;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
ArgDef shape_arg_def = IA("shape_" + std::to_string(i));
|
||||
ArgDef split_size_arg_def = IA("split_size_" + std::to_string(i));
|
||||
output.emplace_back(NodeDef("Shape", {I(i)}, {shape_arg_def}));
|
||||
output.emplace_back(
|
||||
NodeDef("Gather", {shape_arg_def, axis_arg_def}, {split_size_arg_def}, {MakeAttribute("axis", int64_t(0))}));
|
||||
split_sizes.emplace_back(split_size_arg_def);
|
||||
node_outputs.emplace_back(GI(i));
|
||||
}
|
||||
output.emplace_back(NodeDef("Concat", split_sizes, {IA("split_sizes")}, {MakeAttribute("axis", int64_t(0))}));
|
||||
output.emplace_back(NodeDef("Split", {GO(0), IA("split_sizes")}, node_outputs, new_attributes));
|
||||
return output;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) {
|
||||
auto attributes = SrcNodeAttributes();
|
||||
ORT_ENFORCE(utils::HasInt(attributes.at("axis")));
|
||||
auto axis = attributes.at("axis").i();
|
||||
|
||||
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
|
||||
std::vector<ArgDef> outputs;
|
||||
bool known_shapes = true;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
std::vector<Dimension> data_shape;
|
||||
if (GetShape(I(i), data_shape).IsOK()) {
|
||||
int64_t rank = static_cast<int64_t>(data_shape.size());
|
||||
int64_t axis_index = HandleNegativeAxis(axis, rank);
|
||||
if (data_shape[axis_index].has_dim_value()) {
|
||||
split_attribute[i] = data_shape[axis_index].dim_value();
|
||||
} else {
|
||||
known_shapes = false;
|
||||
}
|
||||
} else {
|
||||
known_shapes = false;
|
||||
}
|
||||
|
||||
outputs.push_back(GI(i));
|
||||
}
|
||||
|
||||
std::vector<AttributeProto> new_attributes;
|
||||
new_attributes.push_back(MakeAttribute("axis", axis));
|
||||
if (known_shapes) {
|
||||
new_attributes.push_back(MakeAttribute("split", split_attribute));
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Split",
|
||||
{GO(0)},
|
||||
outputs,
|
||||
new_attributes)};
|
||||
} else {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"SplitTraining", kMSDomain, 1},
|
||||
{GO(0), O(1)},
|
||||
outputs,
|
||||
new_attributes)};
|
||||
std::vector<ArgDef> outputs;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
outputs.push_back(GI(i));
|
||||
}
|
||||
return std::vector<NodeDef>{NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, {GO(0), O(1)}, outputs, new_attributes)};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) {
|
||||
|
|
|
|||
|
|
@ -1126,7 +1126,9 @@ static void TestConcatOpGrad(const std::string& op_type,
|
|||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConcatGrad) {
|
||||
// Concat's gradient uses Split, and Split Op move "split" attribute to input since OpSet13.
|
||||
TestConcatOpGrad("Concat");
|
||||
TestConcatOpGrad("Concat", kOnnxDomain, 13);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */
|
||||
|
|
|
|||
Loading…
Reference in a new issue