Gradient Accumulation optimization verified for correctness (#8273)

* Fetching frontier tensors to frontend

* Move before session initialize call

* Fetch tensor and add to cache

* Rest of the changes for using cache

* Review comments

* Review changes

* Review comments

* switch to shared_ptr

* Fix bug after rebase

* FE docstring change
This commit is contained in:
ashbhandare 2021-08-17 16:24:44 -07:00 committed by GitHub
parent 224380448d
commit cc275e7529
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 300 additions and 55 deletions

View file

@ -233,6 +233,22 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
// construct OpKernelContext
// TODO: log kernel inputs?
OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, false);
// Cache lookup. Currently we only cache single-output nodes,
// to keep memory overhead impact in check. Hence we only look in cache
// if the current node has one output.
bool reuse_cached_value = false;
std::string cached_arg_name;
if (cache_ != nullptr) {
if (p_op_kernel->Node().OutputDefs().size() == 1) {
cached_arg_name = p_op_kernel->Node().OutputDefs()[0]->Name();
if (cache_.get()->count(cached_arg_name)) { // found arg in cache_
VLOGS(logger, 1) << "Found OrtValue in cache for arg: " << cached_arg_name;
reuse_cached_value = true;
}
}
}
// TODO: log kernel outputs?
if (is_profiler_enabled) {
sync_time_begin = session_state.Profiler().StartTime();
@ -312,7 +328,11 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
ORT_RETURN_IF_ERROR(utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context));
}
#endif
compute_status = p_op_kernel->Compute(&op_kernel_context);
if (!reuse_cached_value) {
compute_status = p_op_kernel->Compute(&op_kernel_context);
} else {
compute_status = op_kernel_context.SetOutputMLValue(0, cache_.get()->at(cached_arg_name));
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {

View file

@ -19,8 +19,10 @@
namespace onnxruntime {
class PartialExecutor : public IExecutor {
public:
PartialExecutor(PartialGraphExecutionState& state)
: state_{state} {}
PartialExecutor(PartialGraphExecutionState& state,
const OrtValueCachePtr& cache)
: state_{state},
cache_{cache} {}
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
@ -31,6 +33,7 @@ class PartialExecutor : public IExecutor {
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PartialExecutor);
PartialGraphExecutionState& state_;
const OrtValueCachePtr& cache_;
};
} // namespace onnxruntime
#endif

View file

@ -8,6 +8,10 @@
#include "core/framework/execution_frame.h"
namespace onnxruntime {
typedef std::unordered_map<std::string, OrtValue> OrtValueCache;
typedef std::shared_ptr<OrtValueCache> OrtValueCachePtr;
struct PartialGraphExecutionState {
public:
PartialGraphExecutionState() {
@ -28,7 +32,7 @@ struct PartialGraphExecutionState {
const SessionState& session_state) {
if (execution_frame_ == nullptr) {
execution_frame_ = std::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state);
fetch_allocators, session_state);
} else {
execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds);
execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());

View file

@ -631,10 +631,12 @@ common::Status ExecuteGraph(const SessionState& session_state,
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state) {
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache) {
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
PartialExecutor executor{state};
PartialExecutor executor{state, cache};
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();

View file

@ -94,7 +94,8 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state);
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache);
#endif
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.

View file

@ -1650,7 +1650,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager) {
FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache) {
Status retval = Status::OK();
std::vector<IExecutionProvider*> exec_providers_to_stop;
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
@ -1691,7 +1692,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
// execute the graph
ORT_CHECK_AND_SET_RETVAL(utils::ExecutePartialGraph(*session_state_, feeds_fetches_manager, feeds, fetches,
run_logger, state));
run_logger, state, cache));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {

View file

@ -316,12 +316,15 @@ class InferenceSession {
* @param state State of the graph needed to resume partial graph run.
* @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device
* copy/checks.
* @param cache Contains node arg name to OrtValue map stashed from previous run
* for frontier tensors
*/
common::Status PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager);
FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache);
#endif
/**

View file

@ -50,25 +50,26 @@ TrainingAgent::TrainingAgent(InferenceSession& session,
TrainingAgent::~TrainingAgent() = default;
common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) {
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) {
state.SetProgramCounterStart(0);
state.SetProgramCounterEnd(fw_program_counter_end_);
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_);
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache);
}
common::Status TrainingAgent::RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) {
state.SetProgramCounterStart(fw_program_counter_end_);
state.SetProgramCounterEnd(bw_program_counter_end_);
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_);
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr);
}
common::Status TrainingAgent::RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager) {
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache) {
auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size();
fetches.resize(fetches_size, {});
RunOptions run_options;
return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager);
return inference_session_.PartialRun(run_options, feeds, fetches, state, feeds_fetches_manager, cache);
}
void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& session_state,

View file

@ -25,14 +25,15 @@ class TrainingAgent {
~TrainingAgent();
// For ORTModule.forward()
common::Status RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) ORT_MUST_USE_RESULT;
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) ORT_MUST_USE_RESULT;
// For ORTModule.backward()
common::Status RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) ORT_MUST_USE_RESULT;
common::Status RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager)
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache)
ORT_MUST_USE_RESULT;
void CreateAndInitializeFeedsFetchesManager(const SessionState& session_state,

View file

@ -263,15 +263,15 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() {
// YieldOps non_differentiable_outputs attribute specifies the indices of outputs that are not differentiable
const auto& non_differentiable_indices = graph_info_.output_grad_indices_non_differentiable;
const std::string non_differentiable_outputs_name = "non_differentiable_outputs";
ONNX_NAMESPACE::AttributeProto non_differentiable_outputs;
non_differentiable_outputs.set_name(non_differentiable_outputs_name);
non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
if (non_differentiable_indices.size() > 0) {
ONNX_NAMESPACE::AttributeProto non_differentiable_outputs;
const std::string non_differentiable_outputs_name = "non_differentiable_outputs";
non_differentiable_outputs.set_name(non_differentiable_outputs_name);
non_differentiable_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
for (auto index : non_differentiable_indices) {
non_differentiable_outputs.add_ints(index);
}
attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs});
}
// YieldOps full_shape_outputs attribute specifies the indices of outputs that must be full shape.
@ -303,6 +303,26 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() {
graph_info_.module_output_gradient_name.emplace_back(grad_name);
}
}
size_t input_count = yield_input_node_args.size();
for (auto& iter : graph_info_.frontier_node_arg_map) {
std::string name = iter.second;
yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name));
graph_info_.cached_node_arg_names.emplace_back(name);
}
const auto& frontier_tensors = graph_info_.frontier_node_arg_map;
if (frontier_tensors.size() > 0) {
for (size_t index = input_count; index < input_count + frontier_tensors.size(); index++) {
non_differentiable_outputs.add_ints(index);
}
}
// YieldOps non_differentiable_outputs /attribute specifies the indices of outputs that are not differentiable
if (non_differentiable_indices.size() > 0 || frontier_tensors.size() > 0) {
attributes.insert({non_differentiable_outputs_name, non_differentiable_outputs});
}
attributes.insert({full_shape_outputs_name, full_shape_outputs});
// Handle potential duplciated output_gradient names

View file

@ -61,8 +61,11 @@ struct GraphInfo {
std::vector<size_t> module_output_indices_requires_save_for_backward{};
// Names of module outputs' gradient
std::vector<std::string> module_output_gradient_name{};
// Names of the frontier tensor corresponding to param
std::unordered_map<std::string, std::string> frontier_node_arg_map{};
// Names of the frontier NodeArgs in the order in which they will
// be retrieved in the forward pass
std::vector<std::string> cached_node_arg_names{};
};
class OrtModuleGraphBuilder {

View file

@ -18,6 +18,7 @@
#include "python/onnxruntime_pybind_mlvalue.h"
PYBIND11_MAKE_OPAQUE(std::vector<OrtValue>);
PYBIND11_MAKE_OPAQUE(onnxruntime::OrtValueCache);
namespace onnxruntime {
namespace python {
@ -329,6 +330,29 @@ void addObjectMethodsForTraining(py::module& m) {
return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx)));
});
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
.def(py::init<>())
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
cache_ptr->emplace(node_arg_name, value);
})
.def("keys", [](const OrtValueCachePtr& cache_ptr) {
py::list keys;
for(auto kv : *cache_ptr.get()) {
keys.append(kv.first);
}
return keys;
})
.def("clear", [](const OrtValueCachePtr& cache_ptr) {
cache_ptr->clear();
})
.def("count", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
return cache_ptr->count(node_arg_name);
})
.def("remove", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
const auto& num_entries_erased = cache_ptr->erase(node_arg_name);
ORT_ENFORCE(num_entries_erased == 1, "NodeArg not found in cache: ", node_arg_name);
});
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
@ -530,8 +554,8 @@ void addObjectMethodsForTraining(py::module& m) {
return std::make_unique<TrainingAgent>(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info,
bw_fetches_names, bw_outputs_device_info);
}))
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state) -> void {
Status status = agent->RunForward(feeds, fetches, *state);
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void {
Status status = agent->RunForward(feeds, fetches, *state, cache);
if (!status.IsOK()) {
throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage());
}
@ -619,6 +643,7 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
.def_readwrite("frontier_node_arg_map", &GraphInfo::frontier_node_arg_map)
.def_readwrite("cached_node_arg_names", &GraphInfo::cached_node_arg_names)
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");

View file

@ -116,14 +116,15 @@ class TrainingAgent(object):
self._training_agent = C_TrainingAgent(self._inference_session._sess, fw_feed_names, fw_outputs_device_info,
bw_fetches_names, bw_outputs_device_info)
def run_forward(self, feeds, fetches, state):
def run_forward(self, feeds, fetches, state, cache=None):
"""
Compute the forward subgraph for given feeds and fetches.
:param feeds: Inputs to the graph run.
:param fetches: Outputs of the graph run.
:param state: State of the graph that is used for executing partial graph runs.
:param cache: Cache to store stashed OrtValues for intermediate activations.
"""
self._training_agent.run_forward(feeds, fetches, state)
self._training_agent.run_forward(feeds, fetches, state, cache)
def run_backward(self, feeds, fetches, state):
"""

View file

@ -0,0 +1,86 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from . import _utils
from onnxruntime.capi import _pybind_state as C
class GradientAccumulationManager(object):
"""Handles Gradient accumulation optimization during training
This feature must be enabled once before training and cannot be turned off within a training run.
"""
# TODO: enable switching the feature on/off in the middle of the training
def __init__(self):
self.cache = None
self._param_name_value_map = None
self._param_version_map = None
self._frontier_node_arg_map = None
self._enabled = False
self._update_cache = False
def initialize(self, enabled, module, graph_info) -> None:
"""Initializes Gradient Accumulation optimization.
Args:
enabled (bool): Whether the optimization is enabled or disabled.
module (torch.nn.Module): Training model
graph_info (GraphInfo): The ORT Graph Info object holding information about backend graph.
"""
if enabled:
self._enabled = True
self.cache = C.OrtValueCache()
# Since named_parameters() is a generator function, need to avoid overhead and
# populate the params in memory to avoid generating the param map every
# step. This will not work if the user adds or removes params between steps
self._param_name_value_map = {
name: param for name, param in module.named_parameters()}
self._param_version_map = dict()
self._frontier_node_arg_map = graph_info.frontier_node_arg_map
self._cached_node_arg_names = graph_info.cached_node_arg_names
self._cache_start = len(graph_info.user_output_names)
@property
def enabled(self):
"""Indicates whether gradient accumulation optimization is enabled.
"""
return self._enabled
def extract_outputs_and_maybe_update_cache(self, forward_outputs):
"""Extract the user outputs from the forward outputs as torch tensor and update cache, if needed
Args:
forward_outputs (OrtValueVector): List of outputs returned by forward function
"""
if not self.enabled:
return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
if self._update_cache:
for i in range(self._cache_start, len(forward_outputs)):
self.cache.insert(
self._cached_node_arg_names[i-self._cache_start], forward_outputs[i])
self._update_cache = False
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start))
def maybe_update_cache_before_run(self):
"""Update cache when model parameters are modified and optimization is enabled.
"""
# The current implementation relies on param._version, which might not be
# updated in all cases(eg. inplace update)
# TODO: Make detection of parameter update robust
if not self.enabled:
return
# parse param versions to detect change or no change
for name, arg_name in self._frontier_node_arg_map.items():
param = self._param_name_value_map[name]
if name not in self._param_version_map:
self._param_version_map[name] = param._version
elif param._version != self._param_version_map[name]:
# there is an updated param, so remove entry from cache
# in order to recompute the value
if self.cache.count(arg_name):
self.cache.remove(arg_name)
self._update_cache = True
self._param_version_map[name] = param._version

View file

@ -14,6 +14,7 @@ from ._fallback import (_FallbackManager,
ORTModuleONNXModelException,
ORTModuleTorchModelException,
wrap_exception)
from ._gradient_accumulation_manager import GradientAccumulationManager
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
from onnxruntime.capi import _pybind_state as C
@ -152,6 +153,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# WIP feature to enable caching in Gradient accumulation scenario.
self._enable_grad_acc_optimization = False
self._gradient_accumulation_manager = GradientAccumulationManager()
# Memory aware gradient builder.
self._use_memory_efficient_gradient = False

View file

@ -28,7 +28,7 @@ class TrainingManager(GraphExecutionManager):
self._export_mode = torch.onnx.TrainingMode.TRAINING
@staticmethod
def execution_session_run_forward(execution_session, onnx_model, *inputs):
def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs):
"""Runs the forward graph on execution_session with given model inputs and device"""
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
@ -43,8 +43,8 @@ class TrainingManager(GraphExecutionManager):
forward_outputs = C.OrtValueVector()
# Run and return module outputs.
execution_session.run_forward(forward_inputs, forward_outputs, state)
user_outputs = tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache)
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs)
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
run_info = _RunStateInfo(state, output_info)
@ -113,6 +113,10 @@ class TrainingManager(GraphExecutionManager):
if self._debug_options.logging.log_level <= _logger.LogLevel.WARNING:
warnings.warn("Fast path enabled - skipping checks for rebuilding gradient graph, execution agent creation, and device during training.",
UserWarning)
self._gradient_accumulation_manager.initialize(self._enable_grad_acc_optimization, self._flattened_module, self._graph_info)
self._gradient_accumulation_manager.maybe_update_cache_before_run()
class _ORTModuleFunction(torch.autograd.Function):
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
@ -135,6 +139,7 @@ class TrainingManager(GraphExecutionManager):
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
self._onnx_models.optimized_model,
self._gradient_accumulation_manager,
*inputs)
# Disable materializing grads then None object will not be
@ -146,8 +151,12 @@ class TrainingManager(GraphExecutionManager):
# ORT is NOT relying on save_for_backward() to actually save the tensor,
# as this tensor is also kept in ORT's PartialGraphState
# This call is to invoke pytorch's version check to detect the potential inplace corruption
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
# need marking with save_for_backward()
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
ctx.save_for_backward(user_outputs[idx])
if idx < len(self._graph_info.user_output_names):
ctx.save_for_backward(user_outputs[idx])
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
# This will return torch the output tensors with correct requires_grad settings
@ -273,7 +282,8 @@ class TrainingManager(GraphExecutionManager):
C.OrtDevice(get_ort_device_type(self._device.type),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(self._graph_info.user_output_names)
)] * (len(self._graph_info.user_output_names) +
len(self._graph_info.frontier_node_arg_map))
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
bw_outputs_device_info = [

View file

@ -3140,6 +3140,59 @@ def test_debug_options_log_level_validation_fails_on_type_mismatch():
_ = DebugOptions(log_level=log_level)
assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value)
def test_ortmodule_gradient_accumulation_optimization_correctness():
class NeuralNetWithCast(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetWithCast, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = NeuralNetWithCast(D_in, H, D_out).to(device)
# baseline model with optimization disabled
tgt_model = ORTModule(pt_model)
tgt_optimizer = torch.optim.Adam(tgt_model.parameters())
# model with optimization enabled
opt_model = ORTModule(copy.deepcopy(pt_model))
opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
opt_optimizer = torch.optim.Adam(opt_model.parameters())
def run_step(model, x):
with amp.autocast():
prediction = model(x)
loss = prediction.sum()
loss.backward()
return loss.detach()
def run_optim_step(optimizer):
optimizer.step()
optimizer.zero_grad()
GA_steps = 2
tgt_model.zero_grad()
opt_model.zero_grad()
for step in range(10):
x = torch.randn(N, D_in, device=device)
tgt_loss = run_step(tgt_model, x)
opt_loss = run_step(opt_model, x)
# assert that loss values match
_test_helpers.assert_values_are_close(tgt_loss, opt_loss)
if step % GA_steps == 0:
run_optim_step(tgt_optimizer)
run_optim_step(opt_optimizer)
def test_ortmodule_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):

View file

@ -42,6 +42,12 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
# vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
model.train()
# Always clear any previously calculated gradients before performing a
# backward pass. PyTorch doesn't do this automatically because
# accumulating the gradients is "convenient while training RNNs".
# (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
optimizer.zero_grad()
# For each batch of training data...
for step, batch in enumerate(train_dataloader):
@ -61,12 +67,6 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
b_input_mask = batch[1].to(device)
b_labels = batch[2].to(device)
# Always clear any previously calculated gradients before performing a
# backward pass. PyTorch doesn't do this automatically because
# accumulating the gradients is "convenient while training RNNs".
# (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
optimizer.zero_grad()
# Perform a forward pass (evaluate the model on this training batch).
# This will return the loss (rather than the model output) because we have provided the `labels`.
# The documentation for this `model` function is here:
@ -108,14 +108,16 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device,
# This is to help prevent the "exploding gradients" problem.
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update parameters and take a step using the computed gradient.
# The optimizer dictates the "update rule"--how the parameters are
# modified based on their gradients, the learning rate, etc.
scaler.step(optimizer)
scaler.update()
if step % args.grad_acc_steps == 0:
# Update parameters and take a step using the computed gradient.
# The optimizer dictates the "update rule"--how the parameters are
# modified based on their gradients, the learning rate, etc.
scaler.step(optimizer)
scaler.update()
# Update the learning rate.
scheduler.step()
# Update the learning rate.
scheduler.step()
optimizer.zero_grad()
# Calculate the average loss over the training data.
avg_train_loss = total_loss / len(train_dataloader)
@ -320,6 +322,8 @@ def main():
help='disables ONNX Runtime training')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--do-val', action='store_true', default=False,
help='disables validation')
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
help='input batch size for testing (default: 64)')
parser.add_argument('--view-graphs', action='store_true', default=False,
@ -340,6 +344,8 @@ def main():
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
parser.add_argument('--data-dir', type=str, default='./cola_public/raw',
help='Path to the bert data directory')
parser.add_argument('--grad-acc-steps', type=int, default=2,
help='Number of steps for accumulating gradients')
args = parser.parse_args()
@ -376,23 +382,24 @@ def main():
config=config,
)
# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
optimizer = torch.optim.AdamW(model.parameters(),
lr = 2e-2, # args.learning_rate - default is 5e-5, our notebook had 2e-5
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
)
if not args.pytorch_only:
# Just for future debugging
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast')
model = ORTModule(model, debug_options)
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
# Tell pytorch to run this model on the GPU.
if torch.cuda.is_available() and not args.no_cuda:
model.cuda()
# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
optimizer = AdamW(model.parameters(),
lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
)
# Authors recommend between 2 and 4 epochs
# Total number of training steps is number of batches * number of epochs.
@ -418,10 +425,12 @@ def main():
total_training_time += train(model, optimizer, scaler, scheduler, train_dataloader, epoch_i, device, args)
if not args.pytorch_only and epoch_i == 0:
epoch_0_training = total_training_time
test_time, validation_accuracy = test(model, validation_dataloader, device, args)
total_test_time += test_time
if args.do_val:
test_time, validation_accuracy = test(model, validation_dataloader, device, args)
total_test_time += test_time
assert validation_accuracy > 0.5
if args.do_val:
assert validation_accuracy > 0.5
print('\n======== Global stats ========')
if not args.pytorch_only: