Fix a bug in EmbedLayerNorm fusion (#5150)

* fix embedlayernorm bug

* review comments

* interim checkin

* review comments

* Fix core dump in MacOS

* remove unnecessary lines

* update document

* Update graph_utils.cc

* Update onnx_exporter.py

* resolve comments
This commit is contained in:
Ye Wang 2020-09-21 12:26:14 -07:00 committed by GitHub
parent aefb2cc49b
commit 65740deb10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 56 deletions

View file

@ -753,19 +753,26 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
}
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
std::queue<const Node*> q;
std::vector<NodeIndex> nodes_to_remove;
q.push(&start_node);
std::queue<NodeIndex> q;
std::unordered_set<NodeIndex> removed_nodes;
q.push(start_node.Index());
bool is_start_node(true);
// From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output.
while (q.size() != 0) {
const Node& cur_node = *(q.front());
while (!q.empty()) {
NodeIndex cur_node_index = q.front();
q.pop();
if (removed_nodes.find(cur_node_index) != removed_nodes.end()) {
continue;
}
// Each eligible node in the subgraph must have less than one output edge and no output should be
// the graph output
const Node& cur_node = *graph.GetNode(cur_node_index);
if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) {
continue;
}
nodes_to_remove.push_back(cur_node.Index());
// push the parents of current node to the queue.
for (unsigned int i = 0; i < cur_node.InputDefs().size(); ++i) {
const std::string& input_name = GetNodeInputName(cur_node, i);
@ -773,19 +780,25 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
// skip initializers and graph inputs
continue;
}
q.push(GetInputNode(cur_node, i));
const Node* parent_node = GetInputNode(cur_node, i);
q.push(parent_node->Index());
}
if (is_start_node || cur_node.GetOutputEdgesCount() == 0) {
Node* cur_node_p = graph.GetNode(cur_node_index);
RemoveNodeOutputEdges(graph, *cur_node_p);
graph.RemoveNode(cur_node_index);
removed_nodes.insert(cur_node_index);
is_start_node = false;
}
}
if (nodes_to_remove.size() <= 0) {
if (removed_nodes.size() == 0) {
// Nothing to remove
return false;
}
// Remove nodes that are not used anymore.
for (const auto& node_index : nodes_to_remove) {
Node* node = graph.GetNode(node_index);
RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
return true;
}

View file

@ -65,18 +65,6 @@ static bool CheckInput(NodeArg* input, const logging::Logger& logger) {
return true;
}
static void AddNodes(std::vector<NodeIndex>& node_indices,
const std::vector<const Node::EdgeEnd*>& edges) {
for (size_t i = 0; i < edges.size(); i++) {
auto item = edges[i]->GetNode().Index();
// Avoid duplication.
if (std::find(node_indices.begin(), node_indices.end(), item) != node_indices.end()) {
continue;
}
node_indices.push_back(item);
}
}
static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector<std::string>& expected_types) {
for (const std::string& expected_type : expected_types) {
if (start == end || (*start).OpType().compare(expected_type) != 0) {
@ -121,9 +109,7 @@ static bool MatchInputToConcatSubgraph(
const NodeArg* input_ids,
const int index,
const logging::Logger& logger,
std::vector<NodeIndex>& subgraph_node_indices,
const NodeIndex expected_gather_node_1_index) {
subgraph_node_indices.clear();
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
{0, index, "Concat", {4, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
@ -157,8 +143,6 @@ static bool MatchInputToConcatSubgraph(
return false;
}
AddNodes(subgraph_node_indices, edges);
std::vector<graph_utils::EdgeEndToMatch> concat_parent_path{
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
@ -211,7 +195,6 @@ static bool MatchInputToConcatSubgraph(
}
}
AddNodes(subgraph_node_indices, edges);
return true;
}
@ -238,9 +221,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
Graph& graph,
const Node& position_gather_node,
const NodeArg* input_ids,
const logging::Logger& logger,
std::vector<NodeIndex>& subgraph_node_indices) {
subgraph_node_indices.clear();
const logging::Logger& logger) {
std::vector<const Node::EdgeEnd*> pg_edges;
// Look for Path 1:
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
@ -348,8 +329,6 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
DEBUG_LOG("The parent of shape nodes are expected to be input_ids.");
return false;
}
subgraph_node_indices.push_back(shape_node_index);
} else { // gather_output_edges_count == 2
// Match optional Reshape -> Equal -> Where -> Expand
// | |
@ -373,20 +352,17 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
return false;
}
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node.
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) {
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, gather_node.Index())) {
DEBUG_LOG("Failed to match position subgraph.");
return false;
}
AddNodes(subgraph_node_indices, pg_edges_2);
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) {
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, gather_node.Index())) {
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node.
DEBUG_LOG("Failed to match position subgraph.");
return false;
}
}
AddNodes(subgraph_node_indices, pg_edges);
return true;
}
@ -438,11 +414,12 @@ static bool MatchPositionEmbeddingSubgraph(
}
}
} else {
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger)) {
return false;
}
}
subgraph_node_indices.clear();
subgraph_node_indices.push_back(position_gather_node.Index());
return true;
}
@ -507,7 +484,6 @@ static void CreateEmbedLayernormNode(Graph& graph,
NodeArg* word_embedding,
NodeArg* position_embedding,
NodeArg* segment_embedding,
Node& layer_norm_node) {
// Cast input_ids and segment_ids to int32 if needed.
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
@ -528,7 +504,6 @@ static void CreateEmbedLayernormNode(Graph& graph,
position_embedding,
segment_embedding,
layer_norm_node.MutableInputDefs()[1],
layer_norm_node.MutableInputDefs()[2]};
auto& mask_index = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mask_index"), nullptr);
@ -705,6 +680,12 @@ static bool FuseSubGraph(Graph& graph,
CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding,
layer_norm_node);
if (!nodes_to_remove.empty()) {
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0]));
}
nodes_to_remove.clear();
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(segment_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
@ -712,7 +693,7 @@ static bool FuseSubGraph(Graph& graph,
nodes_to_remove.push_back(layer_norm_add_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
for (const NodeIndex index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
@ -725,7 +706,6 @@ static bool FuseSubGraph(Graph& graph,
static bool FuseSubGraphDistilBert(Graph& graph,
Node& layer_norm_add_node,
Node& layer_norm_node,
const logging::Logger& logger) {
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
@ -796,12 +776,18 @@ static bool FuseSubGraphDistilBert(Graph& graph,
CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr,
layer_norm_node);
if (!nodes_to_remove.empty()) {
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0]));
}
nodes_to_remove.clear();
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
for (const NodeIndex index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());

View file

@ -264,7 +264,7 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, batch_sizes, se
for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

View file

@ -211,19 +211,16 @@ def modelclass_dispatcher(model_name, custom_model_class):
return "AutoModel"
def load_pretrained_model(model_name, config, cache_dir, custom_model_class, if_tf_model=False):
def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
model_class_name = modelclass_dispatcher(model_name, custom_model_class)
if model_class_name == "GPT2ModelNoPastState":
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
if model_class_name == "GPT2ModelNoPastState":
if is_tf_model:
raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.")
else:
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
if if_tf_model:
if is_tf_model:
model_class_name = 'TF' + model_class_name
transformers_module = __import__("transformers", fromlist=[model_class_name])
@ -329,7 +326,7 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True)
model._saved_model_inputs_spec = None

View file

@ -2799,7 +2799,7 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Shape"], 1);
EXPECT_EQ(op_to_count["Gather"], 2);
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
EXPECT_EQ(op_to_count["ReduceSum"], 1);