Merge GatherToSplitFusion and #19218 to a General Fusion (#19600)

#19218 tried to fuse Gather/Slice to Split, but the logic has problem.
Scalar value or 1-dim value of indices in Gather node will produce
different result, scalar value will produce a result tensor by removing
the axis dim, will 1-dim indices value will keep that dim, even when the
dim value is 1. For example,

Node
    |-> Gather(indices=[0], axis=axis)
    |-> Gather(indices=[1], axis=axis)
    |-> Slice(index=2, axis=axis)
is same as
Node
   |-> Split(axis=axis)

But
Node
    |-> Gather(indices=0, axis=axis)
    |-> Gather(indices=1, axis=axis)
    |-> Slice(index=2, axis=axis)
is same as
Node
    |-> Split(axis=axis)
        ||-> Squeeze(axis=axis)
        ||-> Squeeze(axis=axis)
        ||->

Previous PR doesn't take such case related to Squeeze/Unsqueeze into
account.

This PR merges #19218 and GatherToSplitFusion to a general fusion, which
relaxes the limit the number of Gather and Slice node number, check all
Gather and Slice consumers, if the indices of Gather and start/end of
Slice can cover the specific dim of the input tensor, then we can fuse
them to a Split, and adding Squeeze if necessary according to the dim
count of the indices tensor in Gather.

@rui-ren, please check if the fix can still be applied to your model.
This commit is contained in:
Vincent Wang 2024-02-29 13:45:58 +08:00 committed by GitHub
parent 7455dd1f32
commit d2e6dd25ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 390 additions and 954 deletions

View file

