From b0699d901c8e8e8dcfe4fd3b72d54ef7bef5bf75 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 15 Nov 2023 13:46:38 +0800 Subject: [PATCH] Support Graph Input and Initializer for GatherToSplit Fusion (#18412) Support graph input and initializer for GatherToSplit fusion. Previously the fusion requires Gather nodes consume some other node which cannot be graph input or initializer. This helps some model training with such case so that we will not have GatherGrad in the final graph. GatherGrad is super inefficient in kernel implementation. --- onnxruntime/core/optimizer/gather_fusion.cc | 65 +++-- .../test/optimizer/graph_transform_test.cc | 223 ++++++++++++++++-- 2 files changed, 243 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index b994028cbc..4903bc1d6b 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,8 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedVector node_args; + for (auto node_arg : graph.GetInputs()) { + if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + node_args.push_back(node_arg); + } + } + + for (auto entry : graph.GetAllInitializedTensors()) { + if (graph.GetConsumerNodes(entry.first).size() > 1) { + auto node_arg = graph.GetNodeArg(entry.first); + if (node_arg) { + node_args.push_back(node_arg); + } + } + } + for (auto node_index : node_topology_list) { auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion @@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - auto shape = node.MutableOutputDefs()[0]->Shape(); + node_args.push_back(node.OutputDefs()[0]); + } + + for (const NodeArg* node_arg : node_args) { + auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); @@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool first_edge = true; int64_t split_axis = 0; int64_t indices_n_dims = -1; - InlinedVector gather_outputs(output_count, nullptr); + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + InlinedVector gather_outputs(consumer_count, nullptr); InlinedVector> nodes_to_fuse; - for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { + for (auto consumer : consumers) { int64_t index, axis, dims; - if (!IsSupportedGather(graph, *it, index, axis, dims)) { + if (!consumer || consumer->InputDefs()[0] != node_arg || + !IsSupportedGather(graph, *consumer, index, axis, dims)) { can_fuse = false; break; } @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (axis < 0) axis += rank; if (first_edge) { auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(output_count)) { + if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { can_fuse = false; break; } @@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le can_fuse = false; break; } - if (index < 0) index += static_cast(output_count); - if (index < 0 || index >= static_cast(output_count) || gather_outputs[static_cast(index)]) { + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(it->Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.emplace_back(gather_node); gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; } @@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (!can_fuse) continue; ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); split_output_type.mutable_tensor_type()->set_elem_type(element_type); for (int64_t i = 0; i < rank; ++i) { if (i == split_axis) { @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le InlinedVector split_outputs; bool add_squeeze_node = indices_n_dims == 0; if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { split_outputs.emplace_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); } } - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(node.GetExecutionProviderType()); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (onnx_opset_version < 13) { if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } else { if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(output_count)); + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); } if (add_squeeze_node) { @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b82f3345df..17b26ed7ca 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6910,7 +6910,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -6933,8 +6936,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -6962,8 +6965,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -6991,8 +6994,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7023,7 +7026,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -7042,8 +7048,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -7063,8 +7069,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -7084,13 +7090,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + // OpSet-12 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-14 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); @@ -7130,8 +7303,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather indices. @@ -7166,8 +7339,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather axis. @@ -7202,8 +7375,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7250,8 +7423,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14, Tind is int64. @@ -7289,8 +7462,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } }