From eb0f57f0e47120ce1d98797e2f0a1c15eb2cd050 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Tue, 4 Aug 2020 21:48:15 -0700 Subject: [PATCH] Localized Recompute for Gelu and AttentionDropout (#4402) * Gelu Activation Recompute Draft * Prototype for localized recompute * Introduce localized_recompute rewriter * Command line args for enabling recompute * Add logger to Gradient Graph Builder * use const when possible --- include/onnxruntime/core/graph/graph_viewer.h | 4 +- onnxruntime/core/framework/execution_frame.cc | 9 ++- onnxruntime/core/graph/graph_viewer.cc | 5 ++ .../core/framework/gradient_graph_builder.cc | 15 ++-- .../core/framework/gradient_graph_builder.h | 11 ++- .../core/graph/gradient_builder_base.h | 31 ++++++-- .../core/graph/gradient_builder_registry.cc | 11 +-- .../core/graph/gradient_builder_registry.h | 10 ++- .../orttraining/core/graph/gradient_config.h | 4 ++ ..._schema_defs.h => recompute_graph_utils.h} | 8 ++- .../core/optimizer/graph_transformer_utils.cc | 11 ++- .../core/optimizer/graph_transformer_utils.h | 3 +- .../core/optimizer/localized_recompute.cc | 71 +++++++++++++++++++ .../core/optimizer/localized_recompute.h | 50 +++++++++++++ .../core/session/training_session.cc | 24 +++---- .../core/session/training_session.h | 24 ++++--- orttraining/orttraining/models/bert/main.cc | 6 ++ .../models/runner/training_runner.cc | 12 +++- .../models/runner/training_runner.h | 5 +- .../python/orttraining_pybind_state.cc | 11 ++- .../test/gradient/gradient_checker.cc | 3 +- .../test/gradient/gradient_op_test_utils.cc | 3 +- 22 files changed, 256 insertions(+), 75 deletions(-) rename orttraining/orttraining/core/graph/{gradient_schema_defs.h => recompute_graph_utils.h} (53%) create mode 100644 orttraining/orttraining/core/optimizer/localized_recompute.cc create mode 100644 orttraining/orttraining/core/optimizer/localized_recompute.h diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index da2a37b7c2..9935bc275b 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -14,9 +14,7 @@ namespace onnxruntime { // use value-based compare to make sure transformer output order is consistent struct NodeCompare { - bool operator()(const Node* n1, const Node* n2) const { - return n1->Index() < n2->Index(); - } + bool operator()(const Node* n1, const Node* n2) const; }; /** diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 848b227692..4783c1f118 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -76,7 +76,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShap bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const { // By default, there is not information about inferred shape, so this default // implementation always returns false. The derived class of IExecutionFrame - // can override this function to provide, for example, activations' shape information. + // can override this function to provide, for example, activations' shape information. return false; } @@ -261,7 +261,7 @@ ExecutionFrame::ExecutionFrame(const std::vector& feed_mlvalue_idxs, const // it's less efficient (the arena will add some overhead to coalesce individual allocations // back into blocks on 'free'), but better than failing completely. try { - // static_activation_memory_in_bytes_ is max virtual memory size the planner computes + // static_activation_memory_in_bytes_ is max virtual memory size the planner computes auto peak_size = mem_patterns_->patterns[i].PeakSize(); // Planning of one memory type should only happen once. buffer = alloc->Alloc(peak_size); @@ -326,7 +326,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va // Lazily get the allocator only if needed. AllocatorPtr alloc = nullptr; - + // create fence if needed if (create_fence) { ORT_ENFORCE(ort_value.Fence() == nullptr); @@ -630,7 +630,6 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const { return planner_->GeneratePatterns(out); } - bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { // NodeArg index to OrtValue index. int ort_value_idx = GetNodeIdxToMLValueIdx(index); @@ -639,7 +638,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { if (ort_value_idx == NodeIndexInfo::kInvalidEntry) { return false; } - + // Search for inferred shape. // If inferred shape is found, it's assigned to "shape" so that caller can use it. auto it = inferred_shapes_.find(ort_value_idx); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index bb204e4d48..100e629a28 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -11,6 +11,11 @@ #include "core/graph/graph_utils.h" namespace onnxruntime { + +bool NodeCompare::operator()(const Node* n1, const Node* n2) const { + return n1->Index() < n2->Index(); +} + GraphViewer::GraphViewer(const Graph& graph) { graph_ = &graph; std::vector leaf_nodes; diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index fe44ae78b6..8f0e01a6bb 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -22,13 +22,13 @@ using namespace common; GradientGraphBuilder::GradientGraphBuilder(Graph* graph, const unordered_set& y_node_arg_names, const unordered_set& x_node_arg_names, - string loss_node_arg_name, + const std::string& loss_node_arg_name, const GradientGraphConfiguration& gradient_graph_config, - const bool set_gradient_as_graph_output) + const logging::Logger& logger) : graph_(graph), loss_node_arg_name_(loss_node_arg_name), gradient_graph_config_(gradient_graph_config), - set_gradient_as_graph_output_(set_gradient_as_graph_output) { + logger_(logger) { auto rule_based_graph_transformer = onnxruntime::make_unique("pre_training_rule_based_graph_transformer"); rule_based_graph_transformer->Register(make_unique()); @@ -91,7 +91,7 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) { for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) { auto it = STOP_GRADIENT_EDGES.find(n->OpType()); if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) { - std::cout << "Skip building gradient for node: " << edge_it->GetNode().Name() << std::endl; + LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ; continue; } @@ -126,7 +126,7 @@ Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nod } Status GradientGraphBuilder::Build() { - auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logging::LoggingManager::DefaultLogger()); + auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logger_); ORT_RETURN_IF_ERROR(opt_ret); GraphAugmenter::GraphDefs gradient_graph_defs; @@ -175,7 +175,6 @@ Status GradientGraphBuilder::Build() { // so far, visited are the minimum node in between // visited_node_args are the node_args involved - for (auto node : visited) { //TODO: might not need two sets, the union of them might be enough unordered_set input_args_need_grad, output_args_need_grad; @@ -190,7 +189,7 @@ Status GradientGraphBuilder::Build() { } } - GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad); + GradientDef node_defs = GetGradientForOp(gradient_graph_config_, graph_, node, output_args_need_grad, input_args_need_grad, logger_); // updates arg name if gradient accumulation is needed for (auto& op_def : node_defs) { @@ -218,7 +217,7 @@ Status GradientGraphBuilder::Build() { "AccumulateGrad_" + gradient_pair.first.name)}); } - if (set_gradient_as_graph_output_) { + if (gradient_graph_config_.set_gradients_as_graph_outputs) { for (auto x_node_arg : x_node_args_) { gradient_graph_defs.AddGraphOutputs({GradientBuilderBase::GradientName(x_node_arg->Name())}); } diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index a0e09f055c..a27c77649d 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -32,7 +32,7 @@ static std::unordered_map> {"Reshape", {1}}, {"Expand", {1}}, {"TrainableDropout", {1}}, - {"Dropout", {1}}, + {"Dropout", {1, 2}}, {"Slice", {1, 2, 3, 4}}, {"SparseSoftmaxCrossEntropy", {1, 2}}, {"SoftmaxCrossEntropyLoss", {1, 2}}, @@ -58,9 +58,9 @@ class GradientGraphBuilder { GradientGraphBuilder(Graph* graph, const std::unordered_set& y_node_arg_names, const std::unordered_set& x_node_arg_names, - std::string loss_node_arg_name, + const std::string& loss_node_arg_name, const GradientGraphConfiguration& gradient_graph_config, - const bool set_gradient_as_graph_output = false); + const logging::Logger& logger); Status Build(); @@ -77,6 +77,8 @@ class GradientGraphBuilder { const GradientGraphConfiguration& gradient_graph_config_; + const logging::Logger& logger_; + onnxruntime::GraphTransformerManager graph_transformation_mgr_{5}; // key: ArgDef for the gradient after accumulation @@ -104,9 +106,6 @@ class GradientGraphBuilder { @returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status */ Status CheckNodeArgsReachable(const NodeSet& reachable_nodes); - - // if it is true, set gradient of trainable weight as graph output - const bool set_gradient_as_graph_output_; }; } // namespace training diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index ebfc758d4a..eab707c04b 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -8,6 +8,7 @@ #include "core/graph/graph.h" #include "orttraining/core/graph/graph_augmenter.h" #include "orttraining/core/graph/gradient_config.h" +#include "orttraining/core/graph/recompute_graph_utils.h" #include "onnx/defs/attr_proto_util.h" namespace onnxruntime { @@ -35,12 +36,18 @@ typedef std::vector GradientDef; class GradientBuilderBase { public: - GradientBuilderBase( - const GradientGraphConfiguration& gradient_graph_config, - const Node* node, - const std::unordered_set& gradient_inputs, - const std::unordered_set& gradient_outputs) - : gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) { + GradientBuilderBase(const GradientGraphConfiguration& gradient_graph_config, + const Graph* graph, + const Node* node, + const std::unordered_set& gradient_inputs, + const std::unordered_set& gradient_outputs, + const logging::Logger& logger) + : gradient_graph_config_(gradient_graph_config), + graph_(graph), + node_(node), + gradient_inputs_(gradient_inputs), + gradient_outputs_(gradient_outputs), + logger_(logger) { unique_node_prefix_ = CreateUniqueNodePrefix(); } @@ -71,6 +78,15 @@ class GradientBuilderBase { // i-th input of forward op ArgDef I(const size_t i) const { ORT_ENFORCE(i < node_->InputDefs().size()); + + const std::string& name = node_->InputDefs()[i]->Name(); + const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name)); + if (recomputed_nodearg) { + const Node* producer_node = graph_->GetProducerNode(name); + LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name(); + return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto()); + } + return ArgDef(node_->InputDefs()[i]->Name(), node_->InputDefs()[i]->TypeAsProto()); } @@ -207,6 +223,7 @@ class GradientBuilderBase { } const GradientGraphConfiguration& gradient_graph_config_; + const Graph* graph_; const Node* node_; std::string unique_node_prefix_; @@ -215,6 +232,8 @@ class GradientBuilderBase { // contains set of input arg names of node_ which requires gradient std::unordered_set gradient_outputs_; + + const logging::Logger& logger_; }; class EmptyGradientBuilder : public GradientBuilderBase { diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 7631ce25eb..5f50887fa7 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -9,22 +9,25 @@ namespace onnxruntime { namespace training { GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Graph* graph, const Node* node, const std::unordered_set& output_args_need_grad, - const std::unordered_set& input_args_need_grad) { - + const std::unordered_set& input_args_need_grad, + const logging::Logger& logger) { // REVIEW(bahuang): We don't have a version control for forward to backward op mapping. // Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions // less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11. auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(), gradient_graph_config, + graph, node, output_args_need_grad, - input_args_need_grad); + input_args_need_grad, + logger); ORT_ENFORCE(gradient_builder != nullptr, - "The gradient builder has not been registered:", node->OpType()); + "The gradient builder has not been registered:", node->OpType(), " for node ", node->Name()); return gradient_builder->GetGradientDefs(); } diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.h b/orttraining/orttraining/core/graph/gradient_builder_registry.h index 7acec7e7f4..ba16d878b1 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.h +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.h @@ -13,9 +13,11 @@ namespace training { typedef GenericRegistry&, // gradient_inputs - const std::unordered_set&> // gradient_outputs + const std::unordered_set&, // gradient_outputs + const logging::Logger&> GradientRegistryType; class GradientBuilderRegistry : public GradientRegistryType { @@ -33,9 +35,11 @@ class GradientBuilderRegistry : public GradientRegistryType { }; GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Graph* graph, const Node* node, const std::unordered_set& output_args_need_grad, - const std::unordered_set& input_args_need_grad); + const std::unordered_set& input_args_need_grad, + const logging::Logger& logger); } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_config.h b/orttraining/orttraining/core/graph/gradient_config.h index bee896965e..13b86348fb 100644 --- a/orttraining/orttraining/core/graph/gradient_config.h +++ b/orttraining/orttraining/core/graph/gradient_config.h @@ -12,6 +12,10 @@ struct GradientGraphConfiguration { // To save memory, ideally, only one(input vs output) should be stashed rather than both. // By default, the input based algorithm is used. This flag is to enable the output based algorithm. bool use_invertible_layernorm_grad{false}; + + // If set to true, all gradients will be exposed as graph output. + // This should only be used for unit test or debugging purpose. + bool set_gradients_as_graph_outputs{false}; }; } // namespace training diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.h b/orttraining/orttraining/core/graph/recompute_graph_utils.h similarity index 53% rename from orttraining/orttraining/core/graph/gradient_schema_defs.h rename to orttraining/orttraining/core/graph/recompute_graph_utils.h index ca7e46d10b..f4d7e88a07 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.h +++ b/orttraining/orttraining/core/graph/recompute_graph_utils.h @@ -4,9 +4,11 @@ #pragma once namespace onnxruntime { -namespace training { +namespace graph_utils { -void RegisterGradientSchemas(); +inline std::string RecomputeName(const std::string& name) { + return name + "_recompute"; +} -} // namespace training +} // namespace graph_utils } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 9424a78c26..b49e144203 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -38,6 +38,7 @@ #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" +#include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" @@ -48,7 +49,7 @@ namespace transformer_utils { std::vector> GeneratePreTrainingTransformers( TransformerLevel level, const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, + const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; @@ -69,6 +70,12 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + if (config.gelu_checkpoint) { + rule_transformer->Register(make_unique()); + } + if (config.attn_dropout_checkpoint) { + rule_transformer->Register(make_unique()); + } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); @@ -76,7 +83,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); - if (enable_gelu_approximation) { + if (config.enable_gelu_approximation) { transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); } diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index 87a107e5df..c7b6cfe085 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -6,6 +6,7 @@ #include #include "core/optimizer/graph_transformer.h" +#include "orttraining/core/session/training_session.h" namespace onnxruntime { struct FreeDimensionOverride; @@ -17,7 +18,7 @@ namespace transformer_utils { std::vector> GeneratePreTrainingTransformers( TransformerLevel level, const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, + const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const std::vector& rules_and_transformers_to_enable = {}); /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.cc b/orttraining/orttraining/core/optimizer/localized_recompute.cc new file mode 100644 index 0000000000..7a52f69805 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/localized_recompute.cc @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "orttraining/core/graph/recompute_graph_utils.h" +#include "orttraining/core/optimizer/localized_recompute.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { + +bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { + const auto next_node = node.OutputNodesBegin(); + if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") { + return true; + } + return false; +} + +Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { + const auto& output = node.OutputDefs()[0]; + + auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto()); + + graph.AddNode(node.Name() + "_recompute", + node.OpType(), + "Recompute of " + node.Name(), + {node.MutableInputDefs()[0]}, + {&recomputed_output}, + &node.GetAttributes(), + node.Domain()); + + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; + return Status::OK(); +} + +bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { + const auto prev_node = node.InputNodesBegin(); + const auto next_node = node.OutputNodesBegin(); + if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" && + next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") { + return true; + } + return false; +} + +Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { + const auto& output = node.OutputDefs()[0]; + + auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto()); + + graph.AddNode(node.Name() + "_recompute", + "DropoutGrad", // Reusing DropoutGrad as the recompute op + "Recompute of " + node.Name(), + { + node.MutableInputDefs()[0], // X + node.MutableOutputDefs()[1], // mask + node.MutableInputDefs()[1], // ratio + node.MutableInputDefs()[2] // training_mode + }, + {&recomputed_output}, + {}, + kMSDomain); + + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.h b/orttraining/orttraining/core/optimizer/localized_recompute.h new file mode 100644 index 0000000000..1c8c58e96b --- /dev/null +++ b/orttraining/orttraining/core/optimizer/localized_recompute.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class GeluRecompute + +Recompute Gelu/BiasGelu/FastGelu + +*/ +class GeluRecompute : public RewriteRule { + public: + GeluRecompute() noexcept : RewriteRule("GeluRecompute") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Gelu", "FastGelu", "BiasGelu"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +/** +@Class AttentionDropoutRecompute + +Recompute Dropout in the attention layer + +*/ +class AttentionDropoutRecompute : public RewriteRule { + public: + AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Dropout"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 9d708d263b..e3f99e2bf9 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -216,7 +216,7 @@ Status TrainingSession::ConfigureForTraining( } } - ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.enable_gelu_approximation)); + ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config)); // derive actual set of weights to train std::unordered_set weight_names_to_train = @@ -237,7 +237,7 @@ Status TrainingSession::ConfigureForTraining( } ORT_RETURN_IF_ERROR(BuildGradientGraph( - weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs)); + weight_names_to_train, loss_name, config.gradient_graph_config, *session_logger_)); // transform for mixed precision std::unordered_map fp32_weight_name_to_fp16_node_arg{}; @@ -443,14 +443,14 @@ static Status BuildGradientGraphInternal(Graph& graph, const std::string& loss_function_output_name, const std::unordered_set& node_arg_names_to_train, const GradientGraphConfiguration& gradient_graph_config, - const bool set_gradient_as_graph_output = false) { + const logging::Logger& logger) { // Compute the gradient graph def. GradientGraphBuilder grad_graph_builder(&graph, {loss_function_output_name}, node_arg_names_to_train, loss_function_output_name, gradient_graph_config, - set_gradient_as_graph_output); + logger); return grad_graph_builder.Build(); } @@ -488,10 +488,10 @@ static Status AddGradientAccumulationNodes(Graph& graph, return GraphAugmenter::AugmentGraph(graph, graph_defs); } -Status TrainingSession::ApplyTransformationsToMainGraph( - const std::unordered_set& weights_to_train, bool enable_gelu_approximation) { +Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config) { GraphTransformerManager graph_transformation_mgr{1}; - AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, enable_gelu_approximation); + AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, config); // apply transformers Graph& graph = model_->MainGraph(); @@ -505,13 +505,13 @@ Status TrainingSession::ApplyTransformationsToMainGraph( // Registers all the pre transformers with transformer manager void TrainingSession::AddPreTrainingTransformers(GraphTransformerManager& transformer_manager, const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, + const TrainingConfiguration::GraphTransformerConfiguration& config, TransformerLevel graph_optimization_level, const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( - level, weights_to_train, enable_gelu_approximation, custom_list); + level, weights_to_train, config, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } @@ -664,7 +664,7 @@ Status TrainingSession::EnableMixedPrecision( Status TrainingSession::BuildGradientGraph(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, const GradientGraphConfiguration& gradient_graph_config, - const bool set_gradient_as_graph_output) { + const logging::Logger& logger) { // Fill weights_to_train_ according to weights_to_train weights_to_train_ = weights_to_train; gradient_graph_config_ = gradient_graph_config; @@ -673,7 +673,7 @@ Status TrainingSession::BuildGradientGraph(const std::unordered_set loss_function_output_name, weights_to_train_, gradient_graph_config_, - set_gradient_as_graph_output)); + logger)); return DoPostLoadProcessing(*model_); } @@ -791,7 +791,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO actual_loss_name, weights_to_train_, gradient_graph_config_, - false)); + *session_logger_)); OptimizerOutputKeyMap opt_graph_outputs; std::unordered_set opt_state_initializer_names; diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 41ebd04395..e8c57db5d3 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -46,9 +46,6 @@ class TrainingSession : public InferenceSession { // Gradient graph configuration GradientGraphConfiguration gradient_graph_config{}; - // Whether to set the gradients as graph outputs. - bool set_gradients_as_graph_outputs{false}; - // The number of gradient accumulation steps. int gradient_accumulation_steps{1}; @@ -177,8 +174,16 @@ class TrainingSession : public InferenceSession { // Otherwise, it returns false. optional pipeline_config{}; - // Whether to enable GELU approximation which is faster but produces different results. - bool enable_gelu_approximation{false}; + struct GraphTransformerConfiguration { + // Whether to enable GELU approximation which is faster but produces different results. + bool enable_gelu_approximation{false}; + // Enable checkpointing of attention dropout to save memory + bool attn_dropout_checkpoint{false}; + // Enable checkpointing of Gelu activation output to save memory + bool gelu_checkpoint{false}; + }; + + GraphTransformerConfiguration graph_transformer_config{}; }; /** @@ -390,13 +395,13 @@ class TrainingSession : public InferenceSession { std::string& backward_waited_event_after_recv_name, std::string& backward_recorded_event_before_send_name); - common::Status ApplyTransformationsToMainGraph( - const std::unordered_set& weights_to_train, bool enable_gelu_approximation); + common::Status ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config); /** configure initial transformers for training */ void AddPreTrainingTransformers(GraphTransformerManager& transformer_manager, const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, + const TrainingConfiguration::GraphTransformerConfiguration& config, TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel, const std::vector& custom_list = {}); @@ -408,12 +413,11 @@ class TrainingSession : public InferenceSession { /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. @param loss_function_output_name the name of the loss function's output. - @param set_gradient_as_graph_output if it is true, set gradient of trainable weight as graph output */ common::Status BuildGradientGraph(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, const GradientGraphConfiguration& gradient_graph_config, - const bool set_gradient_as_graph_output = false); + const logging::Logger& logger); common::Status BuildAccumulationNode(const std::unordered_set& weights_to_train); diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 8a3145d664..816d1018de 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -166,6 +166,10 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet cxxopts::value()->default_value("true")) ("enable_gelu_approximation", "Specify whether to enable GELU approximation.", cxxopts::value()->default_value("true")) + ("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.", + cxxopts::value()->default_value("false")) + ("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.", + cxxopts::value()->default_value("false")) ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", cxxopts::value()->default_value("false")); options @@ -453,6 +457,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet } params.enable_gelu_approximation = flags["enable_gelu_approximation"].as(); + params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as(); + params.gelu_checkpoint = flags["gelu_checkpoint"].as(); ort_params.log_severity = static_cast(flags["ort_log_severity"].as()); ORT_RETURN_IF_NOT( diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 55ab25a951..cd992d8c26 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -92,7 +92,7 @@ Status TrainingRunner::Initialize() { config.immutable_weights = params_.immutable_weights; config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad; - config.set_gradients_as_graph_outputs = false; + config.gradient_graph_config.set_gradients_as_graph_outputs = false; config.gradient_accumulation_steps = params_.gradient_accumulation_steps; @@ -164,7 +164,15 @@ Status TrainingRunner::Initialize() { config.pipeline_config = pipe; } - config.enable_gelu_approximation = params_.enable_gelu_approximation; + // always configure the graph transformer + { + TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{}; + gt_config.enable_gelu_approximation = params_.enable_gelu_approximation; + gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint; + gt_config.gelu_checkpoint = params_.gelu_checkpoint; + + config.graph_transformer_config = gt_config; + } TrainingSession::TrainingConfigurationResult config_result{}; diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index ebd25411dc..ae65db6db1 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -169,7 +169,10 @@ class TrainingRunner { // Enable GELU approximation bool enable_gelu_approximation = false; - + // Enable checkpointing of attention dropout to save memory + bool attn_dropout_checkpoint = false; + // Enable checkpointing of Gelu activation output to save memory + bool gelu_checkpoint = false; // Use invertible layernorm grad bool use_invertible_layernorm_grad = false; }; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 415e1c2c0e..13097092d8 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -14,7 +14,6 @@ #include "orttraining/core/framework/mpi_setup.h" #include "python/onnxruntime_pybind_mlvalue.h" - namespace onnxruntime { namespace python { namespace py = pybind11; @@ -91,8 +90,6 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.weight_names_to_not_train = parameters.weights_not_to_train; config.immutable_weights = parameters.immutable_weights; - config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; - config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; config.distributed_config.world_rank = parameters.world_rank; @@ -146,6 +143,7 @@ TrainingConfigurationResult ConfigureSessionForTraining( } config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad; + config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; training::TrainingSession::TrainingConfigurationResult config_result{}; @@ -194,9 +192,9 @@ void addObjectMethodsForTraining(py::module& m) { py::class_ training_session(m, "TrainingSession"); training_session.def(py::init([](const SessionOptions& so) { - Environment& env = get_env(); - return onnxruntime::make_unique(so, env); - })) + Environment& env = get_env(); + return onnxruntime::make_unique(so, env); + })) .def(py::init([]() { Environment& env = get_env(); return onnxruntime::make_unique(GetDefaultCPUSessionOptions(), env); @@ -274,7 +272,6 @@ void addObjectMethodsForTraining(py::module& m) { .def("is_output_fp32_node", [](onnxruntime::training::TrainingSession* sess, const std::string& output_name) { return sess->IsGraphOutputFp32Node(output_name); }); - } } // namespace python } // namespace onnxruntime diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 314b2e9bd5..464ff515a7 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -307,12 +307,13 @@ inline Status GradientChecker::InitOpTesterWithGradGraph( } training::GradientGraphConfiguration gradient_graph_config; + gradient_graph_config.set_gradients_as_graph_outputs = true; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", gradient_graph_config, - true); + logging::LoggingManager::DefaultLogger()); Status status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 048efc6078..7cb8afc01a 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -71,12 +71,13 @@ void GradientOpTester::Run( } training::GradientGraphConfiguration gradient_graph_config; + gradient_graph_config.set_gradients_as_graph_outputs = true; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", gradient_graph_config, - true); + logging::LoggingManager::DefaultLogger()); status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); }