mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Support Graph Input and Initializer for GatherToSplit Fusion (#18412)
Support graph input and initializer for GatherToSplit fusion. Previously the fusion requires Gather nodes consume some other node which cannot be graph input or initializer. This helps some model training with such case so that we will not have GatherGrad in the final graph. GatherGrad is super inefficient in kernel implementation.
This commit is contained in:
parent
d738ff16ec
commit
b0699d901c
2 changed files with 243 additions and 45 deletions
|
|
@ -9,7 +9,8 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
|
||||
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
|
||||
int64_t& indices_n_dims) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
return false;
|
||||
|
|
@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
InlinedVector<const NodeArg*> node_args;
|
||||
for (auto node_arg : graph.GetInputs()) {
|
||||
if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
|
||||
node_args.push_back(node_arg);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto entry : graph.GetAllInitializedTensors()) {
|
||||
if (graph.GetConsumerNodes(entry.first).size() > 1) {
|
||||
auto node_arg = graph.GetNodeArg(entry.first);
|
||||
if (node_arg) {
|
||||
node_args.push_back(node_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* p_node = graph.GetNode(node_index);
|
||||
if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion
|
||||
|
|
@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
size_t output_count = node.GetOutputEdgesCount();
|
||||
if (output_count <= 1) continue;
|
||||
|
||||
auto shape = node.MutableOutputDefs()[0]->Shape();
|
||||
node_args.push_back(node.OutputDefs()[0]);
|
||||
}
|
||||
|
||||
for (const NodeArg* node_arg : node_args) {
|
||||
auto shape = node_arg->Shape();
|
||||
if (!shape) continue;
|
||||
int64_t rank = static_cast<int64_t>(shape->dim_size());
|
||||
|
||||
|
|
@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
bool first_edge = true;
|
||||
int64_t split_axis = 0;
|
||||
int64_t indices_n_dims = -1;
|
||||
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
|
||||
auto consumers = graph.GetConsumerNodes(node_arg->Name());
|
||||
size_t consumer_count = consumers.size();
|
||||
InlinedVector<NodeArg*> gather_outputs(consumer_count, nullptr);
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
for (auto consumer : consumers) {
|
||||
int64_t index, axis, dims;
|
||||
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
|
||||
if (!consumer || consumer->InputDefs()[0] != node_arg ||
|
||||
!IsSupportedGather(graph, *consumer, index, axis, dims)) {
|
||||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
if (axis < 0) axis += rank;
|
||||
if (first_edge) {
|
||||
auto dim = shape->dim(static_cast<int>(axis));
|
||||
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(output_count)) {
|
||||
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(consumer_count)) {
|
||||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
if (index < 0) index += static_cast<int64_t>(output_count);
|
||||
if (index < 0 || index >= static_cast<int64_t>(output_count) || gather_outputs[static_cast<size_t>(index)]) {
|
||||
if (index < 0) index += static_cast<int64_t>(consumer_count);
|
||||
if (index < 0 || index >= static_cast<int64_t>(consumer_count) || gather_outputs[static_cast<size_t>(index)]) {
|
||||
can_fuse = false;
|
||||
break;
|
||||
}
|
||||
Node& gather_node = *graph.GetNode(it->Index());
|
||||
Node& gather_node = *graph.GetNode(consumer->Index());
|
||||
nodes_to_fuse.emplace_back(gather_node);
|
||||
gather_outputs[static_cast<size_t>(index)] = gather_node.MutableOutputDefs()[0];
|
||||
}
|
||||
|
|
@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
if (!can_fuse) continue;
|
||||
|
||||
ONNX_NAMESPACE::TypeProto split_output_type;
|
||||
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
|
||||
node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type());
|
||||
const ONNX_NAMESPACE::TensorProto_DataType element_type =
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
|
||||
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (i == split_axis) {
|
||||
|
|
@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
InlinedVector<NodeArg*> split_outputs;
|
||||
bool add_squeeze_node = indices_n_dims == 0;
|
||||
if (add_squeeze_node) {
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
for (size_t i = 0; i < consumer_count; ++i) {
|
||||
split_outputs.emplace_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
|
||||
}
|
||||
}
|
||||
|
||||
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
|
||||
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
|
||||
Node& split_node =
|
||||
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
|
||||
{graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs);
|
||||
split_node.AddAttribute("axis", split_axis);
|
||||
split_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
|
||||
|
||||
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
|
||||
int onnx_opset_version = -1;
|
||||
|
|
@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
|
||||
if (onnx_opset_version < 13) {
|
||||
if (add_squeeze_node) {
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
for (size_t i = 0; i < consumer_count; ++i) {
|
||||
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
|
||||
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (onnx_opset_version >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
|
||||
}
|
||||
|
||||
if (add_squeeze_node) {
|
||||
|
|
@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
|
||||
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
|
||||
|
||||
for (size_t i = 0; i < output_count; ++i) {
|
||||
for (size_t i = 0; i < consumer_count; ++i) {
|
||||
Node& squeeze_node =
|
||||
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
|
||||
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
|
||||
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6910,7 +6910,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// OpSet-12
|
||||
{
|
||||
|
|
@ -6933,8 +6936,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-14
|
||||
|
|
@ -6962,8 +6965,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-18
|
||||
|
|
@ -6991,8 +6994,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7023,7 +7026,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
|
|||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// OpSet-12
|
||||
{
|
||||
|
|
@ -7042,8 +7048,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-14
|
||||
|
|
@ -7063,8 +7069,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-18
|
||||
|
|
@ -7084,13 +7090,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
|
||||
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
|
||||
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
|
||||
auto* gather_out_1 = builder.MakeIntermediate();
|
||||
auto* gather_out_2 = builder.MakeIntermediate();
|
||||
auto* gather_out_3 = builder.MakeIntermediate();
|
||||
auto* transpose_out_1 = builder.MakeOutput();
|
||||
auto* transpose_out_2 = builder.MakeOutput();
|
||||
auto* transpose_out_3 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
|
||||
.AddAttribute("axis", static_cast<int64_t>(-2));
|
||||
builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// OpSet-12
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
} else if (node.OpType() == "Squeeze") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-14
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
} else if (node.OpType() == "Squeeze") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[1]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-18
|
||||
{
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
} else if (node.OpType() == "Squeeze") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[1]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInitializer<float>({2, 3, 3, 3}, std::vector<float>(54));
|
||||
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
|
||||
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
|
||||
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
|
||||
auto* gather_out_1 = builder.MakeIntermediate();
|
||||
auto* gather_out_2 = builder.MakeIntermediate();
|
||||
auto* gather_out_3 = builder.MakeIntermediate();
|
||||
auto* transpose_out_1 = builder.MakeOutput();
|
||||
auto* transpose_out_2 = builder.MakeOutput();
|
||||
auto* transpose_out_3 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
|
||||
.AddAttribute("axis", static_cast<int64_t>(-2));
|
||||
builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
|
||||
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Split") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
|
||||
} else if (node.OpType() == "Squeeze") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[1]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
||||
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
return Status::OK();
|
||||
};
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
|
||||
|
|
@ -7130,8 +7303,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// Invalid Gather indices.
|
||||
|
|
@ -7166,8 +7339,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// Invalid Gather axis.
|
||||
|
|
@ -7202,8 +7375,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7250,8 +7423,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// OpSet-14, Tind is int64.
|
||||
|
|
@ -7289,8 +7462,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue