mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Fix a bug in EmbedLayerNorm fusion (#5150)
* fix embedlayernorm bug * review comments * interim checkin * review comments * Fix core dump in MacOS * remove unnecessary lines * update document * Update graph_utils.cc * Update onnx_exporter.py * resolve comments
This commit is contained in:
parent
aefb2cc49b
commit
65740deb10
5 changed files with 52 additions and 56 deletions
|
|
@ -753,19 +753,26 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
|
|||
}
|
||||
|
||||
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
||||
std::queue<const Node*> q;
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
q.push(&start_node);
|
||||
std::queue<NodeIndex> q;
|
||||
std::unordered_set<NodeIndex> removed_nodes;
|
||||
q.push(start_node.Index());
|
||||
|
||||
bool is_start_node(true);
|
||||
// From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output.
|
||||
while (q.size() != 0) {
|
||||
const Node& cur_node = *(q.front());
|
||||
while (!q.empty()) {
|
||||
NodeIndex cur_node_index = q.front();
|
||||
q.pop();
|
||||
|
||||
if (removed_nodes.find(cur_node_index) != removed_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
// Each eligible node in the subgraph must have less than one output edge and no output should be
|
||||
// the graph output
|
||||
const Node& cur_node = *graph.GetNode(cur_node_index);
|
||||
if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(cur_node.Index());
|
||||
|
||||
// push the parents of current node to the queue.
|
||||
for (unsigned int i = 0; i < cur_node.InputDefs().size(); ++i) {
|
||||
const std::string& input_name = GetNodeInputName(cur_node, i);
|
||||
|
|
@ -773,19 +780,25 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
|||
// skip initializers and graph inputs
|
||||
continue;
|
||||
}
|
||||
q.push(GetInputNode(cur_node, i));
|
||||
const Node* parent_node = GetInputNode(cur_node, i);
|
||||
q.push(parent_node->Index());
|
||||
}
|
||||
|
||||
if (is_start_node || cur_node.GetOutputEdgesCount() == 0) {
|
||||
Node* cur_node_p = graph.GetNode(cur_node_index);
|
||||
RemoveNodeOutputEdges(graph, *cur_node_p);
|
||||
graph.RemoveNode(cur_node_index);
|
||||
|
||||
removed_nodes.insert(cur_node_index);
|
||||
is_start_node = false;
|
||||
}
|
||||
}
|
||||
if (nodes_to_remove.size() <= 0) {
|
||||
|
||||
if (removed_nodes.size() == 0) {
|
||||
// Nothing to remove
|
||||
return false;
|
||||
}
|
||||
// Remove nodes that are not used anymore.
|
||||
for (const auto& node_index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -65,18 +65,6 @@ static bool CheckInput(NodeArg* input, const logging::Logger& logger) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static void AddNodes(std::vector<NodeIndex>& node_indices,
|
||||
const std::vector<const Node::EdgeEnd*>& edges) {
|
||||
for (size_t i = 0; i < edges.size(); i++) {
|
||||
auto item = edges[i]->GetNode().Index();
|
||||
// Avoid duplication.
|
||||
if (std::find(node_indices.begin(), node_indices.end(), item) != node_indices.end()) {
|
||||
continue;
|
||||
}
|
||||
node_indices.push_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector<std::string>& expected_types) {
|
||||
for (const std::string& expected_type : expected_types) {
|
||||
if (start == end || (*start).OpType().compare(expected_type) != 0) {
|
||||
|
|
@ -121,9 +109,7 @@ static bool MatchInputToConcatSubgraph(
|
|||
const NodeArg* input_ids,
|
||||
const int index,
|
||||
const logging::Logger& logger,
|
||||
std::vector<NodeIndex>& subgraph_node_indices,
|
||||
const NodeIndex expected_gather_node_1_index) {
|
||||
subgraph_node_indices.clear();
|
||||
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
|
||||
{0, index, "Concat", {4, 11}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
|
|
@ -157,8 +143,6 @@ static bool MatchInputToConcatSubgraph(
|
|||
return false;
|
||||
}
|
||||
|
||||
AddNodes(subgraph_node_indices, edges);
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> concat_parent_path{
|
||||
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
|
|
@ -211,7 +195,6 @@ static bool MatchInputToConcatSubgraph(
|
|||
}
|
||||
}
|
||||
|
||||
AddNodes(subgraph_node_indices, edges);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -238,9 +221,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
Graph& graph,
|
||||
const Node& position_gather_node,
|
||||
const NodeArg* input_ids,
|
||||
const logging::Logger& logger,
|
||||
std::vector<NodeIndex>& subgraph_node_indices) {
|
||||
subgraph_node_indices.clear();
|
||||
const logging::Logger& logger) {
|
||||
std::vector<const Node::EdgeEnd*> pg_edges;
|
||||
// Look for Path 1:
|
||||
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
|
||||
|
|
@ -348,8 +329,6 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
DEBUG_LOG("The parent of shape nodes are expected to be input_ids.");
|
||||
return false;
|
||||
}
|
||||
|
||||
subgraph_node_indices.push_back(shape_node_index);
|
||||
} else { // gather_output_edges_count == 2
|
||||
// Match optional Reshape -> Equal -> Where -> Expand
|
||||
// | |
|
||||
|
|
@ -373,20 +352,17 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
return false;
|
||||
}
|
||||
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node.
|
||||
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) {
|
||||
if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, gather_node.Index())) {
|
||||
DEBUG_LOG("Failed to match position subgraph.");
|
||||
return false;
|
||||
}
|
||||
AddNodes(subgraph_node_indices, pg_edges_2);
|
||||
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) {
|
||||
} else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, gather_node.Index())) {
|
||||
// Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node.
|
||||
DEBUG_LOG("Failed to match position subgraph.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
AddNodes(subgraph_node_indices, pg_edges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -438,11 +414,12 @@ static bool MatchPositionEmbeddingSubgraph(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) {
|
||||
if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
subgraph_node_indices.clear();
|
||||
subgraph_node_indices.push_back(position_gather_node.Index());
|
||||
return true;
|
||||
}
|
||||
|
|
@ -507,7 +484,6 @@ static void CreateEmbedLayernormNode(Graph& graph,
|
|||
NodeArg* word_embedding,
|
||||
NodeArg* position_embedding,
|
||||
NodeArg* segment_embedding,
|
||||
|
||||
Node& layer_norm_node) {
|
||||
// Cast input_ids and segment_ids to int32 if needed.
|
||||
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
|
||||
|
|
@ -528,7 +504,6 @@ static void CreateEmbedLayernormNode(Graph& graph,
|
|||
position_embedding,
|
||||
segment_embedding,
|
||||
layer_norm_node.MutableInputDefs()[1],
|
||||
|
||||
layer_norm_node.MutableInputDefs()[2]};
|
||||
|
||||
auto& mask_index = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mask_index"), nullptr);
|
||||
|
|
@ -705,6 +680,12 @@ static bool FuseSubGraph(Graph& graph,
|
|||
CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding,
|
||||
layer_norm_node);
|
||||
|
||||
if (!nodes_to_remove.empty()) {
|
||||
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0]));
|
||||
}
|
||||
|
||||
nodes_to_remove.clear();
|
||||
|
||||
nodes_to_remove.push_back(word_gather_node.Index());
|
||||
nodes_to_remove.push_back(segment_gather_node.Index());
|
||||
nodes_to_remove.push_back(add_node.Index());
|
||||
|
|
@ -712,7 +693,7 @@ static bool FuseSubGraph(Graph& graph,
|
|||
nodes_to_remove.push_back(layer_norm_add_node.Index());
|
||||
nodes_to_remove.push_back(layer_norm_node.Index());
|
||||
|
||||
for (const auto& index : nodes_to_remove) {
|
||||
for (const NodeIndex index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(index);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
|
|
@ -725,7 +706,6 @@ static bool FuseSubGraph(Graph& graph,
|
|||
static bool FuseSubGraphDistilBert(Graph& graph,
|
||||
Node& layer_norm_add_node,
|
||||
Node& layer_norm_node,
|
||||
|
||||
const logging::Logger& logger) {
|
||||
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
|
|
@ -796,12 +776,18 @@ static bool FuseSubGraphDistilBert(Graph& graph,
|
|||
CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr,
|
||||
layer_norm_node);
|
||||
|
||||
if (!nodes_to_remove.empty()) {
|
||||
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0]));
|
||||
}
|
||||
|
||||
nodes_to_remove.clear();
|
||||
|
||||
nodes_to_remove.push_back(word_gather_node.Index());
|
||||
nodes_to_remove.push_back(add_node.Index());
|
||||
|
||||
nodes_to_remove.push_back(layer_norm_node.Index());
|
||||
|
||||
for (const auto& index : nodes_to_remove) {
|
||||
for (const NodeIndex index : nodes_to_remove) {
|
||||
Node* node = graph.GetNode(index);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, batch_sizes, se
|
|||
for model_name in model_names:
|
||||
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True)
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -211,19 +211,16 @@ def modelclass_dispatcher(model_name, custom_model_class):
|
|||
return "AutoModel"
|
||||
|
||||
|
||||
def load_pretrained_model(model_name, config, cache_dir, custom_model_class, if_tf_model=False):
|
||||
def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
|
||||
model_class_name = modelclass_dispatcher(model_name, custom_model_class)
|
||||
|
||||
if model_class_name == "GPT2ModelNoPastState":
|
||||
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
|
||||
|
||||
if model_class_name == "GPT2ModelNoPastState":
|
||||
if is_tf_model:
|
||||
raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.")
|
||||
else:
|
||||
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
|
||||
if if_tf_model:
|
||||
if is_tf_model:
|
||||
model_class_name = 'TF' + model_class_name
|
||||
|
||||
transformers_module = __import__("transformers", fromlist=[model_class_name])
|
||||
|
|
@ -329,7 +326,7 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
|
|||
|
||||
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True)
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True)
|
||||
|
||||
model._saved_model_inputs_spec = None
|
||||
|
||||
|
|
|
|||
|
|
@ -2799,7 +2799,7 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
|
|||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["Cast"], 2);
|
||||
EXPECT_EQ(op_to_count["Shape"], 0);
|
||||
EXPECT_EQ(op_to_count["Shape"], 1);
|
||||
EXPECT_EQ(op_to_count["Gather"], 2);
|
||||
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
|
||||
EXPECT_EQ(op_to_count["ReduceSum"], 1);
|
||||
|
|
|
|||
Loading…
Reference in a new issue