@ -9,55 +9,144 @@
namespace onnxruntime {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
namespace {
static int64_t GetGatherAxis(const Node& node, int64_t rank) {
int64_t axis = 0;
auto& attrs = node.GetAttributes();
if (attrs.find("axis") != attrs.end()) {
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) {
axis = axis_attr.i();
if (axis < 0) axis += rank;
}
}
return axis;
}
static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) {
if (!optimizer_utils::IsScalar(node_arg)) return false;
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name());
if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
value = *(init_const.data<int64_t>());
rank = tensor_proto->dims_size();
return true;
}
static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) {
if (node.InputDefs().size() < 4) return false;
int64_t unused = 0;
if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false;
if (axis < 0) axis += rank;
return true;
}
static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) {
if (node.OpType() == "Gather") {
axis = GetGatherAxis(node, rank);
return true;
}
if (node.OpType() == "Slice") {
return GetSliceAxis(graph, node, rank, axis);
}
return false;
}
} // namespace
bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank,
int64_t target_axis, int64_t dim_size, InlinedVector<bool>& consumed,
int64_t& start, bool& need_squeeze) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
}
const NodeArg& input_arg = *(node.InputDefs()[1]);
if (!optimizer_utils::IsScalar(input_arg)) return false;
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (!tensor_proto) return false;
if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
index = *(init_const.data<int64_t>());
axis = 0; // Default value.
auto& attrs = node.GetAttributes();
if (attrs.find("axis") != attrs.end()) {
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
if (GetGatherAxis(node, rank) != target_axis) return false;
// Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice.
// We can relax this later if needed.
int64_t indices_n_dims = 0;
if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false;
if (start < 0) start += dim_size;
if (start < 0 || start >= dim_size || consumed[static_cast<size_t>(start)]) return false;
consumed[static_cast<size_t>(start)] = true;
need_squeeze = indices_n_dims == 0;
return true;
}
bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis,
int64_t dim_size, InlinedVector<bool>& consumed, int64_t& start,
int64_t& end) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
}
int64_t axis = 0;
if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false;
int64_t unused = 0;
if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) ||
!GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) {
return false;
}
// Handling start and end according to schema definition.
if (start < 0) start += dim_size;
if (end < 0) end += dim_size;
if (start < 0)
start = 0;
else if (start > dim_size)
start = dim_size;
if (end < 0)
end = 0;
else if (end > dim_size)
end = dim_size;
if (start >= end) return false;
if (node.InputDefs().size() >= 5) {
int64_t step = 0;
if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false;
}
for (int64_t i = start; i < end; ++i) {
if (consumed[static_cast<size_t>(i)]) return false;
consumed[static_cast<size_t>(i)] = true;
}
indices_n_dims = tensor_proto->dims_size();
return true;
}
/*
GatherToSplitFusion is to fuse:
Node -> Gather(index=0, axis=axis)
|-> Gather(index=1, axis=axis)
|-> Gather(index=2, axis=axis)
GatherSliceToSplitFusion is to fuse:
Node -> Gather(indices=0, axis=axis)
|-> Gather(indices=[1], axis=axis)
|-> Slice(start=2, end=3, axes=[axis])
|...
To
Node -> Split -> Squeeze(axis=axis)
|-> Squeeze(axis=axis)
|-> Squeeze(axis=axis)
|->
|->
|...
So that we can use one kernel to finish the job.
The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover
all the elements in the target axis. Step of Slice node should be 1.
*/
Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
// Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13.
// To make code simple, support OpSet >= 13 only.
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
if (onnx_opset_version < 13) return Status::OK();
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<const NodeArg*> node_args;
InlinedVector<const NodeArg*> candidate_args;
for (auto node_arg : graph.GetInputs()) {
if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
node_args.push_back(node_arg);
candidate_args.push_back(node_arg);
}
}
@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (graph.GetConsumerNodes(entry.first).size() > 1) {
auto node_arg = graph.GetNodeArg(entry.first);
if (node_arg) {
node_args.push_back(node_arg);
candidate_args.push_back(node_arg);
}
}
}
@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
size_t output_count = node.GetOutputEdgesCount();
if (output_count <= 1) continue;
node_args.push_back(node.OutputDefs()[0]);
candidate_args.push_back(node.OutputDefs()[0]);
}
for (const NodeArg* node_arg : node_args) {
for (const NodeArg* node_arg : candidate_args) {
auto shape = node_arg->Shape();
if (!shape) continue;
int64_t rank = static_cast<int64_t>(shape->dim_size());
bool can_fuse = true;
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
auto consumers = graph.GetConsumerNodes(node_arg->Name());
size_t consumer_count = consumers.size();
InlinedVector<NodeArg*> gather_outputs(consumer_count, nullptr);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
InlinedVector<const Node*> condidate_consumers;
for (auto consumer : consumers) {
int64_t index, axis, dims;
if (!consumer || consumer->InputDefs()[0] != node_arg ||
!IsSupportedGather(graph, *consumer, index, axis, dims)) {
can_fuse = false;
break;
if (consumer && consumer->InputDefs()[0] == node_arg &&
(consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) {
condidate_consumers.emplace_back(consumer);
}
if (indices_n_dims == -1) {
indices_n_dims = dims;
} else if (indices_n_dims != dims) {
// Not the same number of dimensions (0 or 1) for all scalar indices.
can_fuse = false;
break;
}
if (axis < 0) axis += rank;
if (first_edge) {
auto dim = shape->dim(static_cast<int>(axis));
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(consumer_count)) {
can_fuse = false;
break;
}
split_axis = axis;
first_edge = false;
} else if (axis != split_axis) {
can_fuse = false;
break;
}
if (index < 0) index += static_cast<int64_t>(consumer_count);
if (index < 0 || index >= static_cast<int64_t>(consumer_count) || gather_outputs[static_cast<size_t>(index)]) {
can_fuse = false;
break;
}
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.emplace_back(gather_node);
gather_outputs[static_cast<size_t>(index)] = gather_node.MutableOutputDefs()[0];
}
if (!can_fuse) continue;
ONNX_NAMESPACE::TypeProto split_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
for (int64_t i = 0; i < rank; ++i) {
if (i == split_axis) {
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
if (condidate_consumers.size() < 2) continue;
int64_t axis = 0;
if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue;
auto dim = shape->dim(static_cast<int>(axis));
if (!utils::HasDimValue(dim)) continue;
int64_t dim_size = dim.dim_value();
InlinedVector<bool> consumed(static_cast<size_t>(dim_size), false);
bool can_fuse = true;
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
InlinedVector<int64_t> starts;
InlinedHashMap<int64_t, std::tuple<NodeArg*, int64_t, bool>> output_info_map;
for (auto consumer : condidate_consumers) {
if (!consumer || consumer->InputDefs()[0] != node_arg) {
can_fuse = false;
break;
}
int64_t start = 0, end = 0;
bool need_squeeze = false;
if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) {
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.emplace_back(gather_node);
starts.emplace_back(start);
output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze);
} else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) {
Node& slice_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.emplace_back(slice_node);
starts.emplace_back(start);
output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false);
} else {
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
can_fuse = false;
break;
}
}
if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue;
std::sort(starts.begin(), starts.end());
InlinedVector<NodeArg*> split_outputs;
bool add_squeeze_node = indices_n_dims == 0;
if (add_squeeze_node) {
for (size_t i = 0; i < consumer_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
}
}
Node& split_node =
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
if (onnx_opset_version < 13) {
if (add_squeeze_node) {
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
InlinedVector<int64_t> split_values;
for (int64_t start : starts) {
auto& output_info = output_info_map[start];
NodeArg* original_output_arg = std::get<0>(output_info);
int64_t split_value = std::get<1>(output_info);
split_values.emplace_back(split_value);
if (std::get<2>(output_info)) {
ONNX_NAMESPACE::TypeProto split_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
for (int64_t i = 0; i < rank; ++i) {
if (i == axis) {
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value);
} else {
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
}
}
}
} else {
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
}
if (add_squeeze_node) {
NodeArg* split_output_arg =
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type);
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes"));
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> axes_value{split_axis};
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
axes_initializer_proto.add_int64_data(axis);
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
}
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes",
{split_output_arg, axes_arg}, {original_output_arg});
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
split_outputs.emplace_back(split_output_arg);
} else {
split_outputs.emplace_back(original_output_arg);
}
}
for (Node& n : nodes_to_fuse) {
graph_utils::RemoveNodeOutputEdges(graph, n);
graph.RemoveNode(n.Index());
ONNX_NAMESPACE::TensorProto split_initializer_proto;
split_initializer_proto.set_name(graph.GenerateNodeName("splits"));
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
split_node.AddAttribute("axis", axis);
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
for (Node& node : nodes_to_fuse) {
graph_utils::RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
}
modified = true;

View file

@ -8,19 +8,23 @@
namespace onnxruntime {
/**
@Class GatherToSplitFusion
@Class GatherSliceToSplitFusion
Fuse multiple Gather nodes that comsuming one output to one Split node.
Fuse multiple Gather/Slice nodes that comsuming one output to one Split node.
*/
class GatherToSplitFusion : public GraphTransformer {
class GatherSliceToSplitFusion : public GraphTransformer {
public:
GatherToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {}
GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, bool& need_squeeze) const;
bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, int64_t& end) const;
};
/**

View file

@ -1,344 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/gather_slice_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index,
int64_t& axis, int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
}
const NodeArg& input_arg = *(node.InputDefs()[1]);
if (!optimizer_utils::IsScalar(input_arg)) return false;
const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (!indices_init) return false;
if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false;
// get the index value
Initializer init_const(*indices_init, graph.ModelPath());
index = *(init_const.data<int64_t>());
// get attributes value
axis = 0;
auto& attrs = node.GetAttributes();
if (attrs.find("axis") != attrs.end()) {
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
}
indices_n_dims = indices_init->dims_size();
return true;
}
bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node,
InlinedVector<int64_t>& starts,
InlinedVector<int64_t>& ends,
InlinedVector<int64_t>& axes,
InlinedVector<int64_t>& steps) const {
// check the version of Slice ops
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
}
// get the opset version
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
// If Slice op of opset version 1
if (onnx_opset_version == 1) {
if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) ||
!graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) ||
starts.size() != ends.size()) {
return false;
}
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) {
return false;
}
}
// If Slice op of opset version >= 10
if (onnx_opset_version >= 10) {
// node inputs include: starts - ends - axes - steps
// return a pointer to the corresponding NodeArg if input of the node at the index exists
auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* {
const auto& input_defs = node.InputDefs();
const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr;
return (input == nullptr || !input->Exists()) ? nullptr : input;
};
// return a pointer to the initializer if it is constant; otherwise, a nullptr
auto get_initializer_if_constant =
[&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* {
const NodeArg* input = get_input_if_exists(input_index);
return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr;
};
// return the initialization data if it is constant
auto get_initializer_data =
[&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector<int64_t> {
Initializer init(*slice_initializer, graph.ModelPath());
if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) {
int32_t* init_data = init.data<int32_t>();
return InlinedVector<int64_t>(init_data, init_data + init.size());
}
if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) {
int64_t* init_data = init.data<int64_t>();
return InlinedVector<int64_t>(init_data, init_data + init.size());
}
return {};
};
// starts and ends inputs have to exist, be constants and be of the same size.
const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1);
const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2);
const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3);
const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4);
if (!starts_init || !ends_init || !axes_init || !steps_init) {
return false;
}
starts = get_initializer_data(starts_init);
ends = get_initializer_data(ends_init);
axes = get_initializer_data(axes_init);
steps = get_initializer_data(steps_init);
if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) {
return false;
}
if (axes_init->dims_size() != 1 || static_cast<size_t>(axes_init->dims().Get(0)) != starts.size()) {
return false;
}
// if steps exists, it should be constant and all value should be 1
if (steps.size() != starts.size()) {
return false;
}
for (int64_t step : steps) {
if (step != 1) {
return false;
}
}
}
return true;
}
/*
GatherToSplitFusion is to fuse:
Node
|-> Gather(index=0, axis=axis)
|-> Gather(index=1, axis=axis)
|-> Slice(index=2, axis=axis)
To
Node
|-> Split(index=0)
So that we can use one kernel to finish the job.
*/
Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<const NodeArg*> output_args;
// Iterate the topological order and get Reshape ops
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr) continue;
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
// Currently only catch after Reshape ops, optimize in the future
if (node.OpType() != "Reshape") continue;
size_t output_count = node.GetOutputEdgesCount();
// We only catch 1 scenario for Multi Query Attention for now.
// |---> Gather
// Reshape |---> Gather
// |---> Slice
// |... or (other ops)
// Get the output into node args
if (output_count < 3) continue;
output_args.push_back(node.OutputDefs()[0]);
}
// iterate the children of Reshape node
for (const NodeArg* node_arg : output_args) {
auto shape = node_arg->Shape();
if (!shape) continue;
auto consumers = graph.GetConsumerNodes(node_arg->Name());
size_t consumer_count = consumers.size();
// get the tensor rank
int64_t rank = static_cast<int64_t>(shape->dim_size());
bool can_fuse = true;
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
// Fuse 2 Gathers and 1 slice to Split
// Get those outputs as Split outputs
InlinedVector<NodeArg*> split_outputs(3);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
size_t gather_node_count = 2, slice_node_count = 0;
// find the nodes to be merged
for (auto consumer : consumers) {
int64_t index, axis, dims;
InlinedVector<int64_t> starts, ends, axes, steps;
bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims);
bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps);
if ((!consumer || consumer->InputDefs()[0] != node_arg) ||
(!IsSupportedGatherOps && !IsSupportedSliceOps)) {
break;
}
if (IsSupportedGatherOps) {
if (indices_n_dims == -1) {
indices_n_dims = dims;
} else if (indices_n_dims != dims) {
// Not the same number of dimensions (0 or 1) for all scalar indices.
can_fuse = false;
break;
}
if (axis < 0) axis += rank;
if (first_edge) {
auto dim = shape->dim(static_cast<int>(axis));
// dim.dim_value() = 73
if (!utils::HasDimValue(dim)) {
can_fuse = false;
break;
}
split_axis = axis;
first_edge = false;
} else if (axis != split_axis) {
can_fuse = false;
break;
}
if (index < 0) index += static_cast<int64_t>(consumer_count);
if (index < 0 || index >= static_cast<int64_t>(consumer_count)) {
can_fuse = false;
break;
}
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.push_back(gather_node);
NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0];
split_outputs[gather_node_count--] = gather_output_args;
}
// check the Slice Ops
if (IsSupportedSliceOps) {
if (axes[0] != axis && !first_edge) {
can_fuse = false;
break;
}
Node& slice_node = *graph.GetNode(consumer->Index());
NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0];
nodes_to_fuse.push_back(slice_node);
split_outputs[slice_node_count++] = slice_output_args;
}
}
// condition check
if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue;
// generate the split node and merge the kernel
ONNX_NAMESPACE::TypeProto split_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
node_arg->TypeAsProto()->tensor_type().elem_type());
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
for (int64_t i = 0; i < rank; i++) {
if (i == split_axis)
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
else
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
}
InlinedVector<NodeArg*> split_output_types;
for (size_t i = 0; i < consumer_count; ++i) {
split_output_types.push_back(
&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type));
}
// Generate the Split Node
ONNX_NAMESPACE::TensorProto split_initializer_proto;
split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split"));
split_initializer_proto.add_dims(static_cast<int64_t>(3));
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
auto dim_value = shape->dim(static_cast<int>(split_axis)).dim_value();
// Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2
int64_t slice_dim = static_cast<int64_t>(dim_value - 2);
InlinedVector<int64_t> split_value{{slice_dim, 1, 1}};
split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t));
NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
Node& split_node =
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion",
{graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
}
for (Node& node_to_fuse : nodes_to_fuse) {
graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse);
graph.RemoveNode(node_to_fuse.Index());
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -1,32 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@class GatherSliceToSplitFusion
Fuse (2 Gather nodes + 1 Slice) to 1 split node.
*/
class GatherSliceToSplitFusion : public GraphTransformer {
private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
int64_t& indices_n_dims) const;
bool IsSupportedSlice(const Graph& graph, const Node& node,
InlinedVector<int64_t>& starts,
InlinedVector<int64_t>& ends,
InlinedVector<int64_t>& axes,
InlinedVector<int64_t>& steps) const;
public:
GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -37,7 +37,6 @@
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gather_slice_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -307,9 +306,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps));

View file

@ -42,7 +42,6 @@
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gather_slice_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -7059,130 +7058,14 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) {
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion) {
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
return Status::OK();
};
// OpSet-12
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
@ -7198,7 +7081,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
@ -7207,23 +7091,16 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
return Status::OK();
};
// OpSet-12
// OpSet-12, not support
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
@ -7233,156 +7110,159 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 2);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) {
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllSlice_GraphInput) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* data_arg = builder.MakeInput<float>({{2, 3, 8, 3}});
auto* starts_1 = builder.MakeInitializer<int64_t>({1}, {0});
auto* ends_1 = builder.MakeInitializer<int64_t>({1}, {2});
auto* axes_1 = builder.MakeInitializer<int64_t>({1}, {2});
auto* steps_1 = builder.MakeInitializer<int64_t>({1}, {1});
auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {2});
auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
auto* starts_3 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* ends_3 = builder.MakeInitializer<int64_t>({1}, {16});
auto* axes_3 = builder.MakeInitializer<int64_t>({1}, {2});
auto* slice_out_1 = builder.MakeIntermediate();
auto* slice_out_2 = builder.MakeIntermediate();
auto* slice_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1});
builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3});
builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 3);
return Status::OK();
};
// OpSet-12
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
}
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
return Status::OK();
};
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-14
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{144}});
auto* shape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 8, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(5)});
auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {6});
auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {8});
auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-3});
auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(4)});
auto* starts_4 = builder.MakeInitializer<int64_t>({1}, {-16});
auto* ends_4 = builder.MakeInitializer<int64_t>({1}, {4});
auto* axes_4 = builder.MakeInitializer<int64_t>({1}, {1});
auto* gather_out_1 = builder.MakeIntermediate();
auto* slice_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* slice_out_4 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
auto* transpose_out_4 = builder.MakeOutput();
builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(1));
builder.AddNode("Slice", {reshape_out, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(-3));
builder.AddNode("Slice", {reshape_out, starts_4, ends_4, axes_4}, {slice_out_4});
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {slice_out_4}, {transpose_out_4})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 2);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(1 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(1 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInitializer<float>({2, 3, 3, 3}, std::vector<float>(54));
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
@ -7430,31 +7310,31 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
return Status::OK();
};
// Invalid shape.
// Not cover all elements of specific dimension.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{72}});
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* shape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 4, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
@ -7467,63 +7347,65 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(2));
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// Invalid Gather indices.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// Invalid Gather axis.
// Has overlap.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{2, 3, 8, 3}});
auto* starts_1 = builder.MakeInitializer<int64_t>({1}, {0});
auto* ends_1 = builder.MakeInitializer<int64_t>({1}, {3});
auto* axes_1 = builder.MakeInitializer<int64_t>({1}, {2});
auto* steps_1 = builder.MakeInitializer<int64_t>({1}, {1});
auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {2});
auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
auto* starts_3 = builder.MakeInitializer<int64_t>({1}, {-2});
auto* ends_3 = builder.MakeInitializer<int64_t>({1}, {16});
auto* axes_3 = builder.MakeInitializer<int64_t>({1}, {2});
auto* slice_out_1 = builder.MakeIntermediate();
auto* slice_out_2 = builder.MakeIntermediate();
auto* slice_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1});
builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3});
builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
// Invalid axis.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* shape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
@ -7550,7 +7432,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
@ -7643,143 +7525,5 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
}
}
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) {
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* reshape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 512, 73, 64}});
builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out});
// Create Gather-1 Ops
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(-2)});
auto* gather_out_1 = builder.MakeIntermediate<float>({{2, 512, 1, 64}});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
// Create Transpose 1-Ops
auto* transpose_out_1 = builder.MakeOutput();
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
// Create Gather-2 Ops
auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(-1)});
auto* gather_out_2 = builder.MakeIntermediate<float>({{2, 512, 1, 64}});
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(2));
// Create Transpose-2 Ops
auto* transpose_out_2 = builder.MakeOutput();
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
// Create Slice Ops
auto* slice_output = builder.MakeIntermediate();
auto* starts = builder.MakeInitializer<int64_t>({1}, {0});
auto* ends = builder.MakeInitializer<int64_t>({1}, {-2});
auto* axes = builder.MakeInitializer<int64_t>({1}, {2});
auto* steps = builder.MakeInitializer<int64_t>({1}, {1});
builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output});
// Create Shape-1 Ops
auto* shape_output_1 = builder.MakeOutput();
builder.AddNode("Shape", {slice_output}, {shape_output_1});
// Create Shape-2 Ops
auto* shape_output_2 = builder.MakeOutput();
builder.AddNode("Shape", {slice_output}, {shape_output_2});
// Create Transpose-3 Ops
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Transpose", {slice_output}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(static_cast<int>(attrs.at("axis").i()) == 2);
}
}
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* reshape_arg = builder.MakeInput<int64_t>({{4}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 512, 73, 64}});
builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out});
// Create Gather-1 Ops
auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(-2)});
auto* gather_out_1 = builder.MakeIntermediate<float>({{2, 512, 1, 64}});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
// Create Transpose 1-Ops
auto* transpose_out_1 = builder.MakeOutput();
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
// Create Slice Ops
auto* slice_output = builder.MakeIntermediate();
auto* starts = builder.MakeInitializer<int64_t>({1}, {0});
auto* ends = builder.MakeInitializer<int64_t>({1}, {-2});
auto* axes = builder.MakeInitializer<int64_t>({1}, {2});
auto* steps = builder.MakeInitializer<int64_t>({1}, {1});
builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output});
// Create Shape-1 Ops
auto* shape_output_1 = builder.MakeOutput();
builder.AddNode("Shape", {slice_output}, {shape_output_1});
// Create Shape-2 Ops
auto* shape_output_2 = builder.MakeOutput();
builder.AddNode("Shape", {slice_output}, {shape_output_2});
// Create Transpose-3 Ops
auto* transpose_out_3 = builder.MakeOutput();
builder.AddNode("Transpose", {slice_output}, {transpose_out_3})
.AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -24,7 +24,6 @@
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gather_slice_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -139,9 +138,8 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(compatible_eps));
// If a model with Q, DQ nodes is being used for the purpose of training, it must be for
// Quantization Aware Training. So, replace QDQ nodes with FakeQuant.
transformers.emplace_back(std::make_unique<QDQFusion>(compatible_eps));