Fix transformer layer detection for recompute (#20106)

### Fix transformer layer detection for recompute

Originally logic miss detecting the layer boudary node in Mistral model.
This PR simplifies the searching, by using more strong pattern's match,
to make sure it is flexible enough to cover different transformer
variants.

Also add a UT.

Add a warning when user enable layerwise recompute but no layer boudary
nodes are found.
This commit is contained in:
pengwa 2024-03-29 17:44:38 +08:00 committed by GitHub
parent 2a184ac1a1
commit 2092bebc78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 194 additions and 50 deletions

View file

@ -257,10 +257,15 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
is_forward_nodes,
logger));
InlinedHashSet<const Node*> layer_boundary_ln_nodes;
InlinedVector<const Node*> layer_boundary_ln_nodes;
FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map,
yield_op_order_in_topological_sort, layer_boundary_ln_nodes);
if (probe_config.enable_transformer_layer_as_boundary && layer_boundary_ln_nodes.size() == 0) {
LOGS(logger, WARNING) << "No transformer layer boundary nodes found, this might cause memory optimization "
"not working as expected. Please check the model and the configuration.";
}
// The first pass - find the candidate subgraphs.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
const Node* p_node = graph_viewer.GetNode(node_ids[i]);

View file

@ -386,6 +386,26 @@ const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& GetAllowedRecompu
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
{
@ -691,7 +711,7 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const InlinedVector<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) {
@ -709,13 +729,14 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
auto output_name = node.OutputDefs()[output_index]->Name();
auto consumers = graph_viewer.GetConsumerNodes(output_name);
for (auto& consumer : consumers) {
if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) {
if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), consumer) !=
layer_boundary_ln_nodes.end()) {
int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]);
if (dest_in_index == 0) {
LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType()
<< ") is a Attention+MLP layer boundary node, "
<< "its stashed activation outputs are used by LayerNormalization's inputs, "
<< "we don't need to recompute it.";
MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() +
") is a Attention+MLP layer boundary node, " +
"its stashed activation outputs are used by LayerNormalization's inputs, " +
"we don't need to recompute it.");
return nullptr;
}
}

View file

@ -164,7 +164,7 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const InlinedVector<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation);

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include <charconv>
#include <tuple>
#include <vector>
#include <utility>
@ -16,43 +17,139 @@
namespace onnxruntime::optimizer::memory_optimizer {
namespace {
bool IsLayerNormNode(const Node& node) {
static const std::set<std::string> layer_norm_ops = {
"LayerNormalization",
"SkipLayerNormalization",
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
};
return layer_norm_ops.find(node.OpType()) != layer_norm_ops.end();
}
bool IsSoftmaxNode(const Node& node) {
static const std::set<std::string> softmax_ops = {
"Softmax",
"BiasSoftmax",
};
return softmax_ops.find(node.OpType()) != softmax_ops.end();
}
std::tuple<bool, const Node*, const Node*> IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) {
auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name());
if (2 > consumers.size()) {
return std::make_tuple(false, nullptr, nullptr);
}
// Find the Add node from the consumer list.
const Node* add_node = nullptr;
const Node* other_node = nullptr;
for (const auto* consumer : consumers) {
if (consumer->OpType() == "Add") {
add_node = consumer;
} else if (IsLayerNormNode(*consumer)) {
other_node = consumer;
}
}
return std::make_tuple(add_node != nullptr && other_node != nullptr, add_node, other_node);
}
} // namespace
/*
One classical layer includes 1). input layer norm, 2). attention, 3). residual add
(input layer norm input + attention output), 4). post attention layer norm feedforward, and 5). residual add
(post attention layer norm input + feedforward out).
The pattern graph looks like below for each transformer layer (taking the example of MistralDecoderLayer):
|
Embedding
|
----------------------|
| |
| |
| SimplifiedLayerNormalization (layer boudary node)
| |
| |
| MistralAttention
| |
| |
|____________________Add
|
----------------------|
| |
| |
| SimplifiedLayerNormalization
| |
| |
| MultipleLayerPerception
| |
| |
|____________________Add
|
(new layer)
----------------------|
| |
| SimplifiedLayerNormalization
....
*/
void FindLayerBoundaryLayerNormNodes(
const GraphViewer& graph_viewer,
const logging::Logger&,
const logging::Logger& logger,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes) {
InlinedVector<const Node*>& layer_boundary_ln_nodes) {
// Loop all nodes to find LayerNormalization nodes.
// For each LayerNormalization node, keep checking its output nodes,
// until find a node that is Softmax or BiasSoftmax or another LayerNormalization.
// If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION.
// If the found node is another LayerNormalization, the LayerNormalization node as MLP.
const InlinedHashSet<std::string_view> softmax_ops{"Softmax", "BiasSoftmax"};
const InlinedHashSet<std::string_view> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"};
layer_boundary_ln_nodes.clear();
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
for (auto node_index : node_topology_list) {
auto& node = *graph_viewer.GetNode(node_index);
if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) {
if (!IsLayerNormNode(node)) {
continue;
}
const NodeArg* input_arg = node.InputDefs()[0];
// IsResidualNodeArg checks input_arg
auto [is_residual_node_arg, add_node, other_node] = IsResidualNodeArg(graph_viewer, input_arg);
if (!is_residual_node_arg) {
MO_LOG_DEBUG_INFO(logger, "Not a residual node arg " + input_arg->Name());
continue;
}
// At this point, there should not be any recompute node, so we don't need check the node existence in
// node_index_to_its_order_in_topological_sort_map.
ptrdiff_t attention_residual_add_node_order =
node_index_to_its_order_in_topological_sort_map.at(add_node->Index());
ptrdiff_t attention_residual_ln_node_order =
node_index_to_its_order_in_topological_sort_map.at(other_node->Index());
if (attention_residual_add_node_order >= yield_op_order_in_topological_sort ||
attention_residual_ln_node_order >= yield_op_order_in_topological_sort) {
MO_LOG_DEBUG_INFO(logger, "Not a valid residual node arg " + input_arg->Name());
continue;
}
// Search all forward nodes that is before `add_node` in topo order, unless we find a softmax or
// nodes_to_check is empty.
std::deque<const Node*> nodes_to_check;
std::set<const Node*> visited_nodes;
for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) {
// Ignore those nodes after YieldOp.
if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) {
auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index());
if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) {
nodes_to_check.push_back(&(*node_it));
}
}
bool unexpected_failure = false;
bool found_softmax = false;
bool found_layernorm = false;
ptrdiff_t next_layernorm_execution_oder = -1;
while (!nodes_to_check.empty()) {
const Node* next_node = nodes_to_check.front();
nodes_to_check.pop_front();
@ -62,41 +159,21 @@ void FindLayerBoundaryLayerNormNodes(
}
visited_nodes.insert(next_node);
if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) {
found_softmax = true;
} else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
if (found_layernorm) {
// If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection.
unexpected_failure = true;
break;
}
found_layernorm = true; // don't trace further
next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index());
continue;
} else {
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
// Stop if the node is after next Layernorm node in execution order.
if (found_layernorm &&
node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) {
continue;
}
if (IsSoftmaxNode(*next_node)) {
MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " +
input_arg->Name());
layer_boundary_ln_nodes.push_back(&node);
break;
}
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
// Stop if the node is after next Layernorm node in execution order.
auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index());
if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) {
nodes_to_check.push_back(&(*node_it));
}
}
}
if (unexpected_failure) {
layer_boundary_ln_nodes.clear();
break;
}
if (found_softmax) {
layer_boundary_ln_nodes.insert(&node);
} else if (!found_layernorm) {
// If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node,
// we also consider it as boundary node.
layer_boundary_ln_nodes.insert(&node);
}
}
}

View file

@ -23,6 +23,6 @@ void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes);
InlinedVector<const Node*>& layer_boundary_ln_nodes);
} // namespace onnxruntime::optimizer::memory_optimizer

View file

@ -29,6 +29,7 @@
#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h"
#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -312,5 +313,45 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) {
}
}
TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger));
Graph& graph = model->MainGraph();
GraphViewer graph_viewer(graph);
InlinedHashMap<NodeIndex, ptrdiff_t> node_index_to_its_order_in_topological_sort_map;
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
// Find boundary ops between forward and backward pass, currently, it's limited to YieldOp.
ptrdiff_t yield_op_order_in_topological_sort = -1;
for (size_t i = 0; i < node_ids.size(); ++i) {
const Node* p_node = graph_viewer.GetNode(node_ids[i]);
if (p_node == nullptr) { /* skip removed nodes*/
continue;
}
if (p_node->OpType() == "YieldOp") {
// There are multiple YieldOps in the graph。
ASSERT_EQ(yield_op_order_in_topological_sort, -1);
yield_op_order_in_topological_sort = static_cast<ptrdiff_t>(i);
}
node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast<ptrdiff_t>(i);
}
InlinedVector<const Node*> layer_boundary_ln_node;
optimizer::memory_optimizer::FindLayerBoundaryLayerNormNodes(graph_viewer, *logger,
node_index_to_its_order_in_topological_sort_map,
yield_op_order_in_topological_sort,
layer_boundary_ln_node);
ASSERT_EQ(layer_boundary_ln_node.size(), 3);
ASSERT_EQ(layer_boundary_ln_node[0]->Name(), "LayerNormalization_token_0");
ASSERT_EQ(layer_boundary_ln_node[1]->Name(), "LayerNormalization_token_6");
ASSERT_EQ(layer_boundary_ln_node[2]->Name(), "LayerNormalization_token_12");
}
} // namespace test
} // namespace onnxruntime