mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Fix Gather to Split optimizer (#14478)
### Description Gather to Split optimizer fails if opset == 18. This PR fixes one bug and extend unit tests. ### Motivation and Context The model produced by the optimizer does not follow onnx specifications with opset 18.
This commit is contained in:
parent
3d8fa4d77b
commit
0bcca7ad45
3 changed files with 170 additions and 28 deletions
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const {
|
||||
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
return false;
|
||||
|
|
@ -19,8 +19,8 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node
|
|||
if (!optimizer_utils::IsScalar(input_arg)) return false;
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
if (!tensor_proto) return false;
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false;
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
index = *(init_const.data<int64_t>());
|
||||
axis = 0; // Default value.
|
||||
auto& attrs = node.GetAttributes();
|
||||
|
|
@ -28,6 +28,7 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node
|
|||
auto& axis_attr = attrs.at("axis");
|
||||
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
|
||||
}
|
||||
indices_n_dims = tensor_proto->dims_size();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -79,11 +80,19 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
bool can_fuse = true;
|
||||
bool first_edge = true;
|
||||
int64_t split_axis = 0;
|
||||
int64_t indices_n_dims = -1;
|
||||
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
int64_t index, axis;
|
||||
if (!IsSupportedGather(graph, *it, index, axis)) {
|
||||
int64_t index, axis, dims;
|
||||
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
|
||||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
if (indices_n_dims == -1) {
|
||||
indices_n_dims = dims;
|
||||
} else if (indices_n_dims != dims) {
|
||||
// Not the same number of dimensions (0 or 1) for all scalar indices.
|
||||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -125,43 +134,54 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
}
|
||||
|
||||
InlinedVector<NodeArg*> split_outputs;
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
split_outputs.emplace_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
|
||||
bool add_squeeze_node = indices_n_dims == 0;
|
||||
if (add_squeeze_node) {
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
split_outputs.emplace_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
|
||||
}
|
||||
}
|
||||
|
||||
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
|
||||
{node.MutableOutputDefs()[0]}, split_outputs);
|
||||
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
|
||||
split_node.AddAttribute("axis", split_axis);
|
||||
split_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
||||
// Squeeze before and after OpSet-13 have different schemas.
|
||||
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
|
||||
int onnx_opset_version = -1;
|
||||
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
|
||||
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
|
||||
}
|
||||
|
||||
if (onnx_opset_version < 13) {
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
|
||||
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
if (add_squeeze_node) {
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
|
||||
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
|
||||
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
|
||||
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
|
||||
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
InlinedVector<int64_t> axes_value{split_axis};
|
||||
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
|
||||
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
|
||||
if (onnx_opset_version >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
Node& squeeze_node =
|
||||
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
if (add_squeeze_node) {
|
||||
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
|
||||
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
|
||||
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
|
||||
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
InlinedVector<int64_t> axes_value{split_axis};
|
||||
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
|
||||
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
|
||||
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
Node& squeeze_node =
|
||||
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class GatherToSplitFusion : public GraphTransformer {
|
|||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
private:
|
||||
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const;
|
||||
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -6270,7 +6270,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] ==3); return Status::OK(); };
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
|
||||
|
||||
// OpSet-12
|
||||
{
|
||||
|
|
@ -6325,6 +6325,128 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-18
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
} else if (node.OpType() == "Squeeze") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[1]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInput<float>({{54}});
|
||||
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
|
||||
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
|
||||
auto* gather_index_1 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(0)});
|
||||
auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
|
||||
auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
|
||||
auto* gather_out_1 = builder.MakeIntermediate();
|
||||
auto* gather_out_2 = builder.MakeIntermediate();
|
||||
auto* gather_out_3 = builder.MakeIntermediate();
|
||||
auto* transpose_out_1 = builder.MakeOutput();
|
||||
auto* transpose_out_2 = builder.MakeOutput();
|
||||
auto* transpose_out_3 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
|
||||
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
|
||||
.AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
|
||||
.AddAttribute("axis", static_cast<int64_t>(-2));
|
||||
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
|
||||
.AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
|
||||
|
||||
// OpSet-12
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-14
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-18
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue