Modify embedlayernorm fusion due to shape node merging (#4967)

* modify embedlayernorm fusion due to shape integration

* update

* update comments

* review comments

* review comments

* fix test
This commit is contained in:
Ye Wang 2020-09-08 14:17:29 -07:00 committed by GitHub
parent 38453acae3
commit c239ff0750
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 34 deletions

View file

@ -111,6 +111,8 @@ static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Nod
The Expand and Gather on the bottom will not be added to subgraph_node_indices.
It is because they are matched as part of other subgraph.
Two Shape nodes may merge into one.
*/
static bool MatchInputToConcatSubgraph(
@ -134,8 +136,14 @@ static bool MatchInputToConcatSubgraph(
DEBUG_LOG("Failed to find path 1 of position shape.");
return false;
}
const size_t shape_index = edges.size() - 1;
for (size_t i = 0; i < edges.size(); i++) {
if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), 1)) {
// Shape may have multiple outputs due to shape integration
// So check it later
if (i == shape_index) {
continue;
}
DEBUG_LOG("Output edge count not expected for nodes in path 1 of position shape.");
return false;
}
@ -161,9 +169,10 @@ static bool MatchInputToConcatSubgraph(
return false;
}
// Shape may have multiple outputs due to shape integration
// Check it later
if (!optimizer_utils::CheckOutputEdges(graph, edges[0]->GetNode(), 1) ||
!optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2) ||
!optimizer_utils::CheckOutputEdges(graph, edges[2]->GetNode(), 1)) {
!optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2)) {
DEBUG_LOG("Output edge count not expected for nodes in path 2 of position shape.");
return false;
}
@ -189,6 +198,19 @@ static bool MatchInputToConcatSubgraph(
return false;
}
// Check if shape have more than one output, it may due to shape integration
// We check if they share the same node
if (!optimizer_utils::CheckOutputEdges(graph, shape_node_0, 1) ||
!optimizer_utils::CheckOutputEdges(graph, shape_node_1, 1)) {
if (shape_node_0.Index() == shape_node_1.Index() &&
(shape_node_0.GetOutputEdgesCount() == 2 ||
shape_node_0.GetOutputEdgesCount() == 4)) {
DEBUG_LOG("two paths share the same shape");
} else {
return false;
}
}
AddNodes(subgraph_node_indices, edges);
return true;
}
@ -210,6 +232,7 @@ static bool MatchInputToConcatSubgraph(
* Note that position gather node is the node in the bottom of above sub-graph.
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
* Path in ** is an alternative path to check.
* Two shape node may merge into one
*/
static bool MatchPositionEmbeddingSubgraphsFromGather(
Graph& graph,
@ -268,12 +291,19 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
return false;
}
const size_t gather_index = pg_edges.size() - 2;
const size_t shape_index = pg_edges.size() - 1;
// All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges
// And shape node allowed multiple output edges due to shape integration
for (size_t i = 0; i < pg_edges.size(); i++) {
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) {
if (i == gather_index && optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2)) {
continue;
}
if (i == shape_index &&
(optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2) ||
optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 4))) {
continue;
}
DEBUG_LOG("Output edge count not expected for nodes in path1.");
return false;
}

View file

@ -2630,19 +2630,25 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
}
//DistilBert
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx";
static void TestEmbedLayerNormFusionDistilBert(const std::basic_string<ORTCHAR_T>& model_uri,
std::map<std::string, int>& op_to_count,
logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret1.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
op_to_count = CountOpsInGraph(graph);
}
//DistilBert
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
@ -2652,6 +2658,30 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 2);
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
std::shared_ptr<Model> p_model;

View file

@ -278,12 +278,33 @@ def GenerateModel6(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel7(model_name):
batch_size = 2
hidden_size = 4
attention_heads = 2
sequence_length = 3
def GenerateInitializers2(hidden_size):
qkv_weights = [
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
]
initializers = [ # initializers
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
helper.make_tensor('start', TensorProto.INT64, [], [0]),
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
]
return initializers
def GenerateNodes2(attention_heads):
nodes = [
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
@ -313,28 +334,17 @@ def GenerateModel7(model_name):
helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3")
]
qkv_weights = [
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
]
return nodes
initializers = [ # initializers
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
helper.make_tensor('start', TensorProto.INT64, [], [0]),
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
]
def GenerateModel7(model_name):
batch_size = 2
hidden_size = 4
attention_heads = 2
sequence_length = 3
nodes = GenerateNodes2(attention_heads)
initializers = GenerateInitializers2(hidden_size)
graph = helper.make_graph(
nodes,
@ -351,9 +361,87 @@ def GenerateModel7(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel8(model_name):
batch_size = -1
hidden_size = 4
attention_heads = 2
sequence_length = -1
nodes = GenerateNodes2(attention_heads)
del nodes[5:7]
del nodes[1:3]
new_nodes = [
helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"),
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"),
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand")
]
nodes = nodes + new_nodes
initializers = GenerateInitializers2(hidden_size)
graph = helper.make_graph(
nodes,
"EmbedLayerNorm_format8", #name
[ # inputs
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]),
],
[ # outputs
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
],
initializers)
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel9(model_name):
batch_size = -1
hidden_size = 4
attention_heads = 2
sequence_length = -1
nodes = GenerateNodes2(attention_heads)
del nodes[10]
del nodes[5:7]
del nodes[1:3]
new_nodes = [
helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"),
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"),
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"),
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"),
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape",
value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])),
helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6),
]
nodes = nodes + new_nodes
initializers = GenerateInitializers2(hidden_size)
graph = helper.make_graph(
nodes,
"EmbedLayerNorm_format9", #name
[ # inputs
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
],
[ # outputs
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
],
initializers)
model = helper.make_model(graph)
onnx.save(model, model_name)
GenerateModel3('embed_layer_norm_format3.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
GenerateModel5('embed_layer_norm_format5.onnx')
GenerateModel6('embed_layer_norm_format6.onnx')
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')