Support Graph Input and Initializer for GatherToSplit Fusion (#18412)

Support graph input and initializer for GatherToSplit fusion. Previously
the fusion requires Gather nodes consume some other node which cannot be
graph input or initializer.

This helps some model training with such case so that we will not have
GatherGrad in the final graph. GatherGrad is super inefficient in kernel
implementation.
This commit is contained in:
Vincent Wang 2023-11-15 13:46:38 +08:00 committed by GitHub
parent d738ff16ec
commit b0699d901c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 243 additions and 45 deletions

View file

@ -9,7 +9,8 @@
namespace onnxruntime {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<const NodeArg*> node_args;
for (auto node_arg : graph.GetInputs()) {
if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
node_args.push_back(node_arg);
}
}
for (auto entry : graph.GetAllInitializedTensors()) {
if (graph.GetConsumerNodes(entry.first).size() > 1) {
auto node_arg = graph.GetNodeArg(entry.first);
if (node_arg) {
node_args.push_back(node_arg);
}
}
}
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion
@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
size_t output_count = node.GetOutputEdgesCount();
if (output_count <= 1) continue;
auto shape = node.MutableOutputDefs()[0]->Shape();
node_args.push_back(node.OutputDefs()[0]);
}
for (const NodeArg* node_arg : node_args) {
auto shape = node_arg->Shape();
if (!shape) continue;
int64_t rank = static_cast<int64_t>(shape->dim_size());
@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
auto consumers = graph.GetConsumerNodes(node_arg->Name());
size_t consumer_count = consumers.size();
InlinedVector<NodeArg*> gather_outputs(consumer_count, nullptr);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
for (auto consumer : consumers) {
int64_t index, axis, dims;
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
if (!consumer || consumer->InputDefs()[0] != node_arg ||
!IsSupportedGather(graph, *consumer, index, axis, dims)) {
can_fuse = false;
break;
}
@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (axis < 0) axis += rank;
if (first_edge) {
auto dim = shape->dim(static_cast<int>(axis));
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(output_count)) {
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(consumer_count)) {
can_fuse = false;
break;
}
@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
can_fuse = false;
break;
}
if (index < 0) index += static_cast<int64_t>(output_count);
if (index < 0 || index >= static_cast<int64_t>(output_count) || gather_outputs[static_cast<size_t>(index)]) {
if (index < 0) index += static_cast<int64_t>(consumer_count);
if (index < 0 || index >= static_cast<int64_t>(consumer_count) || gather_outputs[static_cast<size_t>(index)]) {
can_fuse = false;
break;
}
Node& gather_node = *graph.GetNode(it->Index());
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.emplace_back(gather_node);
gather_outputs[static_cast<size_t>(index)] = gather_node.MutableOutputDefs()[0];
}
@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (!can_fuse) continue;
ONNX_NAMESPACE::TypeProto split_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type());
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
for (int64_t i = 0; i < rank; ++i) {
if (i == split_axis) {
@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
InlinedVector<NodeArg*> split_outputs;
bool add_squeeze_node = indices_n_dims == 0;
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
}
}
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
Node& split_node =
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(node.GetExecutionProviderType());
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
int onnx_opset_version = -1;
@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (onnx_opset_version < 13) {
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
}
}
} else {
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
}
if (add_squeeze_node) {
@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
}
}
}

View file

@ -6910,7 +6910,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
// OpSet-12
{
@ -6933,8 +6936,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14
@ -6962,8 +6965,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
@ -6991,8 +6994,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
@ -7023,7 +7026,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
// OpSet-12
{
@ -7042,8 +7048,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14
@ -7063,8 +7069,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
@ -7084,13 +7090,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
// OpSet-12
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInitializer<float>({2, 3, 3, 3}, std::vector<float>(54));
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
@ -7130,8 +7303,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// Invalid Gather indices.
@ -7166,8 +7339,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// Invalid Gather axis.
@ -7202,8 +7375,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
@ -7250,8 +7423,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14, Tind is int64.
@ -7289,8 +7462,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}