From cd67f12add8bece5c50454e96ccd3c4021dca83e Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Wed, 24 Mar 2021 17:51:00 -0700 Subject: [PATCH] Move IOBinding and RunOptions to ctx (#7028) * Liqun/ort module perf1 (#6806) add mysql script to log perf data Co-authored-by: liqun * Resolve HTTP Error 503: Service Unavailable for MNIST dataset (#6989) * Reduce logging for ORTModule for the end user (#6982) * Support none types in forward output (#7001) * Missed test case for none type output (#7014) * save iobinding to ctx * save run_options to ctx * remove debug tests * PR comments and clean up * add RunStateInfo * remove whitespace edits * PR comments * remove test changes * fix test failure * Fit unit test test_nesting_forward_backward_calls Co-authored-by: liqunfu Co-authored-by: baijumeswani Co-authored-by: Jingyan Wang --- .../orttraining/python/training/__init__.py | 1 + .../orttraining/python/training/ortmodule.py | 51 ++++++++----------- .../python/training/runstateinfo.py | 11 ++++ .../python/orttraining_test_ortmodule_api.py | 35 ++++++++++++- 4 files changed, 66 insertions(+), 32 deletions(-) create mode 100644 orttraining/orttraining/python/training/runstateinfo.py diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index e8f5eff157..6f32161e16 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -10,3 +10,4 @@ from .orttrainer import ORTTrainer, TrainStepInfo from . import amp, checkpoint, optim, model_desc_validation from .training_agent import TrainingAgent from .ortmodule import ORTModule +from .runstateinfo import RunStateInfo diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 98b26f29a9..26e9e1bf7c 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -49,7 +49,6 @@ class Verbosity(IntEnum): ERROR = 3 FATAL = 4 - def _create_iobinding(io_binding, inputs, model, device): '''Creates IO binding for a `model` inputs and output''' for idx, value_info in enumerate(model.graph.input): @@ -172,24 +171,23 @@ class ORTModule(torch.nn.Module): _check_same_device( self._device, "Input argument to forward", *inputs) + # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, + # especially the backward graph outputs. + training_io_binding = self._training_session.io_binding() + run_options = C.RunOptions() + # Use IO binding - _create_iobinding(self._training_io_binding, - inputs, self._onnx_training, self._device) + _create_iobinding(training_io_binding, inputs, self._onnx_training, self._device) # Run and return module outputs. - forward_outputs, run_id = self._training_session.run_forward( - self._training_io_binding, self._run_options) + forward_outputs, run_id = self._training_session.run_forward(training_io_binding, run_options) user_outputs = tuple(_ortvalue_to_torch_tensor( forward_output) for forward_output in forward_outputs) - ctx.run_id = run_id - - # Disable materializing grads then None object will not be converted - # to a tensor filled with zeros prior to calling backward. - # Also save shape, device and type info to ctx for materializing - # tensor in backward if output grad is None. + # Disable materializing grads then None object will not be converted to a tensor filled with zeros prior to calling backward. + # Also save shape, device and type info to ctx for materializing tensor in backward if output grad is None. ctx.set_materialize_grads(False) - ctx.output_info = [ - (output.shape, output.device, output.dtype) for output in user_outputs] + output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] + ctx.run_info = onnxruntime.training.RunStateInfo(run_id, run_options, training_io_binding, output_info) # Assert that the outputs and model device match _check_same_device( @@ -201,6 +199,7 @@ class ORTModule(torch.nn.Module): def backward(ctx, *grad_outputs): '''Performs backward pass based on grad wrt module output ''' + assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' # Assert that the grad_outputs and model device match _check_same_device( @@ -217,7 +216,7 @@ class ORTModule(torch.nn.Module): continue if grad_output is None: - shape, device, dtype = ctx.output_info[idx] + shape, device, dtype = ctx.run_info.output_info[idx] if idx in self._onnx_graphs_info.output_grad_indices_require_full_shape: grad_output = torch.zeros( shape, device=device, dtype=dtype) @@ -231,10 +230,10 @@ class ORTModule(torch.nn.Module): grad_output) for grad_output in contiguous_grad_outputs] # Run and get results - run_id = ctx.run_id - self._training_session.run_backward( - backward_grad_output_ortvalue, run_id) - backward_outputs = self._training_io_binding.get_outputs() + run_id = ctx.run_info.run_id + training_io_binding = ctx.run_info.io_binding + self._training_session.run_backward(backward_grad_output_ortvalue, run_id) + backward_outputs = training_io_binding.get_outputs() # Return input and initializer gradients num_user_input_grads = len(self._input_names_require_grad) @@ -265,9 +264,9 @@ class ORTModule(torch.nn.Module): # The OrtValue has a shared_ptr to the data. # At this point there are two shared_ptrs to the data, one through the # OrtValue in the output iobinding, and the other through the copy in OrtDLManagedTensor. - # The following call clears the iobinding output, reducing the use_count to 1, - # so that once torch finishes computation on the DLpack tensors, the memory can be freed. - self._training_io_binding.clear_binding_outputs() + # The following call clears the iobinding output, reducing the use_count to 1, so that once torch finishes computation + # on the DLpack tensors, the memory can be freed. + training_io_binding.clear_binding_outputs() return tuple(results) return _ortmodule_io.populate_user_output_from_schema_and_outputs( @@ -322,8 +321,6 @@ class ORTModule(torch.nn.Module): # Training model self._onnx_training = None self._training_session = None - self._training_io_binding = None - self._run_options = None # Log level self._loglevel = getattr(logging, 'WARNING') @@ -408,14 +405,6 @@ class ORTModule(torch.nn.Module): self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(), session_options, providers, provider_options) - # Use this global run_options for now - self._run_options = C.RunOptions() - - # IO binding - # TODO: Reuse output buffers as some of output tensors have same shape, - # especially the backward graph outputs. - self._training_io_binding = self._training_session.io_binding() - def _build_training_graph(self, *inputs, **kwargs): if self._use_static_shape: self._module_gradient_graph_builder.build( diff --git a/orttraining/orttraining/python/training/runstateinfo.py b/orttraining/orttraining/python/training/runstateinfo.py new file mode 100644 index 0000000000..0da7a2b3ce --- /dev/null +++ b/orttraining/orttraining/python/training/runstateinfo.py @@ -0,0 +1,11 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +class RunStateInfo(object): + def __init__(self, run_id, run_options, io_binding, output_info): + self.run_id = run_id + self.run_options = run_options + self.io_binding = io_binding + self.output_info = output_info diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 90b8bf419a..44646be72b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -503,7 +503,7 @@ def test_gradient_correctness(): x = torch.randn(N, D_in, device=device) pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) - + assert torch.allclose(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @@ -545,6 +545,39 @@ def test_multiple_forward_only_calls(): assert torch.allclose(ort_prediction, pt_prediction) +def test_nesting_forward_backward_calls(): + device = 'cuda' + N, D_in, H, D_out = 32, 784, 500, 10 + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + # forward1 + ort_x1 = torch.randn(N, D_in, device=device, requires_grad=True) + pt_x1 = copy.deepcopy(ort_x1) + ort_prediction1 = ort_model(ort_x1) + pt_prediction1 = pt_model(pt_x1) + assert torch.allclose(ort_prediction1, pt_prediction1) + ort_loss1 = ort_prediction1.sum() + pt_loss1 = pt_prediction1.sum() + # forward2 + ort_x2 = torch.randn(N, D_in, device=device, requires_grad=True) + pt_x2 = copy.deepcopy(ort_x2) + ort_prediction2 = ort_model(ort_x2) + ort_loss2 = ort_prediction2.sum() + pt_prediction2 = pt_model(pt_x2) + pt_loss2 = pt_prediction2.sum() + assert torch.allclose(ort_prediction2, pt_prediction2) + # backward2 + ort_loss2.backward() + pt_loss2.backward() + assert torch.allclose(ort_x2.grad, ort_x2.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + # backward1 + ort_loss1.backward() + pt_loss1.backward() + assert torch.allclose(ort_x1.grad, pt_x1.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + def test_multiple_overlapping_forward_backward_calls(): device = 'cuda' N, D_in, H, D_out = 32, 784, 500, 10