Match More EmbedLayerNormalization Patterns for Bert Model Graph Fusion (#4354)

match more embed patterns for bert base cased
This commit is contained in:
Cecilia Liu 2020-06-30 13:12:50 -07:00 committed by GitHub
parent 755675541a
commit 37b624b688
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 220 additions and 138 deletions

View file

@ -103,24 +103,24 @@ static void AddNodes(std::vector<NodeIndex>& node_indices,
It is because they are matched as part of other subgraph.
*/
static bool MatchPositionSubgraph(
static bool MatchInputToConcatSubgraph(
Graph& graph,
const Node& expand_node,
const Node& cur_node,
const NodeArg* input_ids,
const int index,
const logging::Logger& logger,
std::vector<NodeIndex>& subgraph_node_indices,
const NodeIndex expected_gather_node_1_index) {
subgraph_node_indices.clear();
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
{0, 1, "Concat", {4, 11}, kOnnxDomain},
{0, index, "Concat", {4, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
};
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(expand_node, true, expand_parent_path1, edges, logger)) {
if (!graph_utils::FindPath(cur_node, true, expand_parent_path1, edges, logger)) {
DEBUG_LOG("Failed to find path 1 of position shape.");
return false;
}
@ -184,60 +184,80 @@ static bool MatchPositionSubgraph(
}
/** Match subgraph like the following:
(input_ids)
/ \
Shape Shape
| |
^Gather (indice=0)^ Gather (indice=1)--+
^|^ ^|^ |
^Unsqueeze^ ^Unsqueeze^ Unsqueeze
^\^ ^/^ |
^\^ ^/^ ConstantOfShape
^\^ ^/^ |
^Concat^ NonZero
| |
| Transpose
| |
| Squeeze
| |
| Cast
| |
| Unsqueeze
+--|----------------------------+
| |
Expand
|
Gather
Note that position gather node is the node in the bottom of above sub-graph.
Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
*/
static bool MatchPositionEmbeddingSubgraph1(
*
* Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^
* / | +-----------------------+
* / v | |
* [input_ids] ^Concat^ -> *Reshape* -> *Equal* -> *Where* -> Expand -> Gather
* \ | | ("position")
* Shape -> ^Gather (indice=1)^ -> ^Unsqueeze^ |
* | |
* +-------------- # one of the below subgraph patterns # ---------------+
* # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze #
* # or #
* # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze #
*
* Note that position gather node is the node in the bottom of above sub-graph.
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
* Path in ** is an alternative path to check.
*/
static bool MatchPositionEmbeddingSubgraphsFromGather(
Graph& graph,
const Node& position_gather_node,
const NodeArg* input_ids,
const logging::Logger& logger,
std::vector<NodeIndex>& subgraph_node_indices) {
subgraph_node_indices.clear();
std::vector<const Node::EdgeEnd*> pg_edges;
// Look for Path 1:
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand --> Gather
if (!graph_utils::FindPath(position_gather_node, true,
{{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}},
pg_edges, logger)) {
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
// --> Cast --> Unsqueeze --> Expand --> Gather
std::vector<graph_utils::EdgeEndToMatch> parent_path_1{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
// Look for Path 2 (Path 1 with no cast):
std::vector<graph_utils::EdgeEndToMatch> parent_path_2{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
// Path 3 Pattern:
// Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand
std::vector<graph_utils::EdgeEndToMatch> parent_path_3{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Range", {1, 11}, kOnnxDomain},
{0, 1, "Cast", {9}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
// Path 4 pattern (Path 3 with no "Cast"):
std::vector<graph_utils::EdgeEndToMatch> parent_path_4{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Range", {1, 11}, kOnnxDomain},
{0, 1, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
// Match one of the three path patterns.
if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) &&
!graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) &&
!graph_utils::FindPath(position_gather_node, true, parent_path_3, pg_edges, logger) &&
!graph_utils::FindPath(position_gather_node, true, parent_path_4, pg_edges, logger)) {
return false;
}
const size_t gather_index = 8;
const size_t gather_index = pg_edges.size() - 2;
// All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges
for (size_t i = 0; i < pg_edges.size(); i++) {
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) {
@ -251,6 +271,18 @@ static bool MatchPositionEmbeddingSubgraph1(
Node& expand_node = *graph.GetNode(pg_edges[0]->GetNode().Index());
Node& gather_node = *graph.GetNode(pg_edges[gather_index]->GetNode().Index());
if (pg_edges[2]->GetNode().OpType() == "Range") {
// Check if the values in "start" and "delta" attributes in Range are expected.
Node& range_node = *graph.GetNode(pg_edges[2]->GetNode().Index());
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) {
DEBUG_LOG("The first input of Range should be a constant with value 0.");
return false;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) {
DEBUG_LOG("The third input of Range should be a constant with value 1.");
return false;
}
}
if (gather_node.GetOutputEdgesCount() == 1) {
// Check if the second input of the Gather node in the path has a constant input of 1
@ -279,7 +311,35 @@ static bool MatchPositionEmbeddingSubgraph1(
subgraph_node_indices.push_back(shape_node_index);
} else { // gather_output_edges_count == 2
if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node.Index())) {
// Match optional Reshape -> Equal -> Where -> Expand
// | |
// --------------------
std::vector<const Node::EdgeEnd*> pg_edges_2;
std::vector<graph_utils::EdgeEndToMatch> path_to_match_1{
{0, 1, "Where", {9}, kOnnxDomain},
{0, 0, "Equal", {1, 11}, kOnnxDomain},
{0, 0, "Reshape", {5}, kOnnxDomain}};
if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) {
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) ||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) ||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[2]->GetNode(), 2)) {
DEBUG_LOG("Optional position subgraph nodes number of outputs unexpected.");
return false;
}
Node& where_node = *graph.GetNode(pg_edges_2[0]->GetNode().Index());
Node& reshape_node = *graph.GetNode(pg_edges_2[2]->GetNode().Index());
if (where_node.MutableInputDefs()[2] != reshape_node.MutableOutputDefs()[0]) {
DEBUG_LOG("Optional position subgraph nodes Where node is expected to be the parent of Reshape.");
return false;
}
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node.
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) {
DEBUG_LOG("Failed to match position subgraph.");
return false;
}
AddNodes(subgraph_node_indices, pg_edges_2);
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) {
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node.
DEBUG_LOG("Failed to match position subgraph.");
return false;
}
@ -290,90 +350,6 @@ static bool MatchPositionEmbeddingSubgraph1(
return true;
}
/** Match subgraph like the following:
(input_ids)
/ \
Shape Shape
| |
Gather (indice=0) Gather (indice=1)--+
| | |
Unsqueeze Unsqueeze Cast(to=7) (Cast is optional)
\ / |
\ / Range(start=0, delta=1)
\ / |
Concat Unsqueeze
| |
+--|----------------------------+
| |
Expand
|
Gather
Note that position gather node is the node in the bottom of above sub-graph.
*/
static bool MatchPositionEmbeddingSubgraph2(
Graph& graph,
const Node& position_gather_node,
const NodeArg* input_ids,
const logging::Logger& logger,
std::vector<NodeIndex>& subgraph_node_indices) {
subgraph_node_indices.clear();
// Match Gather <-- Expand <-- Unsqueeze <-- Range <-- Cast <-- Gather
// Since Range is from opset 11, we only match opset 11 here.
std::vector<NodeIndex> position_parent_nodes;
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {11}, kOnnxDomain},
{0, 0, "Range", {11}, kOnnxDomain},
{0, 1, "Cast", {9}, kOnnxDomain},
{0, 0, "Gather", {11}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
// Cast node might be removed by other optimizer. Here we check a pattern without Cast node.
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_no_cast{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {11}, kOnnxDomain},
{0, 0, "Range", {11}, kOnnxDomain},
{0, 1, "Gather", {11}, kOnnxDomain}};
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_no_cast, edges, logger)) {
DEBUG_LOG("Failed to find path 1.");
return false;
}
}
size_t last_edge = edges.size() - 1;
for (size_t i = 0; i < edges.size(); i++) {
if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), (i == last_edge ? 2u : 1u))) {
DEBUG_LOG("Output edge count not expected for nodes in path 1.");
return false;
}
}
Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index());
Node& range_node = *graph.GetNode(edges[2]->GetNode().Index());
Node& gather_node_1 = *graph.GetNode(edges[last_edge]->GetNode().Index());
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) {
DEBUG_LOG("The first input of Range should be a constant with value 0.");
return false;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) {
DEBUG_LOG("The third input of Range should be a constant with value 1.");
return false;
}
if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node_1.Index())) {
DEBUG_LOG("Failed to match position subgraph.");
return false;
}
AddNodes(subgraph_node_indices, edges);
return true;
}
static bool MatchPositionEmbeddingSubgraph(
Graph& graph,
const Node& add_node,
@ -422,10 +398,8 @@ static bool MatchPositionEmbeddingSubgraph(
}
}
} else {
if (!MatchPositionEmbeddingSubgraph1(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
if (!MatchPositionEmbeddingSubgraph2(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
return false;
}
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
return false;
}
}

View file

@ -2424,6 +2424,35 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
}
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Expand"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["Reshape"], 0);
EXPECT_EQ(op_to_count["Equal"], 0);
EXPECT_EQ(op_to_count["Where"], 0);
EXPECT_EQ(op_to_count["LayerNormalization"], 0);
EXPECT_EQ(op_to_count["SkipLayerNormalization"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 0);
EXPECT_EQ(op_to_count["MatMul"], 1);
EXPECT_EQ(op_to_count["Add"], 2);
EXPECT_EQ(op_to_count["Cast"], 3);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
}
TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) {
auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx";
std::shared_ptr<Model> p_model;

View file

@ -143,7 +143,86 @@ def GenerateModel5(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel6(model_name):
nodes = [ # LayerNorm subgraph
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"),
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"),
helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"),
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0),
helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"),
helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"),
helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"),
helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"],
"range"),
helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"),
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"),
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"),
helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather"),
helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"),
helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
"att",
domain="com.microsoft",
num_heads=2),
helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"),
helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"),
helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2")
]
# hidden_size=4, num_heads=2, max_seq_length=3
initializers = [ # initializers
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
helper.make_tensor('start_0', TensorProto.INT64, [], [0]),
helper.make_tensor('delta_1', TensorProto.INT64, [], [1]),
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]),
helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]),
helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]),
]
graph = helper.make_graph(
nodes,
"EmbedLayerNorm_format6", #name
[ # inputs
helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]),
helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]),
helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]),
],
[ # outputs
helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]),
],
initializers)
model = helper.make_model(graph)
onnx.save(model, model_name)
GenerateModel3('embed_layer_norm_format3.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
GenerateModel5('embed_layer_norm_format5.onnx')
GenerateModel6('embed_layer_norm_format6.onnx')