Refactor TrainingManager.forward (#9354)

* Refactor TrainingManager.forward
This commit is contained in:
Xavier Dupré 2021-10-14 12:54:31 +02:00 committed by GitHub
parent 851554536c
commit 22e3f8bf54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,6 +33,7 @@ class TrainingManager(GraphExecutionManager):
def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager):
super().__init__(model, debug_options, fallback_manager)
self._export_mode = torch.onnx.TrainingMode.TRAINING
self._forward_class = self._create_autofunction_class()
@staticmethod
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
@ -59,6 +60,133 @@ class TrainingManager(GraphExecutionManager):
# Return user outputs and forward run information
return user_outputs, run_info
def _create_autofunction_class(self):
class _ORTModuleFunction(torch.autograd.Function):
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
gradient implementation for self.forward_graph.'''
@staticmethod
def forward(ctx, *inputs):
'''Performs forward pass based on user input and PyTorch initializer
Autograd Function's apply() doesn't support keyword arguments,
so `*inputs` has all the arguments - keyword arguments converted
to positional/keywords during `TrainingManager.forward`.
Module outputs are returned to the user
'''
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(
self._execution_agent,
self._onnx_models.optimized_model,
self._device,
self._gradient_accumulation_manager,
*inputs)
# Disable materializing grads then None object will not be
# converted to a tensor filled with zeros prior to calling backward.
# Save shape/device/type info to ctx for materializing tensor in backward if output grad is None.
ctx.set_materialize_grads(False)
# Mark the outputs tensors needed in backward computation
# ORT is NOT relying on save_for_backward() to actually save the tensor,
# as this tensor is also kept in ORT's PartialGraphState
# This call is to invoke pytorch's version check to detect the potential inplace corruption
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
# need marking with save_for_backward()
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
if idx < len(self._graph_info.user_output_names):
ctx.save_for_backward(user_outputs[idx])
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
# This will return torch the output tensors with correct requires_grad settings
for idx in self._graph_info.output_grad_indices_non_differentiable:
ctx.mark_non_differentiable(user_outputs[idx])
return user_outputs
@staticmethod
def backward(ctx, *grad_outputs):
'''Performs backward pass based on grad wrt module output'''
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
# Unpack saved_tensor to trigger version detection that catches inplace corruption
_ = ctx.saved_tensors
# Use IO binding
# Push user output grads to ONNX backend.
backward_inputs = C.OrtValueVector()
# Preallocate length of the vector. And then delete as required towards the end.
backward_inputs.reserve(len(grad_outputs))
for idx, grad_output in enumerate(grad_outputs):
if idx in self._graph_info.output_grad_indices_non_differentiable:
assert grad_output is None, "ORT found the {}-th module output '{}' is " \
"non-differentiable according to the onnx graph. " \
"However, the gradient value is still provided by " \
"PyTorch's autograd engine." \
.format(idx, self._graph_info.user_output_names[idx])
continue
if grad_output is None:
shape, device, dtype = ctx.run_info.output_info[idx]
if idx in self._graph_info.output_grad_indices_require_full_shape:
grad_output = torch.zeros(shape, device=device, dtype=dtype)
else:
grad_output = torch.tensor(0., device=device, dtype=dtype)
elif not grad_output.is_contiguous():
grad_output = grad_output.contiguous()
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output),
grad_output.dtype is torch.bool)
backward_inputs.shrink_to_fit()
# Run and get results
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state
# Return input and initializer gradients
num_user_input_grads = len(self._input_info.require_grad_names)
results = []
require_grad_names_set = set(self._input_info.require_grad_names)
require_grad_names_index = 0
for input_name in self._graph_info.user_input_names:
# Append to the results the backward output for each input that required grad
if input_name in require_grad_names_set:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(require_grad_names_index),
backward_outputs[require_grad_names_index], self._device))
require_grad_names_index += 1
else:
# input_name is not found in the self._input_info.require_grad_names list
# Append None to results for each input that did not require grad
results.append(None)
# Append gradients of initializer to results
# Go over each initializer, check if it required grad and append to results accordingly
initializer_index = num_user_input_grads
for initializer_name in self._graph_info.initializer_names:
if initializer_name in self._graph_initializer_names_to_train:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(initializer_index),
backward_outputs[initializer_index], self._device))
initializer_index += 1
else:
results.append(None)
return tuple(results)
return _ORTModuleFunction
def forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
@ -135,131 +263,8 @@ class TrainingManager(GraphExecutionManager):
self._gradient_accumulation_manager.maybe_update_cache_before_run()
class _ORTModuleFunction(torch.autograd.Function):
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
gradient implementation for self.forward_graph.'''
@staticmethod
def forward(ctx, *inputs):
'''Performs forward pass based on user input and PyTorch initializer
Autograd Function's apply() doesn't support keyword arguments,
so `*inputs` has all the arguments - keyword arguments converted
to positional/keywords during `TrainingManager.forward`.
Module outputs are returned to the user
'''
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(
self._execution_agent,
self._onnx_models.optimized_model,
self._device,
self._gradient_accumulation_manager,
*inputs)
# Disable materializing grads then None object will not be
# converted to a tensor filled with zeros prior to calling backward.
# Save shape/device/type info to ctx for materializing tensor in backward if output grad is None.
ctx.set_materialize_grads(False)
# Mark the outputs tensors needed in backward computation
# ORT is NOT relying on save_for_backward() to actually save the tensor,
# as this tensor is also kept in ORT's PartialGraphState
# This call is to invoke pytorch's version check to detect the potential inplace corruption
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
# need marking with save_for_backward()
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
if idx < len(self._graph_info.user_output_names):
ctx.save_for_backward(user_outputs[idx])
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
# This will return torch the output tensors with correct requires_grad settings
for idx in self._graph_info.output_grad_indices_non_differentiable:
ctx.mark_non_differentiable(user_outputs[idx])
return user_outputs
@staticmethod
def backward(ctx, *grad_outputs):
'''Performs backward pass based on grad wrt module output'''
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
# Unpack saved_tensor to trigger version detection that catches inplace corruption
_ = ctx.saved_tensors
# Use IO binding
# Push user output grads to ONNX backend.
backward_inputs = C.OrtValueVector()
# Preallocate length of the vector. And then delete as required towards the end.
backward_inputs.reserve(len(grad_outputs))
for idx, grad_output in enumerate(grad_outputs):
if idx in self._graph_info.output_grad_indices_non_differentiable:
assert grad_output is None, "ORT found the {}-th module output '{}' is " \
"non-differentiable according to the onnx graph. " \
"However, the gradient value is still provided by " \
"PyTorch's autograd engine." \
.format(idx, self._graph_info.user_output_names[idx])
continue
if grad_output is None:
shape, device, dtype = ctx.run_info.output_info[idx]
if idx in self._graph_info.output_grad_indices_require_full_shape:
grad_output = torch.zeros(shape, device=device, dtype=dtype)
else:
grad_output = torch.tensor(0., device=device, dtype=dtype)
elif not grad_output.is_contiguous():
grad_output = grad_output.contiguous()
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output),
grad_output.dtype is torch.bool)
backward_inputs.shrink_to_fit()
# Run and get results
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state
# Return input and initializer gradients
num_user_input_grads = len(self._input_info.require_grad_names)
results = []
require_grad_names_set = set(self._input_info.require_grad_names)
require_grad_names_index = 0
for input_name in self._graph_info.user_input_names:
# Append to the results the backward output for each input that required grad
if input_name in require_grad_names_set:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(require_grad_names_index),
backward_outputs[require_grad_names_index], self._device))
require_grad_names_index += 1
else:
# input_name is not found in the self._input_info.require_grad_names list
# Append None to results for each input that did not require grad
results.append(None)
# Append gradients of initializer to results
# Go over each initializer, check if it required grad and append to results accordingly
initializer_index = num_user_input_grads
for initializer_name in self._graph_info.initializer_names:
if initializer_name in self._graph_initializer_names_to_train:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(initializer_index),
backward_outputs[initializer_index], self._device))
initializer_index += 1
else:
results.append(None)
return tuple(results)
return _io.unflatten_user_output(self._module_output_schema,
_ORTModuleFunction.apply(
self._forward_class.apply(
*_io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,