mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Gradient Accumulation optimization verified for correctness (#8273)
* Fetching frontier tensors to frontend * Move before session initialize call * Fetch tensor and add to cache * Rest of the changes for using cache * Review comments * Review changes * Review comments * switch to shared_ptr * Fix bug after rebase * FE docstring change
This commit is contained in:
parent
224380448d
commit
cc275e7529
18 changed files with 300 additions and 55 deletions
|
|
@ -233,6 +233,22 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
|
|||
// construct OpKernelContext
|
||||
// TODO: log kernel inputs?
|
||||
OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, false);
|
||||
|
||||
// Cache lookup. Currently we only cache single-output nodes,
|
||||
// to keep memory overhead impact in check. Hence we only look in cache
|
||||
// if the current node has one output.
|
||||
bool reuse_cached_value = false;
|
||||
std::string cached_arg_name;
|
||||
if (cache_ != nullptr) {
|
||||
if (p_op_kernel->Node().OutputDefs().size() == 1) {
|
||||
cached_arg_name = p_op_kernel->Node().OutputDefs()[0]->Name();
|
||||
if (cache_.get()->count(cached_arg_name)) { // found arg in cache_
|
||||
VLOGS(logger, 1) << "Found OrtValue in cache for arg: " << cached_arg_name;
|
||||
reuse_cached_value = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: log kernel outputs?
|
||||
if (is_profiler_enabled) {
|
||||
sync_time_begin = session_state.Profiler().StartTime();
|
||||
|
|
@ -312,7 +328,11 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
|
|||
ORT_RETURN_IF_ERROR(utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context));
|
||||
}
|
||||
#endif
|
||||
compute_status = p_op_kernel->Compute(&op_kernel_context);
|
||||
if (!reuse_cached_value) {
|
||||
compute_status = p_op_kernel->Compute(&op_kernel_context);
|
||||
} else {
|
||||
compute_status = op_kernel_context.SetOutputMLValue(0, cache_.get()->at(cached_arg_name));
|
||||
}
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@
|
|||
namespace onnxruntime {
|
||||
class PartialExecutor : public IExecutor {
|
||||
public:
|
||||
PartialExecutor(PartialGraphExecutionState& state)
|
||||
: state_{state} {}
|
||||
PartialExecutor(PartialGraphExecutionState& state,
|
||||
const OrtValueCachePtr& cache)
|
||||
: state_{state},
|
||||
cache_{cache} {}
|
||||
|
||||
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
|
|
@ -31,6 +33,7 @@ class PartialExecutor : public IExecutor {
|
|||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PartialExecutor);
|
||||
PartialGraphExecutionState& state_;
|
||||
const OrtValueCachePtr& cache_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
@ -8,6 +8,10 @@
|
|||
#include "core/framework/execution_frame.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
typedef std::unordered_map<std::string, OrtValue> OrtValueCache;
|
||||
typedef std::shared_ptr<OrtValueCache> OrtValueCachePtr;
|
||||
|
||||
struct PartialGraphExecutionState {
|
||||
public:
|
||||
PartialGraphExecutionState() {
|
||||
|
|
@ -28,7 +32,7 @@ struct PartialGraphExecutionState {
|
|||
const SessionState& session_state) {
|
||||
if (execution_frame_ == nullptr) {
|
||||
execution_frame_ = std::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
|
||||
fetch_allocators, session_state);
|
||||
fetch_allocators, session_state);
|
||||
} else {
|
||||
execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds);
|
||||
execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
|
||||
|
|
|
|||
|
|
@ -631,10 +631,12 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
#ifdef ENABLE_TRAINING
|
||||
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state) {
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state,
|
||||
const OrtValueCachePtr& cache) {
|
||||
|
||||
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
|
||||
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
|
||||
PartialExecutor executor{state};
|
||||
PartialExecutor executor{state, cache};
|
||||
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
|
||||
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,8 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag
|
|||
#ifdef ENABLE_TRAINING
|
||||
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state);
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state,
|
||||
const OrtValueCachePtr& cache);
|
||||
#endif
|
||||
|
||||
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
|
||||
|
|
|
|||
|
|
@ -1650,7 +1650,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
|
|||
const std::vector<OrtValue>& feeds,
|
||||
std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state,
|
||||
FeedsFetchesManager& feeds_fetches_manager) {
|
||||
FeedsFetchesManager& feeds_fetches_manager,
|
||||
const OrtValueCachePtr& cache) {
|
||||
Status retval = Status::OK();
|
||||
std::vector<IExecutionProvider*> exec_providers_to_stop;
|
||||
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
|
||||
|
|
@ -1691,7 +1692,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
|
|||
|
||||
// execute the graph
|
||||
ORT_CHECK_AND_SET_RETVAL(utils::ExecutePartialGraph(*session_state_, feeds_fetches_manager, feeds, fetches,
|
||||
run_logger, state));
|
||||
run_logger, state, cache));
|
||||
}
|
||||
ORT_CATCH(const std::exception& e) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
|
|
|
|||
|
|
@ -316,12 +316,15 @@ class InferenceSession {
|
|||
* @param state State of the graph needed to resume partial graph run.
|
||||
* @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device
|
||||
* copy/checks.
|
||||
* @param cache Contains node arg name to OrtValue map stashed from previous run
|
||||
* for frontier tensors
|
||||
*/
|
||||
common::Status PartialRun(onnxruntime::RunOptions& run_options,
|
||||
const std::vector<OrtValue>& feeds,
|
||||
std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state,
|
||||
FeedsFetchesManager& feeds_fetches_manager);
|
||||
FeedsFetchesManager& feeds_fetches_manager,
|
||||
const OrtValueCachePtr& cache);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -50,25 +50,26 @@ TrainingAgent::TrainingAgent(InferenceSession& session,
|
|||
TrainingAgent::~TrainingAgent() = default;
|
||||
|
||||
common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state) {
|
||||
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) {
|
||||
state.SetProgramCounterStart(0);
|
||||
state.SetProgramCounterEnd(fw_program_counter_end_);
|
||||
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_);
|
||||
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache);
|
||||
}
|
||||
|
||||
common::Status TrainingAgent::RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state) {
|
||||
state.SetProgramCounterStart(fw_program_counter_end_);
|
||||
state.SetProgramCounterEnd(bw_program_counter_end_);
|
||||
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_);
|
||||
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr);
|
||||
}
|
||||
|
||||
common::Status TrainingAgent::RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager) {
|
||||
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const OrtValueCachePtr& cache) {
|
||||
auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size();
|
||||
fetches.resize(fetches_size, {});
|
||||
RunOptions run_options;
|
||||
return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager);
|
||||
return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager, cache);
|
||||
}
|
||||
|
||||
void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& session_state,
|
||||
|
|
|
|||
|
|
@ -25,14 +25,15 @@ class TrainingAgent {
|
|||
~TrainingAgent();
|
||||
// For ORTModule.forward()
|
||||
common::Status RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state) ORT_MUST_USE_RESULT;
|
||||
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) ORT_MUST_USE_RESULT;
|
||||
|
||||
// For ORTModule.backward()
|
||||
common::Status RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state) ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager)
|
||||
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const OrtValueCachePtr& cache)
|
||||
ORT_MUST_USE_RESULT;
|
||||
|
||||
void CreateAndInitializeFeedsFetchesManager(const SessionState& session_state,
|
||||
|
|
|
|||
|
|
@ -263,15 +263,15 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() {
|
|||
|
||||
// YieldOps non_differentiable_outputs attribute specifies the indices of outputs that are not differentiable
|
||||
const auto& non_differentiable_indices = graph_info_.output_grad_indices_non_differentiable;
|
||||
const std::string non_differentiable_outputs_name = "non_differentiable_outputs";
|
||||
ONNX_NAMESPACE::AttributeProto non_differentiable_outputs;
|
||||
non_differentiable_outputs.set_name(non_differentiable_outputs_name);
|
||||
non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
|
||||
|
||||
if (non_differentiable_indices.size() > 0) {
|
||||
ONNX_NAMESPACE::AttributeProto non_differentiable_outputs;
|
||||
const std::string non_differentiable_outputs_name = "non_differentiable_outputs";
|
||||
non_differentiable_outputs.set_name(non_differentiable_outputs_name);
|
||||
non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
|
||||
for (auto index : non_differentiable_indices) {
|
||||
non_differentiable_outputs.add_ints(index);
|
||||
}
|
||||
attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs});
|
||||
}
|
||||
|
||||
// YieldOps full_shape_outputs attribute specifies the indices of outputs that must be full shape.
|
||||
|
|
@ -303,6 +303,26 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() {
|
|||
graph_info_.module_output_gradient_name.emplace_back(grad_name);
|
||||
}
|
||||
}
|
||||
|
||||
size_t input_count = yield_input_node_args.size();
|
||||
for (auto& iter : graph_info_.frontier_node_arg_map) {
|
||||
std::string name = iter.second;
|
||||
yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name));
|
||||
graph_info_.cached_node_arg_names.emplace_back(name);
|
||||
}
|
||||
|
||||
const auto& frontier_tensors = graph_info_.frontier_node_arg_map;
|
||||
if (frontier_tensors.size() > 0) {
|
||||
for (size_t index = input_count; index < input_count + frontier_tensors.size(); index++) {
|
||||
non_differentiable_outputs.add_ints(index);
|
||||
}
|
||||
}
|
||||
|
||||
// YieldOps non_differentiable_outputs /attribute specifies the indices of outputs that are not differentiable
|
||||
if (non_differentiable_indices.size() > 0 || frontier_tensors.size() > 0) {
|
||||
attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs});
|
||||
}
|
||||
|
||||
attributes.insert({full_shape_outputs_name, full_shape_outputs});
|
||||
|
||||
// Handle potential duplciated output_gradient names
|
||||
|
|
|
|||
|
|
@ -61,8 +61,11 @@ struct GraphInfo {
|
|||
std::vector<size_t> module_output_indices_requires_save_for_backward{};
|
||||
// Names of module outputs' gradient
|
||||
std::vector<std::string> module_output_gradient_name{};
|
||||
|
||||
// Names of the frontier tensor corresponding to param
|
||||
std::unordered_map<std::string, std::string> frontier_node_arg_map{};
|
||||
// Names of the frontier NodeArgs in the order in which they will
|
||||
// be retrieved in the forward pass
|
||||
std::vector<std::string> cached_node_arg_names{};
|
||||
};
|
||||
|
||||
class OrtModuleGraphBuilder {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<OrtValue>);
|
||||
PYBIND11_MAKE_OPAQUE(onnxruntime::OrtValueCache);
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
|
@ -329,6 +330,29 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx)));
|
||||
});
|
||||
|
||||
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
|
||||
.def(py::init<>())
|
||||
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
|
||||
cache_ptr->emplace(node_arg_name, value);
|
||||
})
|
||||
.def("keys", [](const OrtValueCachePtr& cache_ptr) {
|
||||
py::list keys;
|
||||
for(auto kv : *cache_ptr.get()) {
|
||||
keys.append(kv.first);
|
||||
}
|
||||
return keys;
|
||||
})
|
||||
.def("clear", [](const OrtValueCachePtr& cache_ptr) {
|
||||
cache_ptr->clear();
|
||||
})
|
||||
.def("count", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
|
||||
return cache_ptr->count(node_arg_name);
|
||||
})
|
||||
.def("remove", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
|
||||
const auto& num_entries_erased = cache_ptr->erase(node_arg_name);
|
||||
ORT_ENFORCE(num_entries_erased == 1, "NodeArg not found in cache: ", node_arg_name);
|
||||
});
|
||||
|
||||
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
|
||||
parameters.def(py::init())
|
||||
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
|
||||
|
|
@ -530,8 +554,8 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return std::make_unique<TrainingAgent>(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info,
|
||||
bw_fetches_names, bw_outputs_device_info);
|
||||
}))
|
||||
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state) -> void {
|
||||
Status status = agent->RunForward(feeds, fetches, *state);
|
||||
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void {
|
||||
Status status = agent->RunForward(feeds, fetches, *state, cache);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -619,6 +643,7 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
|
||||
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
|
||||
.def_readwrite("frontier_node_arg_map", &GraphInfo::frontier_node_arg_map)
|
||||
.def_readwrite("cached_node_arg_names", &GraphInfo::cached_node_arg_names)
|
||||
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
|
||||
|
||||
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");
|
||||
|
|
|
|||
|
|
@ -116,14 +116,15 @@ class TrainingAgent(object):
|
|||
self._training_agent = C_TrainingAgent(self._inference_session._sess, fw_feed_names, fw_outputs_device_info,
|
||||
bw_fetches_names, bw_outputs_device_info)
|
||||
|
||||
def run_forward(self, feeds, fetches, state):
|
||||
def run_forward(self, feeds, fetches, state, cache=None):
|
||||
"""
|
||||
Compute the forward subgraph for given feeds and fetches.
|
||||
:param feeds: Inputs to the graph run.
|
||||
:param fetches: Outputs of the graph run.
|
||||
:param state: State of the graph that is used for executing partial graph runs.
|
||||
:param cache: Cache to store stashed OrtValues for intermediate activations.
|
||||
"""
|
||||
self._training_agent.run_forward(feeds, fetches, state)
|
||||
self._training_agent.run_forward(feeds, fetches, state, cache)
|
||||
|
||||
def run_backward(self, feeds, fetches, state):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from . import _utils
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
|
||||
class GradientAccumulationManager(object):
|
||||
"""Handles Gradient accumulation optimization during training
|
||||
|
||||
This feature must be enabled once before training and cannot be turned off within a training run.
|
||||
"""
|
||||
# TODO: enable switching the feature on/off in the middle of the training
|
||||
|
||||
def __init__(self):
|
||||
self.cache = None
|
||||
self._param_name_value_map = None
|
||||
self._param_version_map = None
|
||||
self._frontier_node_arg_map = None
|
||||
self._enabled = False
|
||||
self._update_cache = False
|
||||
|
||||
def initialize(self, enabled, module, graph_info) -> None:
|
||||
"""Initializes Gradient Accumulation optimization.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether the optimization is enabled or disabled.
|
||||
module (torch.nn.Module): Training model
|
||||
graph_info (GraphInfo): The ORT Graph Info object holding information about backend graph.
|
||||
"""
|
||||
if enabled:
|
||||
self._enabled = True
|
||||
self.cache = C.OrtValueCache()
|
||||
# Since named_parameters() is a generator function, need to avoid overhead and
|
||||
# populate the params in memory to avoid generating the param map every
|
||||
# step. This will not work if the user adds or removes params between steps
|
||||
self._param_name_value_map = {
|
||||
name: param for name, param in module.named_parameters()}
|
||||
self._param_version_map = dict()
|
||||
self._frontier_node_arg_map = graph_info.frontier_node_arg_map
|
||||
self._cached_node_arg_names = graph_info.cached_node_arg_names
|
||||
self._cache_start = len(graph_info.user_output_names)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
"""Indicates whether gradient accumulation optimization is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def extract_outputs_and_maybe_update_cache(self, forward_outputs):
|
||||
"""Extract the user outputs from the forward outputs as torch tensor and update cache, if needed
|
||||
|
||||
Args:
|
||||
forward_outputs (OrtValueVector): List of outputs returned by forward function
|
||||
"""
|
||||
if not self.enabled:
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
|
||||
if self._update_cache:
|
||||
for i in range(self._cache_start, len(forward_outputs)):
|
||||
self.cache.insert(
|
||||
self._cached_node_arg_names[i-self._cache_start], forward_outputs[i])
|
||||
self._update_cache = False
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start))
|
||||
|
||||
def maybe_update_cache_before_run(self):
|
||||
"""Update cache when model parameters are modified and optimization is enabled.
|
||||
"""
|
||||
# The current implementation relies on param._version, which might not be
|
||||
# updated in all cases(eg. inplace update)
|
||||
# TODO: Make detection of parameter update robust
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# parse param versions to detect change or no change
|
||||
for name, arg_name in self._frontier_node_arg_map.items():
|
||||
param = self._param_name_value_map[name]
|
||||
if name not in self._param_version_map:
|
||||
self._param_version_map[name] = param._version
|
||||
elif param._version != self._param_version_map[name]:
|
||||
# there is an updated param, so remove entry from cache
|
||||
# in order to recompute the value
|
||||
if self.cache.count(arg_name):
|
||||
self.cache.remove(arg_name)
|
||||
self._update_cache = True
|
||||
self._param_version_map[name] = param._version
|
||||
|
|
@ -14,6 +14,7 @@ from ._fallback import (_FallbackManager,
|
|||
ORTModuleONNXModelException,
|
||||
ORTModuleTorchModelException,
|
||||
wrap_exception)
|
||||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
|
@ -152,6 +153,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self._enable_grad_acc_optimization = False
|
||||
self._gradient_accumulation_manager = GradientAccumulationManager()
|
||||
|
||||
# Memory aware gradient builder.
|
||||
self._use_memory_efficient_gradient = False
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._export_mode = torch.onnx.TrainingMode.TRAINING
|
||||
|
||||
@staticmethod
|
||||
def execution_session_run_forward(execution_session, onnx_model, *inputs):
|
||||
def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs):
|
||||
"""Runs the forward graph on execution_session with given model inputs and device"""
|
||||
|
||||
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
|
||||
|
|
@ -43,8 +43,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
forward_outputs = C.OrtValueVector()
|
||||
# Run and return module outputs.
|
||||
execution_session.run_forward(forward_inputs, forward_outputs, state)
|
||||
user_outputs = tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
|
||||
execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache)
|
||||
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs)
|
||||
|
||||
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
|
||||
run_info = _RunStateInfo(state, output_info)
|
||||
|
|
@ -113,6 +113,10 @@ class TrainingManager(GraphExecutionManager):
|
|||
if self._debug_options.logging.log_level <= _logger.LogLevel.WARNING:
|
||||
warnings.warn("Fast path enabled - skipping checks for rebuilding gradient graph, execution agent creation, and device during training.",
|
||||
UserWarning)
|
||||
|
||||
self._gradient_accumulation_manager.initialize(self._enable_grad_acc_optimization, self._flattened_module, self._graph_info)
|
||||
|
||||
self._gradient_accumulation_manager.maybe_update_cache_before_run()
|
||||
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
|
|
@ -135,6 +139,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
|
||||
self._onnx_models.optimized_model,
|
||||
self._gradient_accumulation_manager,
|
||||
*inputs)
|
||||
|
||||
# Disable materializing grads then None object will not be
|
||||
|
|
@ -146,8 +151,12 @@ class TrainingManager(GraphExecutionManager):
|
|||
# ORT is NOT relying on save_for_backward() to actually save the tensor,
|
||||
# as this tensor is also kept in ORT's PartialGraphState
|
||||
# This call is to invoke pytorch's version check to detect the potential inplace corruption
|
||||
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
|
||||
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
|
||||
# need marking with save_for_backward()
|
||||
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
if idx < len(self._graph_info.user_output_names):
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
|
||||
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
|
||||
# This will return torch the output tensors with correct requires_grad settings
|
||||
|
|
@ -273,7 +282,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
C.OrtDevice(get_ort_device_type(self._device.type),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(self._graph_info.user_output_names)
|
||||
)] * (len(self._graph_info.user_output_names) +
|
||||
len(self._graph_info.frontier_node_arg_map))
|
||||
|
||||
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
|
||||
bw_outputs_device_info = [
|
||||
|
|
|
|||
|
|
@ -3140,6 +3140,59 @@ def test_debug_options_log_level_validation_fails_on_type_mismatch():
|
|||
_ = DebugOptions(log_level=log_level)
|
||||
assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value)
|
||||
|
||||
def test_ortmodule_gradient_accumulation_optimization_correctness():
|
||||
class NeuralNetWithCast(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetWithCast, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = NeuralNetWithCast(D_in, H, D_out).to(device)
|
||||
|
||||
# baseline model with optimization disabled
|
||||
tgt_model = ORTModule(pt_model)
|
||||
tgt_optimizer = torch.optim.Adam(tgt_model.parameters())
|
||||
|
||||
# model with optimization enabled
|
||||
opt_model = ORTModule(copy.deepcopy(pt_model))
|
||||
opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
opt_optimizer = torch.optim.Adam(opt_model.parameters())
|
||||
|
||||
def run_step(model, x):
|
||||
with amp.autocast():
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return loss.detach()
|
||||
|
||||
def run_optim_step(optimizer):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
GA_steps = 2
|
||||
tgt_model.zero_grad()
|
||||
opt_model.zero_grad()
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
tgt_loss = run_step(tgt_model, x)
|
||||
opt_loss = run_step(opt_model, x)
|
||||
|
||||
# assert that loss values match
|
||||
_test_helpers.assert_values_are_close(tgt_loss, opt_loss)
|
||||
if step % GA_steps == 0:
|
||||
run_optim_step(tgt_optimizer)
|
||||
run_optim_step(opt_optimizer)
|
||||
|
||||
def test_ortmodule_dict_input():
|
||||
class DictNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -42,6 +42,12 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
|
|||
# vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
|
||||
model.train()
|
||||
|
||||
# Always clear any previously calculated gradients before performing a
|
||||
# backward pass. PyTorch doesn't do this automatically because
|
||||
# accumulating the gradients is "convenient while training RNNs".
|
||||
# (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# For each batch of training data...
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
|
|
@ -61,12 +67,6 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
|
|||
b_input_mask = batch[1].to(device)
|
||||
b_labels = batch[2].to(device)
|
||||
|
||||
# Always clear any previously calculated gradients before performing a
|
||||
# backward pass. PyTorch doesn't do this automatically because
|
||||
# accumulating the gradients is "convenient while training RNNs".
|
||||
# (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Perform a forward pass (evaluate the model on this training batch).
|
||||
# This will return the loss (rather than the model output) because we have provided the `labels`.
|
||||
# The documentation for this `model` function is here:
|
||||
|
|
@ -108,14 +108,16 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
|
|||
# This is to help prevent the "exploding gradients" problem.
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
# Update parameters and take a step using the computed gradient.
|
||||
# The optimizer dictates the "update rule"--how the parameters are
|
||||
# modified based on their gradients, the learning rate, etc.
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
if step % args.grad_acc_steps == 0:
|
||||
# Update parameters and take a step using the computed gradient.
|
||||
# The optimizer dictates the "update rule"--how the parameters are
|
||||
# modified based on their gradients, the learning rate, etc.
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
# Update the learning rate.
|
||||
scheduler.step()
|
||||
# Update the learning rate.
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Calculate the average loss over the training data.
|
||||
avg_train_loss = total_loss / len(train_dataloader)
|
||||
|
|
@ -320,6 +322,8 @@ def main():
|
|||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--do-val', action='store_true', default=False,
|
||||
help='disables validation')
|
||||
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for testing (default: 64)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
|
|
@ -340,6 +344,8 @@ def main():
|
|||
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
|
||||
parser.add_argument('--data-dir', type=str, default='./cola_public/raw',
|
||||
help='Path to the bert data directory')
|
||||
parser.add_argument('--grad-acc-steps', type=int, default=2,
|
||||
help='Number of steps for accumulating gradients')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -376,23 +382,24 @@ def main():
|
|||
config=config,
|
||||
)
|
||||
|
||||
# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
|
||||
optimizer = torch.optim.AdamW(model.parameters(),
|
||||
lr = 2e-2, # args.learning_rate - default is 5e-5, our notebook had 2e-5
|
||||
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
|
||||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
# Just for future debugging
|
||||
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast')
|
||||
|
||||
model = ORTModule(model, debug_options)
|
||||
|
||||
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
model.cuda()
|
||||
|
||||
# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
|
||||
optimizer = AdamW(model.parameters(),
|
||||
lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
|
||||
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
|
||||
)
|
||||
|
||||
|
||||
# Authors recommend between 2 and 4 epochs
|
||||
# Total number of training steps is number of batches * number of epochs.
|
||||
|
|
@ -418,10 +425,12 @@ def main():
|
|||
total_training_time += train(model, optimizer, scaler, scheduler, train_dataloader, epoch_i, device, args)
|
||||
if not args.pytorch_only and epoch_i == 0:
|
||||
epoch_0_training = total_training_time
|
||||
test_time, validation_accuracy = test(model, validation_dataloader, device, args)
|
||||
total_test_time += test_time
|
||||
if args.do_val:
|
||||
test_time, validation_accuracy = test(model, validation_dataloader, device, args)
|
||||
total_test_time += test_time
|
||||
|
||||
assert validation_accuracy > 0.5
|
||||
if args.do_val:
|
||||
assert validation_accuracy > 0.5
|
||||
|
||||
print('\n======== Global stats ========')
|
||||
if not args.pytorch_only:
|
||||
|
|
|
|||
Loading…
Reference in a new issue