diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index f3f80983a7..182c8dbf57 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -33,6 +33,7 @@ class TrainingManager(GraphExecutionManager): def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager): super().__init__(model, debug_options, fallback_manager) self._export_mode = torch.onnx.TrainingMode.TRAINING + self._forward_class = self._create_autofunction_class() @staticmethod def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): @@ -59,6 +60,133 @@ class TrainingManager(GraphExecutionManager): # Return user outputs and forward run information return user_outputs, run_info + def _create_autofunction_class(self): + + class _ORTModuleFunction(torch.autograd.Function): + '''Use a custom torch.autograd.Function to associate self.backward_graph as the + gradient implementation for self.forward_graph.''' + + @staticmethod + def forward(ctx, *inputs): + '''Performs forward pass based on user input and PyTorch initializer + + Autograd Function's apply() doesn't support keyword arguments, + so `*inputs` has all the arguments - keyword arguments converted + to positional/keywords during `TrainingManager.forward`. + + Module outputs are returned to the user + ''' + + if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: + # Assert that the input and model device match + _utils._check_same_device(self._device, "Input argument to forward", *inputs) + + user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward( + self._execution_agent, + self._onnx_models.optimized_model, + self._device, + self._gradient_accumulation_manager, + *inputs) + + # Disable materializing grads then None object will not be + # converted to a tensor filled with zeros prior to calling backward. + # Save shape/device/type info to ctx for materializing tensor in backward if output grad is None. + ctx.set_materialize_grads(False) + + # Mark the outputs tensors needed in backward computation + # ORT is NOT relying on save_for_backward() to actually save the tensor, + # as this tensor is also kept in ORT's PartialGraphState + # This call is to invoke pytorch's version check to detect the potential inplace corruption + # If ORT is caching tensors, the module_output_indices_requires_save_for_backward field + # might also have indices of cached tensors that are not passed over to pytorch, and they don't + # need marking with save_for_backward() + for idx in self._graph_info.module_output_indices_requires_save_for_backward: + if idx < len(self._graph_info.user_output_names): + ctx.save_for_backward(user_outputs[idx]) + + # Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info + # This will return torch the output tensors with correct requires_grad settings + for idx in self._graph_info.output_grad_indices_non_differentiable: + ctx.mark_non_differentiable(user_outputs[idx]) + + return user_outputs + + @staticmethod + def backward(ctx, *grad_outputs): + '''Performs backward pass based on grad wrt module output''' + + assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' + if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: + _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) + + # Unpack saved_tensor to trigger version detection that catches inplace corruption + _ = ctx.saved_tensors + + # Use IO binding + # Push user output grads to ONNX backend. + backward_inputs = C.OrtValueVector() + # Preallocate length of the vector. And then delete as required towards the end. + backward_inputs.reserve(len(grad_outputs)) + for idx, grad_output in enumerate(grad_outputs): + if idx in self._graph_info.output_grad_indices_non_differentiable: + assert grad_output is None, "ORT found the {}-th module output '{}' is " \ + "non-differentiable according to the onnx graph. " \ + "However, the gradient value is still provided by " \ + "PyTorch's autograd engine." \ + .format(idx, self._graph_info.user_output_names[idx]) + continue + + if grad_output is None: + shape, device, dtype = ctx.run_info.output_info[idx] + if idx in self._graph_info.output_grad_indices_require_full_shape: + grad_output = torch.zeros(shape, device=device, dtype=dtype) + else: + grad_output = torch.tensor(0., device=device, dtype=dtype) + elif not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), + grad_output.dtype is torch.bool) + backward_inputs.shrink_to_fit() + + # Run and get results + backward_outputs = C.OrtValueVector() + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + del ctx.run_info.state + # Return input and initializer gradients + num_user_input_grads = len(self._input_info.require_grad_names) + results = [] + require_grad_names_set = set(self._input_info.require_grad_names) + require_grad_names_index = 0 + for input_name in self._graph_info.user_input_names: + # Append to the results the backward output for each input that required grad + if input_name in require_grad_names_set: + results.append(_utils._torch_tensor_from_dl_pack( + backward_outputs.dlpack_at(require_grad_names_index), + backward_outputs[require_grad_names_index], self._device)) + require_grad_names_index += 1 + else: + # input_name is not found in the self._input_info.require_grad_names list + # Append None to results for each input that did not require grad + results.append(None) + + # Append gradients of initializer to results + # Go over each initializer, check if it required grad and append to results accordingly + initializer_index = num_user_input_grads + for initializer_name in self._graph_info.initializer_names: + if initializer_name in self._graph_initializer_names_to_train: + results.append(_utils._torch_tensor_from_dl_pack( + backward_outputs.dlpack_at(initializer_index), + backward_outputs[initializer_index], self._device)) + initializer_index += 1 + else: + results.append(None) + + return tuple(results) + + return _ORTModuleFunction + def forward(self, *inputs, **kwargs): '''Forward pass starts here and continues at `_ORTModuleFunction.forward` @@ -135,131 +263,8 @@ class TrainingManager(GraphExecutionManager): self._gradient_accumulation_manager.maybe_update_cache_before_run() - class _ORTModuleFunction(torch.autograd.Function): - '''Use a custom torch.autograd.Function to associate self.backward_graph as the - gradient implementation for self.forward_graph.''' - - @staticmethod - def forward(ctx, *inputs): - '''Performs forward pass based on user input and PyTorch initializer - - Autograd Function's apply() doesn't support keyword arguments, - so `*inputs` has all the arguments - keyword arguments converted - to positional/keywords during `TrainingManager.forward`. - - Module outputs are returned to the user - ''' - - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: - # Assert that the input and model device match - _utils._check_same_device(self._device, "Input argument to forward", *inputs) - - user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward( - self._execution_agent, - self._onnx_models.optimized_model, - self._device, - self._gradient_accumulation_manager, - *inputs) - - # Disable materializing grads then None object will not be - # converted to a tensor filled with zeros prior to calling backward. - # Save shape/device/type info to ctx for materializing tensor in backward if output grad is None. - ctx.set_materialize_grads(False) - - # Mark the outputs tensors needed in backward computation - # ORT is NOT relying on save_for_backward() to actually save the tensor, - # as this tensor is also kept in ORT's PartialGraphState - # This call is to invoke pytorch's version check to detect the potential inplace corruption - # If ORT is caching tensors, the module_output_indices_requires_save_for_backward field - # might also have indices of cached tensors that are not passed over to pytorch, and they don't - # need marking with save_for_backward() - for idx in self._graph_info.module_output_indices_requires_save_for_backward: - if idx < len(self._graph_info.user_output_names): - ctx.save_for_backward(user_outputs[idx]) - - # Mark the outputs tensors non-differentiable if requires_grad is False in _graph_info - # This will return torch the output tensors with correct requires_grad settings - for idx in self._graph_info.output_grad_indices_non_differentiable: - ctx.mark_non_differentiable(user_outputs[idx]) - - return user_outputs - - @staticmethod - def backward(ctx, *grad_outputs): - '''Performs backward pass based on grad wrt module output''' - - assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: - _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) - - # Unpack saved_tensor to trigger version detection that catches inplace corruption - _ = ctx.saved_tensors - - # Use IO binding - # Push user output grads to ONNX backend. - backward_inputs = C.OrtValueVector() - # Preallocate length of the vector. And then delete as required towards the end. - backward_inputs.reserve(len(grad_outputs)) - for idx, grad_output in enumerate(grad_outputs): - if idx in self._graph_info.output_grad_indices_non_differentiable: - assert grad_output is None, "ORT found the {}-th module output '{}' is " \ - "non-differentiable according to the onnx graph. " \ - "However, the gradient value is still provided by " \ - "PyTorch's autograd engine." \ - .format(idx, self._graph_info.user_output_names[idx]) - continue - - if grad_output is None: - shape, device, dtype = ctx.run_info.output_info[idx] - if idx in self._graph_info.output_grad_indices_require_full_shape: - grad_output = torch.zeros(shape, device=device, dtype=dtype) - else: - grad_output = torch.tensor(0., device=device, dtype=dtype) - elif not grad_output.is_contiguous(): - grad_output = grad_output.contiguous() - backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), - grad_output.dtype is torch.bool) - backward_inputs.shrink_to_fit() - - # Run and get results - backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state - # Return input and initializer gradients - num_user_input_grads = len(self._input_info.require_grad_names) - results = [] - require_grad_names_set = set(self._input_info.require_grad_names) - require_grad_names_index = 0 - for input_name in self._graph_info.user_input_names: - # Append to the results the backward output for each input that required grad - if input_name in require_grad_names_set: - results.append(_utils._torch_tensor_from_dl_pack( - backward_outputs.dlpack_at(require_grad_names_index), - backward_outputs[require_grad_names_index], self._device)) - require_grad_names_index += 1 - else: - # input_name is not found in the self._input_info.require_grad_names list - # Append None to results for each input that did not require grad - results.append(None) - - # Append gradients of initializer to results - # Go over each initializer, check if it required grad and append to results accordingly - initializer_index = num_user_input_grads - for initializer_name in self._graph_info.initializer_names: - if initializer_name in self._graph_initializer_names_to_train: - results.append(_utils._torch_tensor_from_dl_pack( - backward_outputs.dlpack_at(initializer_index), - backward_outputs[initializer_index], self._device)) - initializer_index += 1 - else: - results.append(None) - - return tuple(results) - return _io.unflatten_user_output(self._module_output_schema, - _ORTModuleFunction.apply( + self._forward_class.apply( *_io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names,