Fix Gather to Split optimizer (#14478)

### Description
Gather to Split optimizer fails if opset == 18. This PR fixes one bug
and extend unit tests.



### Motivation and Context
The model produced by the optimizer does not follow onnx specifications
with opset 18.
This commit is contained in:
Xavier Dupré 2023-02-02 22:29:44 +01:00 committed by GitHub
parent 3d8fa4d77b
commit 0bcca7ad45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 28 deletions

View file

@ -9,7 +9,7 @@
namespace onnxruntime {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
@ -19,8 +19,8 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node
if (!optimizer_utils::IsScalar(input_arg)) return false;
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (!tensor_proto) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
index = *(init_const.data<int64_t>());
axis = 0; // Default value.
auto& attrs = node.GetAttributes();
@ -28,6 +28,7 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
}
indices_n_dims = tensor_proto->dims_size();
return true;
}
@ -79,11 +80,19 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
bool can_fuse = true;
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
int64_t index, axis;
if (!IsSupportedGather(graph, *it, index, axis)) {
int64_t index, axis, dims;
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
can_fuse = false;
break;
}
if (indices_n_dims == -1) {
indices_n_dims = dims;
} else if (indices_n_dims != dims) {
// Not the same number of dimensions (0 or 1) for all scalar indices.
can_fuse = false;
break;
}
@ -125,43 +134,54 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
}
InlinedVector<NodeArg*> split_outputs;
for (size_t i = 0; i < output_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
bool add_squeeze_node = indices_n_dims == 0;
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
}
}
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{node.MutableOutputDefs()[0]}, split_outputs);
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(node.GetExecutionProviderType());
// Squeeze before and after OpSet-13 have different schemas.
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
if (onnx_opset_version < 13) {
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
}
} else {
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> axes_value{split_axis};
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
}
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
if (add_squeeze_node) {
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> axes_value{split_axis};
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
}
}

View file

@ -20,7 +20,7 @@ class GatherToSplitFusion : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const;
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
};
/**

View file

@ -6270,7 +6270,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] ==3); return Status::OK(); };
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
// OpSet-12
{
@ -6325,6 +6325,128 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
// OpSet-12
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
// OpSet-14
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {