diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.cc b/orttraining/orttraining/core/optimizer/localized_recompute.cc index df60be5625..0c5eb31c40 100644 --- a/orttraining/orttraining/core/optimizer/localized_recompute.cc +++ b/orttraining/orttraining/core/optimizer/localized_recompute.cc @@ -25,10 +25,12 @@ bool GeluRecompute::SatisfyCondition(const Node& node) const { Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { GraphViewer graph_viewer(graph); - const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); - for (NodeIndex i : order) { - Node& node = *graph.GetNode(i); + // Traverse backward from the bottom of the graph, so that the recompute nodes + // for lower layers are executed earlier + for (int i = static_cast(node_ids.size() - 1); i >= 0; --i) { + Node& node = *graph.GetNode(node_ids[i]); if (!SatisfyCondition(node)) { continue; @@ -70,10 +72,12 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const { Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { GraphViewer graph_viewer(graph); - const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); - for (NodeIndex i : order) { - Node& node = *graph.GetNode(i); + // Traverse backward from the bottom of the graph, so that the recompute nodes + // for lower layers are executed earlier + for (int i = static_cast(node_ids.size() - 1); i >= 0; --i) { + Node& node = *graph.GetNode(node_ids[i]); if (!SatisfyCondition(node)) { continue; diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 51f504982f..c8f5397086 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -764,7 +764,6 @@ def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers set_seed(seed) # Setup ORTTrainer - loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'graph_transformer' : { 'attn_dropout_recompute': attn_dropout, diff --git a/orttraining/tools/ci_test/run_batch_size_test.py b/orttraining/tools/ci_test/run_batch_size_test.py index a8cc92b785..fa5949d568 100755 --- a/orttraining/tools/ci_test/run_batch_size_test.py +++ b/orttraining/tools/ci_test/run_batch_size_test.py @@ -19,12 +19,26 @@ def parse_args(): def main(): args = parse_args() - Config = collections.namedtuple("Config", ["enable_mixed_precision", "sequence_length", "max_batch_size"]) + Config = collections.namedtuple("Config", ["enable_mixed_precision", + "sequence_length", + "max_batch_size", + "max_predictions_per_seq", + "additional_options"]) configs = [ - Config(True, 128, 66), - Config(True, 512, 10), - Config(False, 128, 33), - Config(False, 512, 5), + Config(True, 128, 76, 20, ""), + Config(True, 512, 11, 80, ""), + Config(False, 128, 39, 20, ""), + Config(False, 512, 6, 80, ""), + + # BertLarge Phase 1 recompute + Config(True, 128, 91, 20, "--gelu_recompute"), + Config(True, 128, 83, 20, "--attn_dropout_recompute"), + Config(True, 128, 344, 20, "--transformer_layer_recompute"), + + # BertLarge Phase 2 recompute + Config(True, 512, 12, 80, "--gelu_recompute"), + Config(True, 512, 14, 80, "--attn_dropout_recompute"), + Config(True, 512, 50, 80, "--transformer_layer_recompute"), ] # run BERT training @@ -52,6 +66,7 @@ def main(): "--use_nccl", "--seed", "42", "--enable_grad_norm_clip=false", + config.additional_options ] if config.enable_mixed_precision: