From e71668f92c1ec7be50fdcd05ce4eebf59dce2af8 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Fri, 2 Oct 2020 09:49:47 -0700 Subject: [PATCH] Expose recompute configs to the frontend (#5318) * Expose recompute configs to the frontend * Add frontend test * Ensure recompute graph transformer is only applied once Co-authored-by: Sherlock Huang --- .../core/optimizer/graph_transformer.h | 2 + onnxruntime/__init__.py | 2 +- .../core/optimizer/graph_transformer_mgr.cc | 3 + .../python/onnxruntime_pybind_state.cc | 6 +- .../core/optimizer/graph_transformer_utils.cc | 15 ++-- .../core/optimizer/localized_recompute.cc | 69 ++++++++++++++----- .../core/optimizer/localized_recompute.h | 30 ++++---- .../optimizer/transformer_layer_recompute.cc | 22 +++++- .../optimizer/transformer_layer_recompute.h | 10 ++- .../core/session/training_session.h | 2 + orttraining/orttraining/models/bert/main.cc | 3 + .../models/runner/training_runner.cc | 1 + .../models/runner/training_runner.h | 2 + .../python/orttraining_pybind_state.cc | 17 ++++- .../orttraining/python/training/orttrainer.py | 9 +++ .../python/training/orttrainer_options.py | 59 ++++++++++++++++ .../orttraining_test_orttrainer_frontend.py | 61 ++++++++++++++++ 17 files changed, 262 insertions(+), 51 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer.h b/include/onnxruntime/core/optimizer/graph_transformer.h index a806c483fb..865a9f3c94 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer.h +++ b/include/onnxruntime/core/optimizer/graph_transformer.h @@ -39,6 +39,8 @@ class GraphTransformer { */ common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const; + virtual bool ShouldOnlyApplyOnce() const { return false; } + protected: /** Helper method to call ApplyImpl on any subgraphs in the Node. */ common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const { diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index f464da3729..21090d268f 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -12,7 +12,7 @@ __author__ = "Microsoft" from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \ RunOptions, SessionOptions, set_default_logger_severity, NodeArg, ModelMetadata, GraphOptimizationLevel, \ - ExecutionMode, OrtDevice, SessionIOBinding + ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding try: from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index f161fcf0cb..ce17708f32 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -28,6 +28,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor for (unsigned step = 0; step < steps_; ++step) { bool graph_changed = false; for (const auto& transformer : transformers->second) { + if (step > 0 && transformer->ShouldOnlyApplyOnce()) + continue; + bool modified = false; ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); graph_changed = graph_changed || modified; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9b55c8683a..0be1a573b2 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -937,6 +937,10 @@ void addObjectMethods(py::module& m, Environment& env) { .value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL) .value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL); + py::enum_(m, "ExecutionOrder") + .value("DEFAULT", ExecutionOrder::DEFAULT) + .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED); + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc"); device.def(py::init()) .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") @@ -1089,7 +1093,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc") .def_readwrite("execution_mode", &PySessionOptions::execution_mode, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") - .def_readwrite("execution_order", &SessionOptions::execution_order, + .def_readwrite("execution_order", &PySessionOptions::execution_order, R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc") .def_property( "graph_optimization_level", diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 34f095dd2a..4c515e20b4 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -74,12 +74,6 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); - if (config.gelu_recompute) { - rule_transformer->Register(make_unique()); - } - if (config.attn_dropout_recompute) { - rule_transformer->Register(make_unique()); - } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); @@ -106,8 +100,15 @@ std::vector> GeneratePreTrainingTransformers( } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + if (config.gelu_recompute) { + transformers.emplace_back(onnxruntime::make_unique()); + } + if (config.attn_dropout_recompute) { + transformers.emplace_back(onnxruntime::make_unique()); + } if (config.transformer_layer_recompute) { - transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + transformers.emplace_back(onnxruntime::make_unique( + config.number_recompute_layers, compatible_eps)); } } break; diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.cc b/orttraining/orttraining/core/optimizer/localized_recompute.cc index 02164fa0dc..df60be5625 100644 --- a/orttraining/orttraining/core/optimizer/localized_recompute.cc +++ b/orttraining/orttraining/core/optimizer/localized_recompute.cc @@ -10,7 +10,12 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { -bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { +bool GeluRecompute::SatisfyCondition(const Node& node) const { + static const std::unordered_set target_optypes = {"Gelu", "FastGelu", "BiasGelu"}; + if (target_optypes.find(node.OpType()) == target_optypes.end()) { + return false; + } + const auto next_node = node.OutputNodesBegin(); if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") { return true; @@ -18,27 +23,42 @@ bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, c return false; } -Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { - const auto& output = node.OutputDefs()[0]; +Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); - auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), - output->TypeAsProto()); + for (NodeIndex i : order) { + Node& node = *graph.GetNode(i); - Node& recompute_node = graph.AddNode(node.Name() + "_recompute", - node.OpType(), - "Recompute of " + node.Name(), - {node.MutableInputDefs()[0]}, - {&recomputed_output}, - &node.GetAttributes(), - node.Domain()); + if (!SatisfyCondition(node)) { + continue; + } - recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + const auto& output = node.OutputDefs()[0]; + + auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto()); + + Node& recompute_node = graph.AddNode(node.Name() + "_recompute", + node.OpType(), + "Recompute of " + node.Name(), + {node.MutableInputDefs()[0]}, + {&recomputed_output}, + &node.GetAttributes(), + node.Domain()); + + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + + modified = true; + } - rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; return Status::OK(); } -bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { +bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const { + if (node.OpType() != "Dropout") + return false; + const auto prev_node = node.InputNodesBegin(); const auto next_node = node.OutputNodesBegin(); if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" && @@ -48,11 +68,22 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N return false; } -Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { - Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true); - recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); +Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); - rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; + for (NodeIndex i : order) { + Node& node = *graph.GetNode(i); + + if (!SatisfyCondition(node)) { + continue; + } + + Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true); + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + + modified = true; + } return Status::OK(); } diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.h b/orttraining/orttraining/core/optimizer/localized_recompute.h index 1c8c58e96b..d0b91b3b65 100644 --- a/orttraining/orttraining/core/optimizer/localized_recompute.h +++ b/orttraining/orttraining/core/optimizer/localized_recompute.h @@ -3,7 +3,7 @@ #pragma once -#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/graph_transformer.h" namespace onnxruntime { @@ -13,18 +13,16 @@ namespace onnxruntime { Recompute Gelu/BiasGelu/FastGelu */ -class GeluRecompute : public RewriteRule { +class GeluRecompute : public GraphTransformer { public: - GeluRecompute() noexcept : RewriteRule("GeluRecompute") {} + GeluRecompute() noexcept : GraphTransformer("GeluRecompute") {} - std::vector TargetOpTypes() const noexcept override { - return {"Gelu", "FastGelu", "BiasGelu"}; - } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool ShouldOnlyApplyOnce() const override { return true; } private: - bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; + bool SatisfyCondition(const Node& node) const; }; /** @@ -33,18 +31,16 @@ class GeluRecompute : public RewriteRule { Recompute Dropout in the attention layer */ -class AttentionDropoutRecompute : public RewriteRule { +class AttentionDropoutRecompute : public GraphTransformer { public: - AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {} + AttentionDropoutRecompute() noexcept : GraphTransformer("AttentionDropoutRecompute") {} - std::vector TargetOpTypes() const noexcept override { - return {"Dropout"}; - } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool ShouldOnlyApplyOnce() const override { return true; } private: - bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; + bool SatisfyCondition(const Node& node) const; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc index c43192a129..c608b7289a 100644 --- a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc +++ b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc @@ -116,7 +116,7 @@ void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::ve Node* node = graph.GetNode(n->Index()); // recomputed Dropout need to produce the same output as original dropout - // currently reusing original dropout's mask to achieve this + // currently reusing original dropout's mask to achieve this if (node->OpType() == "Dropout") { const NodeArg* input = node->InputDefs()[0]; const Node* p_node = graph.GetProducerNode(input->Name()); @@ -175,9 +175,25 @@ Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /* return Status::OK(); } - // insert recompute nodes expect for the last transformer layer + // by default, apply recompute expect for the last transformer layer + // otherwise, take user specified 'number_recompute_layers_' + + int n_layers; + const int n_layers_limit = static_cast(start_end_edges.size() - 1); + if (number_recompute_layers_ > n_layers_limit) { + LOGS(logger, WARNING) << "User specified number_recompute_layers " << number_recompute_layers_ + << " is larger than limit " << n_layers_limit << "." + << "number_recompute_layers is now cliped to limit."; + n_layers = n_layers_limit; + } else if (number_recompute_layers_ > 0) { + n_layers = number_recompute_layers_; + } else { + LOGS(logger, INFO) << "number_recompute_layers is not set by user, using default " << n_layers_limit << "."; + n_layers = n_layers_limit; + } + // latter recompute layers have higher execution priorty - for (size_t i = 0; i < start_end_edges.size() - 1; ++i) { + for (int i = 0; i < n_layers; ++i) { std::vector nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second); InsertRecomputeNodes(graph, nodes, static_cast(start_end_edges.size() - i)); } diff --git a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h index f3bb00ff20..bad47e5b51 100644 --- a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h +++ b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h @@ -10,11 +10,15 @@ namespace onnxruntime { class TransformerLayerRecompute : public GraphTransformer { public: - TransformerLayerRecompute(const std::unordered_set& compatible_execution_providers = {}) noexcept - : GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {} + TransformerLayerRecompute(int number_recompute_layers, + const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("TransformerLayerRecompute", compatible_execution_providers), + number_recompute_layers_(number_recompute_layers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + bool ShouldOnlyApplyOnce() const override { return true; } + private: Status IdentifyTransformerLayerEdges(const Graph& graph, std::vector>& start_end_edges, @@ -23,6 +27,8 @@ class TransformerLayerRecompute : public GraphTransformer { std::vector NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const; void InsertRecomputeNodes(Graph& graph, const std::vector& nodes, int priority) const; + + int number_recompute_layers_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 3fc97a306a..58eb1d526c 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -200,6 +200,8 @@ class TrainingSession : public InferenceSession { bool gelu_recompute{false}; // Enable recompute of transformer layer ouput to save memory bool transformer_layer_recompute{false}; + // Number of layers to apply recompute + int number_recompute_layers{0}; }; GraphTransformerConfiguration graph_transformer_config{}; diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 91a1aa1d88..0e24c4d9a8 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -173,6 +173,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet cxxopts::value()->default_value("false")) ("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.", cxxopts::value()->default_value("false")) + ("number_recompute_layers", "Number of layers to apply recompute.", + cxxopts::value()->default_value("0")) ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", cxxopts::value()->default_value("false")); options @@ -463,6 +465,7 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet params.attn_dropout_recompute = flags["attn_dropout_recompute"].as(); params.gelu_recompute = flags["gelu_recompute"].as(); params.transformer_layer_recompute = flags["transformer_layer_recompute"].as(); + params.number_recompute_layers = flags["number_recompute_layers"].as(); ort_params.log_severity = static_cast(flags["ort_log_severity"].as()); ORT_RETURN_IF_NOT( diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 85fab1c5e1..89e5aa6af1 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -187,6 +187,7 @@ Status TrainingRunner::Initialize() { gt_config.attn_dropout_recompute = params_.attn_dropout_recompute; gt_config.gelu_recompute = params_.gelu_recompute; gt_config.transformer_layer_recompute = params_.transformer_layer_recompute; + gt_config.number_recompute_layers = params_.number_recompute_layers; config.graph_transformer_config = gt_config; } diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 28024f1f6c..04b6f5b211 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -180,6 +180,8 @@ class TrainingRunner { bool gelu_recompute = false; // Enable checkpointing of transformer layer output to save memory bool transformer_layer_recompute = false; + // Number of layers to apply recompute + int number_recompute_layers = 0; // Use invertible layernorm grad bool use_invertible_layernorm_grad = false; }; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index e7524fd0dd..757c75a861 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -47,6 +47,12 @@ struct TrainingParameters { bool enable_grad_norm_clip = true; bool set_gradients_as_graph_outputs = false; bool use_invertible_layernorm_grad = false; + + // recompute + bool attn_dropout_recompute = false; + bool gelu_recompute = false; + bool transformer_layer_recompute = false; + int number_recompute_layers = 0; }; struct TrainingConfigurationResult { @@ -130,6 +136,11 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad; config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; + config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute; + config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute; + config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute; + config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers; + training::TrainingSession::TrainingConfigurationResult config_result{}; OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); @@ -186,7 +197,11 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage) .def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip) .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs) - .def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad); + .def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad) + .def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute) + .def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute) + .def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute) + .def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers); #if defined(USE_NCCL) m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); }); diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 755173dffd..bbb24e1c9d 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -633,9 +633,18 @@ class ORTTrainer(object): ort_parameters.optimizer_attributes_map = optimizer_attributes_map ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map + ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute + ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute + ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute + ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers + # SessionOptions session_options = ort.SessionOptions() session_options.use_deterministic_compute = self.options.debug.deterministic_compute + if (self.options.graph_transformer.attn_dropout_recompute or + self.options.graph_transformer.gelu_recompute or + self.options.graph_transformer.transformer_layer_recompute): + session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED # TrainingSession self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 0ed9a68f5f..bde8a60290 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -116,6 +116,30 @@ class ORTTrainerOptions(object): } } }, + 'graph_transformer': { + 'type': 'dict', + 'required': False, + 'default': {}, + 'schema': { + 'attn_dropout_recompute': { + 'type': 'boolean', + 'default': False + }, + 'gelu_recompute': { + 'type': 'boolean', + 'default': False + }, + 'transformer_layer_recompute': { + 'type': 'boolean', + 'default': False + }, + 'number_recompute_layers': { + 'type': 'integer', + 'min': 0, + 'default': 0 + } + } + }, 'utils' : { 'type' : 'dict', 'required': False, @@ -221,6 +245,17 @@ class ORTTrainerOptions(object): Users can also instantiate :py:class:`.DynamicLossScaler` and override its parameters. Lastly, a completely new implementation can be specified by extending :py:class:`.LossScaler` class from scratch + graph_transformer (dict): + graph transformer related configurations + attn_dropout_recompute (bool, default is False): + enable recomputing attention dropout to save memory + gelu_recompute (bool, default is False): + enable recomputing Gelu activation output to save memory + transformer_layer_recompute (bool, default is False): + enable recomputing transformer layerwise to save memory + number_recompute_layers (int, default is 0) + number of layers to apply transformer_layer_recompute, by default system will + apply recompute to all the layers, except for the last one utils (dict): miscellaneous options utils.frozen_weights (list of str, []): @@ -435,6 +470,30 @@ _ORTTRAINER_OPTIONS_SCHEMA = { } } }, + 'graph_transformer': { + 'type': 'dict', + 'default_setter': lambda _: {}, + 'required': False, + 'schema': { + 'attn_dropout_recompute': { + 'type': 'boolean', + 'default': False + }, + 'gelu_recompute': { + 'type': 'boolean', + 'default': False + }, + 'transformer_layer_recompute': { + 'type': 'boolean', + 'default': False + }, + 'number_recompute_layers': { + 'type': 'integer', + 'min': 0, + 'default': 0 + } + } + }, 'utils': { 'type': 'dict', 'default_setter': lambda _: {}, diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 1039a72703..51f504982f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -94,6 +94,12 @@ def testORTTrainerOptionsDefaultValues(test_input): 'enabled': False, 'loss_scaler': None }, + 'graph_transformer': { + 'attn_dropout_recompute': False, + 'gelu_recompute': False, + 'transformer_layer_recompute': False, + 'number_recompute_layers': 0 + }, 'utils': { 'frozen_weights': [], 'grad_norm_clip': True, @@ -728,6 +734,61 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches) assert trainer._onnx_model is not None +def _recompute_data(): + device_capability_major = torch.cuda.get_device_capability()[0] + if device_capability_major == 7: # V100 for Dev machine + expected_loss = [10.577394, 10.444777, 10.425666, 10.299958, 10.290016] + return [ + (False, False, False, 0, expected_loss), # no recompute + (True, False, False, 0, expected_loss), # attn_dropout recompute + (False, True, False, 0, expected_loss), # gelu recompute + (False, False, True, 0, expected_loss), # transformer_layer recompute + (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer + ] + elif device_capability_major == 5: # M60 for CI machines + expected_loss = [10.56341 , 10.461096, 10.364473, 10.297504, 10.249142] + return [ + (False, False, False, 0, expected_loss), # no recompute + (True, False, False, 0, expected_loss), # attn_dropout recompute + (False, True, False, 0, expected_loss), # gelu recompute + (False, False, True, 0, expected_loss), # transformer_layer recompute + (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer + ] +@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) +def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): + seed = 321 + device = 'cuda' + rtol = 1e-3 + total_steps = len(expected_loss) + torch.manual_seed(seed) + set_seed(seed) + + # Setup ORTTrainer + loss_scaler = amp.DynamicLossScaler() + options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, + 'graph_transformer' : { + 'attn_dropout_recompute': attn_dropout, + 'gelu_recompute': gelu, + 'transformer_layer_recompute': transformer_layer, + 'number_recompute_layers': number_layers + }, + 'debug' : {'deterministic_compute' : True}}) + model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) + optim_config = optim.LambConfig(lr=0.001) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) + + # Training loop + actual_loss = [] + for i in range(total_steps): + data, targets = batcher_fn(train_data, i) + loss, _ = trainer.train_step(data, targets) + actual_loss.append(loss.cpu()) + + # Compare loss to ground truth computed from current ORTTrainer API + _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol) + assert trainer._onnx_model is not None + + @pytest.mark.parametrize("seed,device,gradient_accumulation_steps,total_steps,expected_loss", [ (0, 'cuda', 1, 12, [10.5368022919, 10.4146203995, 10.3635568619, 10.2650547028, 10.2284049988, 10.1304626465,\ 10.0853414536, 9.9987659454, 9.9472427368, 9.8832416534, 9.8223171234, 9.8222122192]),