mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add support for 'axes' attr of unsqueeze in opset 13 and add ut (#14071)
Since opset 13, 'axes' attr of unsqueeze become an input of unsqueeze, add support for it and add ut.
This commit is contained in:
parent
7654cd50e8
commit
f5b4b0f77d
2 changed files with 138 additions and 4 deletions
|
|
@ -10,6 +10,17 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
bool GetAxesFromUnsqueezeNode(const Graph& graph, const Node& unsqueeze, InlinedVector<int64_t>& axes) {
|
||||
if (graph_utils::MatchesOpSinceVersion(unsqueeze, {1, 11})) {
|
||||
return graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes);
|
||||
} else if (graph_utils::MatchesOpSinceVersion(unsqueeze, {13})) {
|
||||
const NodeArg* axes_node_arg = unsqueeze.InputDefs()[1];
|
||||
return optimizer_utils::AppendTensorFromInitializer(graph, *axes_node_arg, axes, true);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
@ -145,15 +156,23 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
|
|||
std::vector<graph_utils::EdgeEndToMatch> parent_path{
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
{0, 0, "Shape", {1, 13, 15}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (graph_utils::FindPath(concat, true, parent_path, edges, logger)) {
|
||||
const Node& unsqueeze = edges[0]->GetNode();
|
||||
const Node& gather = edges[1]->GetNode();
|
||||
const Node& shape = edges[2]->GetNode();
|
||||
|
||||
if (graph_utils::MatchesOpSinceVersion(shape, {15})) {
|
||||
const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape, "start");
|
||||
const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape, "end");
|
||||
if (!((!start_attr || static_cast<int>(start_attr->i()) == 0) && (!end_attr))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
InlinedVector<int64_t> axes;
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes) && axes.size() == 1 && axes[0] == 0)) {
|
||||
if (!(GetAxesFromUnsqueezeNode(graph, unsqueeze, axes) && axes.size() == 1 && axes[0] == 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -275,7 +294,7 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg&
|
|||
graph_utils::FindPath(concat, true, unsqueeze_path, edges, logger)) {
|
||||
const Node& unsqueeze = edges[0]->GetNode();
|
||||
InlinedVector<int64_t> axes;
|
||||
if (!(graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes) && axes.size() == 1 && axes[0] == 0)) {
|
||||
if (!(GetAxesFromUnsqueezeNode(graph, unsqueeze, axes) && axes.size() == 1 && axes[0] == 0)) {
|
||||
return false;
|
||||
}
|
||||
// Unsqueeze_path is found, check for "one-element subgraph -> concat" or "shape -> slice -> squeeze ->
|
||||
|
|
@ -344,7 +363,8 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
|
|||
}
|
||||
const Node& concat = *p_concat;
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "ConcatTraining", {1}, kMSDomain)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4312,6 +4312,120 @@ TEST_F(GraphTransformationTests, BitmaskDropoutFusionTest) {
|
|||
TestBitmaskDropoutFusion(MODEL_FOLDER "fusion/bitmask_bias_dropout_fusion_residual.onnx", true, *logger_, 0, 0, 0, 0,
|
||||
1, 0, 1);
|
||||
}
|
||||
|
||||
/*
|
||||
This test build a graph like:
|
||||
input0 input1
|
||||
\ /
|
||||
Add
|
||||
-----------------|
|
||||
| |
|
||||
| Shape
|
||||
| / \
|
||||
| Gather0 Gather1
|
||||
| / \
|
||||
| Unsqueeze0 Unsqueeze1 (Constant Initializer) (Constant Initializer)
|
||||
| \ / / /
|
||||
| \ / / /
|
||||
| ConcatTraining ------- ------------
|
||||
\ /
|
||||
\ /
|
||||
Reshape
|
||||
|
||||
|
||||
After fusion, the graph become:
|
||||
input0 input1
|
||||
\ /
|
||||
Add (Constant Initializer)
|
||||
\ /
|
||||
Reshape
|
||||
|
||||
*/
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
|
||||
constexpr const int batch_size = 64;
|
||||
constexpr const int seq_lenth = 1024;
|
||||
constexpr const int hidden_size = 1024;
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Shape"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.ConcatTraining"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Reshape"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Shape"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.ConcatTraining"] == 0);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Reshape"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
const std::vector<int> opsets{11, 12, 13, 14, 15, 15};
|
||||
bool shape_test_for_opset15 = false;
|
||||
|
||||
for (auto& opset_version : opsets) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg0 = builder.MakeInput<float>({{batch_size, seq_lenth, hidden_size}});
|
||||
auto* input_arg1 = builder.MakeInput<float>({{hidden_size}});
|
||||
auto* scalar_int_0 = builder.MakeInitializer<int64_t>({}, {0});
|
||||
auto* scalar_int_1 = builder.MakeInitializer<int64_t>({}, {1});
|
||||
auto* single_value_1d_int_0 = builder.MakeInitializer<int64_t>({1}, {0});
|
||||
auto* single_value_1d_int_16 = builder.MakeInitializer<int64_t>({1}, {16});
|
||||
auto* single_value_1d_int_64 = builder.MakeInitializer<int64_t>({1}, {64});
|
||||
auto* add_out = builder.MakeIntermediate();
|
||||
auto* shape_out = builder.MakeIntermediate();
|
||||
auto* gather_out_0 = builder.MakeIntermediate();
|
||||
auto* gather_out_1 = builder.MakeIntermediate();
|
||||
auto* unsqueeze_out_0 = builder.MakeIntermediate();
|
||||
|
||||
auto* unsqueeze_out_1 = builder.MakeIntermediate();
|
||||
auto* concattraining1_out = builder.MakeIntermediate();
|
||||
auto* concattraining1_length = builder.MakeIntermediate();
|
||||
auto* out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Add", {input_arg0, input_arg1}, {add_out});
|
||||
if (opset_version == 15) {
|
||||
if (shape_test_for_opset15) {
|
||||
auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out});
|
||||
shape_1.AddAttribute("start", (int64_t)1);
|
||||
shape_1.AddAttribute("end", (int64_t)2);
|
||||
} else {
|
||||
builder.AddNode("Shape", {add_out}, {shape_out}).AddAttribute("start", (int64_t)0);
|
||||
shape_test_for_opset15 = true;
|
||||
}
|
||||
} else {
|
||||
builder.AddNode("Shape", {add_out}, {shape_out});
|
||||
}
|
||||
builder.AddNode("Gather", {shape_out, scalar_int_0}, {gather_out_0});
|
||||
builder.AddNode("Gather", {shape_out, scalar_int_1}, {gather_out_1});
|
||||
if (opset_version >= 13) {
|
||||
builder.AddNode("Unsqueeze", {gather_out_0, single_value_1d_int_0}, {unsqueeze_out_0});
|
||||
builder.AddNode("Unsqueeze", {gather_out_1, single_value_1d_int_0}, {unsqueeze_out_1});
|
||||
} else {
|
||||
builder.AddNode("Unsqueeze", {gather_out_0}, {unsqueeze_out_0}).AddAttribute("axes", std::vector<int64_t>{0});
|
||||
builder.AddNode("Unsqueeze", {gather_out_1}, {unsqueeze_out_1}).AddAttribute("axes", std::vector<int64_t>{0});
|
||||
}
|
||||
builder.AddNode("ConcatTraining", {unsqueeze_out_0, unsqueeze_out_1, single_value_1d_int_16, single_value_1d_int_64},
|
||||
{concattraining1_out, concattraining1_length}, "com.microsoft").AddAttribute("axis", static_cast<int64_t>(0));
|
||||
builder.AddNode("Reshape", {add_out, concattraining1_out}, {out});
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<ReshapeFusion>();
|
||||
if (opset_version == 15 && shape_test_for_opset15) {
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, pre_graph_checker));
|
||||
} else{
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(GraphTransformationTests, LayerNormFusionTest) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue