mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Refactor TrainingManager.forward (#9354)
* Refactor TrainingManager.forward
This commit is contained in:
parent
851554536c
commit
22e3f8bf54
1 changed files with 129 additions and 124 deletions
|
|
@ -33,6 +33,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager):
|
||||
super().__init__(model, debug_options, fallback_manager)
|
||||
self._export_mode = torch.onnx.TrainingMode.TRAINING
|
||||
self._forward_class = self._create_autofunction_class()
|
||||
|
||||
@staticmethod
|
||||
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
|
||||
|
|
@ -59,6 +60,133 @@ class TrainingManager(GraphExecutionManager):
|
|||
# Return user outputs and forward run information
|
||||
return user_outputs, run_info
|
||||
|
||||
def _create_autofunction_class(self):
|
||||
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
gradient implementation for self.forward_graph.'''
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *inputs):
|
||||
'''Performs forward pass based on user input and PyTorch initializer
|
||||
|
||||
Autograd Function's apply() doesn't support keyword arguments,
|
||||
so `*inputs` has all the arguments - keyword arguments converted
|
||||
to positional/keywords during `TrainingManager.forward`.
|
||||
|
||||
Module outputs are returned to the user
|
||||
'''
|
||||
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
# Assert that the input and model device match
|
||||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(
|
||||
self._execution_agent,
|
||||
self._onnx_models.optimized_model,
|
||||
self._device,
|
||||
self._gradient_accumulation_manager,
|
||||
*inputs)
|
||||
|
||||
# Disable materializing grads then None object will not be
|
||||
# converted to a tensor filled with zeros prior to calling backward.
|
||||
# Save shape/device/type info to ctx for materializing tensor in backward if output grad is None.
|
||||
ctx.set_materialize_grads(False)
|
||||
|
||||
# Mark the outputs tensors needed in backward computation
|
||||
# ORT is NOT relying on save_for_backward() to actually save the tensor,
|
||||
# as this tensor is also kept in ORT's PartialGraphState
|
||||
# This call is to invoke pytorch's version check to detect the potential inplace corruption
|
||||
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
|
||||
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
|
||||
# need marking with save_for_backward()
|
||||
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
|
||||
if idx < len(self._graph_info.user_output_names):
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
|
||||
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
|
||||
# This will return torch the output tensors with correct requires_grad settings
|
||||
for idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
ctx.mark_non_differentiable(user_outputs[idx])
|
||||
|
||||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
'''Performs backward pass based on grad wrt module output'''
|
||||
|
||||
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
|
||||
|
||||
# Unpack saved_tensor to trigger version detection that catches inplace corruption
|
||||
_ = ctx.saved_tensors
|
||||
|
||||
# Use IO binding
|
||||
# Push user output grads to ONNX backend.
|
||||
backward_inputs = C.OrtValueVector()
|
||||
# Preallocate length of the vector. And then delete as required towards the end.
|
||||
backward_inputs.reserve(len(grad_outputs))
|
||||
for idx, grad_output in enumerate(grad_outputs):
|
||||
if idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
assert grad_output is None, "ORT found the {}-th module output '{}' is " \
|
||||
"non-differentiable according to the onnx graph. " \
|
||||
"However, the gradient value is still provided by " \
|
||||
"PyTorch's autograd engine." \
|
||||
.format(idx, self._graph_info.user_output_names[idx])
|
||||
continue
|
||||
|
||||
if grad_output is None:
|
||||
shape, device, dtype = ctx.run_info.output_info[idx]
|
||||
if idx in self._graph_info.output_grad_indices_require_full_shape:
|
||||
grad_output = torch.zeros(shape, device=device, dtype=dtype)
|
||||
else:
|
||||
grad_output = torch.tensor(0., device=device, dtype=dtype)
|
||||
elif not grad_output.is_contiguous():
|
||||
grad_output = grad_output.contiguous()
|
||||
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output),
|
||||
grad_output.dtype is torch.bool)
|
||||
backward_inputs.shrink_to_fit()
|
||||
|
||||
# Run and get results
|
||||
backward_outputs = C.OrtValueVector()
|
||||
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
|
||||
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
|
||||
# affect peak memory usage in a subsequent graph run.
|
||||
del ctx.run_info.state
|
||||
# Return input and initializer gradients
|
||||
num_user_input_grads = len(self._input_info.require_grad_names)
|
||||
results = []
|
||||
require_grad_names_set = set(self._input_info.require_grad_names)
|
||||
require_grad_names_index = 0
|
||||
for input_name in self._graph_info.user_input_names:
|
||||
# Append to the results the backward output for each input that required grad
|
||||
if input_name in require_grad_names_set:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(require_grad_names_index),
|
||||
backward_outputs[require_grad_names_index], self._device))
|
||||
require_grad_names_index += 1
|
||||
else:
|
||||
# input_name is not found in the self._input_info.require_grad_names list
|
||||
# Append None to results for each input that did not require grad
|
||||
results.append(None)
|
||||
|
||||
# Append gradients of initializer to results
|
||||
# Go over each initializer, check if it required grad and append to results accordingly
|
||||
initializer_index = num_user_input_grads
|
||||
for initializer_name in self._graph_info.initializer_names:
|
||||
if initializer_name in self._graph_initializer_names_to_train:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(initializer_index),
|
||||
backward_outputs[initializer_index], self._device))
|
||||
initializer_index += 1
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return tuple(results)
|
||||
|
||||
return _ORTModuleFunction
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
|
||||
|
||||
|
|
@ -135,131 +263,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
self._gradient_accumulation_manager.maybe_update_cache_before_run()
|
||||
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
gradient implementation for self.forward_graph.'''
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *inputs):
|
||||
'''Performs forward pass based on user input and PyTorch initializer
|
||||
|
||||
Autograd Function's apply() doesn't support keyword arguments,
|
||||
so `*inputs` has all the arguments - keyword arguments converted
|
||||
to positional/keywords during `TrainingManager.forward`.
|
||||
|
||||
Module outputs are returned to the user
|
||||
'''
|
||||
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
# Assert that the input and model device match
|
||||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(
|
||||
self._execution_agent,
|
||||
self._onnx_models.optimized_model,
|
||||
self._device,
|
||||
self._gradient_accumulation_manager,
|
||||
*inputs)
|
||||
|
||||
# Disable materializing grads then None object will not be
|
||||
# converted to a tensor filled with zeros prior to calling backward.
|
||||
# Save shape/device/type info to ctx for materializing tensor in backward if output grad is None.
|
||||
ctx.set_materialize_grads(False)
|
||||
|
||||
# Mark the outputs tensors needed in backward computation
|
||||
# ORT is NOT relying on save_for_backward() to actually save the tensor,
|
||||
# as this tensor is also kept in ORT's PartialGraphState
|
||||
# This call is to invoke pytorch's version check to detect the potential inplace corruption
|
||||
# If ORT is caching tensors, the module_output_indices_requires_save_for_backward field
|
||||
# might also have indices of cached tensors that are not passed over to pytorch, and they don't
|
||||
# need marking with save_for_backward()
|
||||
for idx in self._graph_info.module_output_indices_requires_save_for_backward:
|
||||
if idx < len(self._graph_info.user_output_names):
|
||||
ctx.save_for_backward(user_outputs[idx])
|
||||
|
||||
# Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info
|
||||
# This will return torch the output tensors with correct requires_grad settings
|
||||
for idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
ctx.mark_non_differentiable(user_outputs[idx])
|
||||
|
||||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
'''Performs backward pass based on grad wrt module output'''
|
||||
|
||||
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
|
||||
|
||||
# Unpack saved_tensor to trigger version detection that catches inplace corruption
|
||||
_ = ctx.saved_tensors
|
||||
|
||||
# Use IO binding
|
||||
# Push user output grads to ONNX backend.
|
||||
backward_inputs = C.OrtValueVector()
|
||||
# Preallocate length of the vector. And then delete as required towards the end.
|
||||
backward_inputs.reserve(len(grad_outputs))
|
||||
for idx, grad_output in enumerate(grad_outputs):
|
||||
if idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
assert grad_output is None, "ORT found the {}-th module output '{}' is " \
|
||||
"non-differentiable according to the onnx graph. " \
|
||||
"However, the gradient value is still provided by " \
|
||||
"PyTorch's autograd engine." \
|
||||
.format(idx, self._graph_info.user_output_names[idx])
|
||||
continue
|
||||
|
||||
if grad_output is None:
|
||||
shape, device, dtype = ctx.run_info.output_info[idx]
|
||||
if idx in self._graph_info.output_grad_indices_require_full_shape:
|
||||
grad_output = torch.zeros(shape, device=device, dtype=dtype)
|
||||
else:
|
||||
grad_output = torch.tensor(0., device=device, dtype=dtype)
|
||||
elif not grad_output.is_contiguous():
|
||||
grad_output = grad_output.contiguous()
|
||||
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output),
|
||||
grad_output.dtype is torch.bool)
|
||||
backward_inputs.shrink_to_fit()
|
||||
|
||||
# Run and get results
|
||||
backward_outputs = C.OrtValueVector()
|
||||
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
|
||||
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
|
||||
# affect peak memory usage in a subsequent graph run.
|
||||
del ctx.run_info.state
|
||||
# Return input and initializer gradients
|
||||
num_user_input_grads = len(self._input_info.require_grad_names)
|
||||
results = []
|
||||
require_grad_names_set = set(self._input_info.require_grad_names)
|
||||
require_grad_names_index = 0
|
||||
for input_name in self._graph_info.user_input_names:
|
||||
# Append to the results the backward output for each input that required grad
|
||||
if input_name in require_grad_names_set:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(require_grad_names_index),
|
||||
backward_outputs[require_grad_names_index], self._device))
|
||||
require_grad_names_index += 1
|
||||
else:
|
||||
# input_name is not found in the self._input_info.require_grad_names list
|
||||
# Append None to results for each input that did not require grad
|
||||
results.append(None)
|
||||
|
||||
# Append gradients of initializer to results
|
||||
# Go over each initializer, check if it required grad and append to results accordingly
|
||||
initializer_index = num_user_input_grads
|
||||
for initializer_name in self._graph_info.initializer_names:
|
||||
if initializer_name in self._graph_initializer_names_to_train:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(initializer_index),
|
||||
backward_outputs[initializer_index], self._device))
|
||||
initializer_index += 1
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return tuple(results)
|
||||
|
||||
return _io.unflatten_user_output(self._module_output_schema,
|
||||
_ORTModuleFunction.apply(
|
||||
self._forward_class.apply(
|
||||
*_io._combine_input_buffers_initializers(
|
||||
self._graph_initializers,
|
||||
self._graph_info.user_input_names,
|
||||
|
|
|
|||
Loading…
Reference in a new issue