diff --git a/onnxruntime/core/framework/orttraining_partial_executor.cc b/onnxruntime/core/framework/orttraining_partial_executor.cc index cab6a5ba27..6cd7636373 100644 --- a/onnxruntime/core/framework/orttraining_partial_executor.cc +++ b/onnxruntime/core/framework/orttraining_partial_executor.cc @@ -233,6 +233,22 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve // construct OpKernelContext // TODO: log kernel inputs? OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, false); + + // Cache lookup. Currently we only cache single-output nodes, + // to keep memory overhead impact in check. Hence we only look in cache + // if the current node has one output. + bool reuse_cached_value = false; + std::string cached_arg_name; + if (cache_ != nullptr) { + if (p_op_kernel->Node().OutputDefs().size() == 1) { + cached_arg_name = p_op_kernel->Node().OutputDefs()[0]->Name(); + if (cache_.get()->count(cached_arg_name)) { // found arg in cache_ + VLOGS(logger, 1) << "Found OrtValue in cache for arg: " << cached_arg_name; + reuse_cached_value = true; + } + } + } + // TODO: log kernel outputs? if (is_profiler_enabled) { sync_time_begin = session_state.Profiler().StartTime(); @@ -312,7 +328,11 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve ORT_RETURN_IF_ERROR(utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context)); } #endif - compute_status = p_op_kernel->Compute(&op_kernel_context); + if (!reuse_cached_value) { + compute_status = p_op_kernel->Compute(&op_kernel_context); + } else { + compute_status = op_kernel_context.SetOutputMLValue(0, cache_.get()->at(cached_arg_name)); + } } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { diff --git a/onnxruntime/core/framework/orttraining_partial_executor.h b/onnxruntime/core/framework/orttraining_partial_executor.h index d97396c972..0a17a0937b 100644 --- a/onnxruntime/core/framework/orttraining_partial_executor.h +++ b/onnxruntime/core/framework/orttraining_partial_executor.h @@ -19,8 +19,10 @@ namespace onnxruntime { class PartialExecutor : public IExecutor { public: - PartialExecutor(PartialGraphExecutionState& state) - : state_{state} {} + PartialExecutor(PartialGraphExecutionState& state, + const OrtValueCachePtr& cache) + : state_{state}, + cache_{cache} {} common::Status Execute(const SessionState& session_state, const std::vector& feed_mlvalue_idxs, const std::vector& feeds, const std::vector& fetch_mlvalue_idxs, @@ -31,6 +33,7 @@ class PartialExecutor : public IExecutor { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PartialExecutor); PartialGraphExecutionState& state_; + const OrtValueCachePtr& cache_; }; } // namespace onnxruntime #endif \ No newline at end of file diff --git a/onnxruntime/core/framework/partial_graph_execution_state.h b/onnxruntime/core/framework/partial_graph_execution_state.h index 10768bae3b..7ce701cf16 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.h +++ b/onnxruntime/core/framework/partial_graph_execution_state.h @@ -8,6 +8,10 @@ #include "core/framework/execution_frame.h" namespace onnxruntime { + +typedef std::unordered_map OrtValueCache; +typedef std::shared_ptr OrtValueCachePtr; + struct PartialGraphExecutionState { public: PartialGraphExecutionState() { @@ -28,7 +32,7 @@ struct PartialGraphExecutionState { const SessionState& session_state) { if (execution_frame_ == nullptr) { execution_frame_ = std::make_unique(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, - fetch_allocators, session_state); + fetch_allocators, session_state); } else { execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds); execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors()); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 022182062e..8af2b89af9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -631,10 +631,12 @@ common::Status ExecuteGraph(const SessionState& session_state, #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, const std::vector& feeds, std::vector& fetches, - const logging::Logger& logger, PartialGraphExecutionState& state) { + const logging::Logger& logger, PartialGraphExecutionState& state, + const OrtValueCachePtr& cache) { + // finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches); - PartialExecutor executor{state}; + PartialExecutor executor{state, cache}; const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo(); const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks(); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 229d7e4eed..c9bd4cedab 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -94,7 +94,8 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, const std::vector& feeds, std::vector& fetches, - const logging::Logger& logger, PartialGraphExecutionState& state); + const logging::Logger& logger, PartialGraphExecutionState& state, + const OrtValueCachePtr& cache); #endif // Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8b7bcc8c1..84a1ce5602 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1650,7 +1650,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, - FeedsFetchesManager& feeds_fetches_manager) { + FeedsFetchesManager& feeds_fetches_manager, + const OrtValueCachePtr& cache) { Status retval = Status::OK(); std::vector exec_providers_to_stop; exec_providers_to_stop.reserve(execution_providers_.NumProviders()); @@ -1691,7 +1692,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // execute the graph ORT_CHECK_AND_SET_RETVAL(utils::ExecutePartialGraph(*session_state_, feeds_fetches_manager, feeds, fetches, - run_logger, state)); + run_logger, state, cache)); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 32abe60b98..7345f6c683 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -316,12 +316,15 @@ class InferenceSession { * @param state State of the graph needed to resume partial graph run. * @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device * copy/checks. + * @param cache Contains node arg name to OrtValue map stashed from previous run + * for frontier tensors */ common::Status PartialRun(onnxruntime::RunOptions& run_options, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, - FeedsFetchesManager& feeds_fetches_manager); + FeedsFetchesManager& feeds_fetches_manager, + const OrtValueCachePtr& cache); #endif /** diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index e2b60528e7..88cb317eae 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -50,25 +50,26 @@ TrainingAgent::TrainingAgent(InferenceSession& session, TrainingAgent::~TrainingAgent() = default; common::Status TrainingAgent::RunForward(const std::vector& feeds, std::vector& fetches, - PartialGraphExecutionState& state) { + PartialGraphExecutionState& state, const OrtValueCachePtr& cache) { state.SetProgramCounterStart(0); state.SetProgramCounterEnd(fw_program_counter_end_); - return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_); + return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache); } common::Status TrainingAgent::RunBackward(const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) { state.SetProgramCounterStart(fw_program_counter_end_); state.SetProgramCounterEnd(bw_program_counter_end_); - return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_); + return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr); } common::Status TrainingAgent::RunCore(const std::vector& feeds, std::vector& fetches, - PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager) { + PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, + const OrtValueCachePtr& cache) { auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size(); fetches.resize(fetches_size, {}); RunOptions run_options; - return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager); + return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager, cache); } void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& session_state, diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index 63f4794a97..5645cc1059 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -25,14 +25,15 @@ class TrainingAgent { ~TrainingAgent(); // For ORTModule.forward() common::Status RunForward(const std::vector& feeds, std::vector& fetches, - PartialGraphExecutionState& state) ORT_MUST_USE_RESULT; + PartialGraphExecutionState& state, const OrtValueCachePtr& cache) ORT_MUST_USE_RESULT; // For ORTModule.backward() common::Status RunBackward(const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) ORT_MUST_USE_RESULT; common::Status RunCore(const std::vector& feeds, std::vector& fetches, - PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager) + PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, + const OrtValueCachePtr& cache) ORT_MUST_USE_RESULT; void CreateAndInitializeFeedsFetchesManager(const SessionState& session_state, diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index 31c2ff4f3c..a805c25b04 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -263,15 +263,15 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() { // YieldOps non_differentiable_outputs attribute specifies the indices of outputs that are not differentiable const auto& non_differentiable_indices = graph_info_.output_grad_indices_non_differentiable; + const std::string non_differentiable_outputs_name = "non_differentiable_outputs"; + ONNX_NAMESPACE::AttributeProto non_differentiable_outputs; + non_differentiable_outputs.set_name(non_differentiable_outputs_name); + non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); + if (non_differentiable_indices.size() > 0) { - ONNX_NAMESPACE::AttributeProto non_differentiable_outputs; - const std::string non_differentiable_outputs_name = "non_differentiable_outputs"; - non_differentiable_outputs.set_name(non_differentiable_outputs_name); - non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); for (auto index : non_differentiable_indices) { non_differentiable_outputs.add_ints(index); } - attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs}); } // YieldOps full_shape_outputs attribute specifies the indices of outputs that must be full shape. @@ -303,6 +303,26 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() { graph_info_.module_output_gradient_name.emplace_back(grad_name); } } + + size_t input_count = yield_input_node_args.size(); + for (auto& iter : graph_info_.frontier_node_arg_map) { + std::string name = iter.second; + yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name)); + graph_info_.cached_node_arg_names.emplace_back(name); + } + + const auto& frontier_tensors = graph_info_.frontier_node_arg_map; + if (frontier_tensors.size() > 0) { + for (size_t index = input_count; index < input_count + frontier_tensors.size(); index++) { + non_differentiable_outputs.add_ints(index); + } + } + + // YieldOps non_differentiable_outputs /attribute specifies the indices of outputs that are not differentiable + if (non_differentiable_indices.size() > 0 || frontier_tensors.size() > 0) { + attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs}); + } + attributes.insert({full_shape_outputs_name, full_shape_outputs}); // Handle potential duplciated output_gradient names diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h index bfddb5b8e0..aefa558699 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h @@ -61,8 +61,11 @@ struct GraphInfo { std::vector module_output_indices_requires_save_for_backward{}; // Names of module outputs' gradient std::vector module_output_gradient_name{}; - + // Names of the frontier tensor corresponding to param std::unordered_map frontier_node_arg_map{}; + // Names of the frontier NodeArgs in the order in which they will + // be retrieved in the forward pass + std::vector cached_node_arg_names{}; }; class OrtModuleGraphBuilder { diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 8eff5f0869..54f313cd5b 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -18,6 +18,7 @@ #include "python/onnxruntime_pybind_mlvalue.h" PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(onnxruntime::OrtValueCache); namespace onnxruntime { namespace python { @@ -329,6 +330,29 @@ void addObjectMethodsForTraining(py::module& m) { return py::reinterpret_steal(ToDlpack(v->at(idx))); }); + py::class_(m, "OrtValueCache") + .def(py::init<>()) + .def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) { + cache_ptr->emplace(node_arg_name, value); + }) + .def("keys", [](const OrtValueCachePtr& cache_ptr) { + py::list keys; + for(auto kv : *cache_ptr.get()) { + keys.append(kv.first); + } + return keys; + }) + .def("clear", [](const OrtValueCachePtr& cache_ptr) { + cache_ptr->clear(); + }) + .def("count", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) { + return cache_ptr->count(node_arg_name); + }) + .def("remove", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) { + const auto& num_entries_erased = cache_ptr->erase(node_arg_name); + ORT_ENFORCE(num_entries_erased == 1, "NodeArg not found in cache: ", node_arg_name); + }); + py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) @@ -530,8 +554,8 @@ void addObjectMethodsForTraining(py::module& m) { return std::make_unique(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info); })) - .def("run_forward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state) -> void { - Status status = agent->RunForward(feeds, fetches, *state); + .def("run_forward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void { + Status status = agent->RunForward(feeds, fetches, *state, cache); if (!status.IsOK()) { throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage()); } @@ -619,6 +643,7 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape) .def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward) .def_readwrite("frontier_node_arg_map", &GraphInfo::frontier_node_arg_map) + .def_readwrite("cached_node_arg_names", &GraphInfo::cached_node_arg_names) .def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name); py::class_ ortmodule_graph_builder(m, "OrtModuleGraphBuilder"); diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 878b0738b2..a5eef94e05 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -116,14 +116,15 @@ class TrainingAgent(object): self._training_agent = C_TrainingAgent(self._inference_session._sess, fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info) - def run_forward(self, feeds, fetches, state): + def run_forward(self, feeds, fetches, state, cache=None): """ Compute the forward subgraph for given feeds and fetches. :param feeds: Inputs to the graph run. :param fetches: Outputs of the graph run. :param state: State of the graph that is used for executing partial graph runs. + :param cache: Cache to store stashed OrtValues for intermediate activations. """ - self._training_agent.run_forward(feeds, fetches, state) + self._training_agent.run_forward(feeds, fetches, state, cache) def run_backward(self, feeds, fetches, state): """ diff --git a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py new file mode 100644 index 0000000000..22fce2fe18 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py @@ -0,0 +1,86 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from . import _utils +from onnxruntime.capi import _pybind_state as C + + +class GradientAccumulationManager(object): + """Handles Gradient accumulation optimization during training + + This feature must be enabled once before training and cannot be turned off within a training run. + """ + # TODO: enable switching the feature on/off in the middle of the training + + def __init__(self): + self.cache = None + self._param_name_value_map = None + self._param_version_map = None + self._frontier_node_arg_map = None + self._enabled = False + self._update_cache = False + + def initialize(self, enabled, module, graph_info) -> None: + """Initializes Gradient Accumulation optimization. + + Args: + enabled (bool): Whether the optimization is enabled or disabled. + module (torch.nn.Module): Training model + graph_info (GraphInfo): The ORT Graph Info object holding information about backend graph. + """ + if enabled: + self._enabled = True + self.cache = C.OrtValueCache() + # Since named_parameters() is a generator function, need to avoid overhead and + # populate the params in memory to avoid generating the param map every + # step. This will not work if the user adds or removes params between steps + self._param_name_value_map = { + name: param for name, param in module.named_parameters()} + self._param_version_map = dict() + self._frontier_node_arg_map = graph_info.frontier_node_arg_map + self._cached_node_arg_names = graph_info.cached_node_arg_names + self._cache_start = len(graph_info.user_output_names) + + @property + def enabled(self): + """Indicates whether gradient accumulation optimization is enabled. + """ + return self._enabled + + def extract_outputs_and_maybe_update_cache(self, forward_outputs): + """Extract the user outputs from the forward outputs as torch tensor and update cache, if needed + + Args: + forward_outputs (OrtValueVector): List of outputs returned by forward function + """ + if not self.enabled: + return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs) + if self._update_cache: + for i in range(self._cache_start, len(forward_outputs)): + self.cache.insert( + self._cached_node_arg_names[i-self._cache_start], forward_outputs[i]) + self._update_cache = False + return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start)) + + def maybe_update_cache_before_run(self): + """Update cache when model parameters are modified and optimization is enabled. + """ + # The current implementation relies on param._version, which might not be + # updated in all cases(eg. inplace update) + # TODO: Make detection of parameter update robust + if not self.enabled: + return + + # parse param versions to detect change or no change + for name, arg_name in self._frontier_node_arg_map.items(): + param = self._param_name_value_map[name] + if name not in self._param_version_map: + self._param_version_map[name] = param._version + elif param._version != self._param_version_map[name]: + # there is an updated param, so remove entry from cache + # in order to recompute the value + if self.cache.count(arg_name): + self.cache.remove(arg_name) + self._update_cache = True + self._param_version_map[name] = param._version diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 1afa9e8c51..70c8b4cb60 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -14,6 +14,7 @@ from ._fallback import (_FallbackManager, ORTModuleONNXModelException, ORTModuleTorchModelException, wrap_exception) +from ._gradient_accumulation_manager import GradientAccumulationManager from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION from onnxruntime.capi import _pybind_state as C @@ -152,6 +153,7 @@ class GraphExecutionManager(GraphExecutionInterface): # WIP feature to enable caching in Gradient accumulation scenario. self._enable_grad_acc_optimization = False + self._gradient_accumulation_manager = GradientAccumulationManager() # Memory aware gradient builder. self._use_memory_efficient_gradient = False diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 48041c2424..5d6bf4aaa0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -28,7 +28,7 @@ class TrainingManager(GraphExecutionManager): self._export_mode = torch.onnx.TrainingMode.TRAINING @staticmethod - def execution_session_run_forward(execution_session, onnx_model, *inputs): + def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, @@ -43,8 +43,8 @@ class TrainingManager(GraphExecutionManager): forward_outputs = C.OrtValueVector() # Run and return module outputs. - execution_session.run_forward(forward_inputs, forward_outputs, state) - user_outputs = tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs) + execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) + user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) @@ -113,6 +113,10 @@ class TrainingManager(GraphExecutionManager): if self._debug_options.logging.log_level <= _logger.LogLevel.WARNING: warnings.warn("Fast path enabled - skipping checks for rebuilding gradient graph, execution agent creation, and device during training.", UserWarning) + + self._gradient_accumulation_manager.initialize(self._enable_grad_acc_optimization, self._flattened_module, self._graph_info) + + self._gradient_accumulation_manager.maybe_update_cache_before_run() class _ORTModuleFunction(torch.autograd.Function): '''Use a custom torch.autograd.Function to associate self.backward_graph as the @@ -135,6 +139,7 @@ class TrainingManager(GraphExecutionManager): user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent, self._onnx_models.optimized_model, + self._gradient_accumulation_manager, *inputs) # Disable materializing grads then None object will not be @@ -146,8 +151,12 @@ class TrainingManager(GraphExecutionManager): # ORT is NOT relying on save_for_backward() to actually save the tensor, # as this tensor is also kept in ORT's PartialGraphState # This call is to invoke pytorch's version check to detect the potential inplace corruption + # If ORT is caching tensors, the module_output_indices_requires_save_for_backward field + # might also have indices of cached tensors that are not passed over to pytorch, and they don't + # need marking with save_for_backward() for idx in self._graph_info.module_output_indices_requires_save_for_backward: - ctx.save_for_backward(user_outputs[idx]) + if idx < len(self._graph_info.user_output_names): + ctx.save_for_backward(user_outputs[idx]) # Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info # This will return torch the output tensors with correct requires_grad settings @@ -273,7 +282,8 @@ class TrainingManager(GraphExecutionManager): C.OrtDevice(get_ort_device_type(self._device.type), C.OrtDevice.default_memory(), _utils.get_device_index(self._device) - )] * len(self._graph_info.user_output_names) + )] * (len(self._graph_info.user_output_names) + + len(self._graph_info.frontier_node_arg_map)) bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output] bw_outputs_device_info = [ diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d722c93b4b..7d1fe3da9c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -3140,6 +3140,59 @@ def test_debug_options_log_level_validation_fails_on_type_mismatch(): _ = DebugOptions(log_level=log_level) assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value) +def test_ortmodule_gradient_accumulation_optimization_correctness(): + class NeuralNetWithCast(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNetWithCast, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = NeuralNetWithCast(D_in, H, D_out).to(device) + + # baseline model with optimization disabled + tgt_model = ORTModule(pt_model) + tgt_optimizer = torch.optim.Adam(tgt_model.parameters()) + + # model with optimization enabled + opt_model = ORTModule(copy.deepcopy(pt_model)) + opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True + opt_optimizer = torch.optim.Adam(opt_model.parameters()) + + def run_step(model, x): + with amp.autocast(): + prediction = model(x) + loss = prediction.sum() + loss.backward() + return loss.detach() + + def run_optim_step(optimizer): + optimizer.step() + optimizer.zero_grad() + + GA_steps = 2 + tgt_model.zero_grad() + opt_model.zero_grad() + + for step in range(10): + x = torch.randn(N, D_in, device=device) + tgt_loss = run_step(tgt_model, x) + opt_loss = run_step(opt_model, x) + + # assert that loss values match + _test_helpers.assert_values_are_close(tgt_loss, opt_loss) + if step % GA_steps == 0: + run_optim_step(tgt_optimizer) + run_optim_step(opt_optimizer) + def test_ortmodule_dict_input(): class DictNet(torch.nn.Module): def __init__(self): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index c73eec9509..ddc524a665 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -42,6 +42,12 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch) model.train() + # Always clear any previously calculated gradients before performing a + # backward pass. PyTorch doesn't do this automatically because + # accumulating the gradients is "convenient while training RNNs". + # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch) + optimizer.zero_grad() + # For each batch of training data... for step, batch in enumerate(train_dataloader): @@ -61,12 +67,6 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, b_input_mask = batch[1].to(device) b_labels = batch[2].to(device) - # Always clear any previously calculated gradients before performing a - # backward pass. PyTorch doesn't do this automatically because - # accumulating the gradients is "convenient while training RNNs". - # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch) - optimizer.zero_grad() - # Perform a forward pass (evaluate the model on this training batch). # This will return the loss (rather than the model output) because we have provided the `labels`. # The documentation for this `model` function is here: @@ -108,14 +108,16 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, # This is to help prevent the "exploding gradients" problem. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - # Update parameters and take a step using the computed gradient. - # The optimizer dictates the "update rule"--how the parameters are - # modified based on their gradients, the learning rate, etc. - scaler.step(optimizer) - scaler.update() + if step % args.grad_acc_steps == 0: + # Update parameters and take a step using the computed gradient. + # The optimizer dictates the "update rule"--how the parameters are + # modified based on their gradients, the learning rate, etc. + scaler.step(optimizer) + scaler.update() - # Update the learning rate. - scheduler.step() + # Update the learning rate. + scheduler.step() + optimizer.zero_grad() # Calculate the average loss over the training data. avg_train_loss = total_loss / len(train_dataloader) @@ -320,6 +322,8 @@ def main(): help='disables ONNX Runtime training') parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') + parser.add_argument('--do-val', action='store_true', default=False, + help='disables validation') parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 64)') parser.add_argument('--view-graphs', action='store_true', default=False, @@ -340,6 +344,8 @@ def main(): help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)') parser.add_argument('--data-dir', type=str, default='./cola_public/raw', help='Path to the bert data directory') + parser.add_argument('--grad-acc-steps', type=int, default=2, + help='Number of steps for accumulating gradients') args = parser.parse_args() @@ -376,23 +382,24 @@ def main(): config=config, ) + # Note: AdamW is a class from the huggingface library (as opposed to pytorch) + optimizer = torch.optim.AdamW(model.parameters(), + lr = 2e-2, # args.learning_rate - default is 5e-5, our notebook had 2e-5 + eps = 1e-8 # args.adam_epsilon - default is 1e-8. + ) + if not args.pytorch_only: # Just for future debugging debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast') model = ORTModule(model, debug_options) - - model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True + model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: model.cuda() - # Note: AdamW is a class from the huggingface library (as opposed to pytorch) - optimizer = AdamW(model.parameters(), - lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5 - eps = 1e-8 # args.adam_epsilon - default is 1e-8. - ) + # Authors recommend between 2 and 4 epochs # Total number of training steps is number of batches * number of epochs. @@ -418,10 +425,12 @@ def main(): total_training_time += train(model, optimizer, scaler, scheduler, train_dataloader, epoch_i, device, args) if not args.pytorch_only and epoch_i == 0: epoch_0_training = total_training_time - test_time, validation_accuracy = test(model, validation_dataloader, device, args) - total_test_time += test_time + if args.do_val: + test_time, validation_accuracy = test(model, validation_dataloader, device, args) + total_test_time += test_time - assert validation_accuracy > 0.5 + if args.do_val: + assert validation_accuracy > 0.5 print('\n======== Global stats ========') if not args.pytorch_only: