Move IOBinding and RunOptions to ctx (#7028)

* Liqun/ort module perf1 (#6806)

add mysql script to log perf data
Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>

* Resolve HTTP Error 503: Service Unavailable for MNIST dataset (#6989)

* Reduce logging for ORTModule for the end user (#6982)

* Support none types in forward output (#7001)

* Missed test case for none type output (#7014)

* save iobinding to ctx

* save run_options to ctx

* remove debug tests

* PR comments and clean up

* add RunStateInfo

* remove whitespace edits

* PR comments

* remove test changes

* fix test failure

* Fit unit test test_nesting_forward_backward_calls

Co-authored-by: liqunfu <liqfu@microsoft.com>
Co-authored-by: baijumeswani <bmeswani@microsoft.com>
Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
jingyanwangms 2021-03-24 17:51:00 -07:00 committed by GitHub
parent 2e3bbad19f
commit cd67f12add
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 32 deletions

View file

@ -10,3 +10,4 @@ from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, optim, model_desc_validation
from .training_agent import TrainingAgent
from .ortmodule import ORTModule
from .runstateinfo import RunStateInfo

View file

@ -49,7 +49,6 @@ class Verbosity(IntEnum):
ERROR = 3
FATAL = 4
def _create_iobinding(io_binding, inputs, model, device):
'''Creates IO binding for a `model` inputs and output'''
for idx, value_info in enumerate(model.graph.input):
@ -172,24 +171,23 @@ class ORTModule(torch.nn.Module):
_check_same_device(
self._device, "Input argument to forward", *inputs)
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
# especially the backward graph outputs.
training_io_binding = self._training_session.io_binding()
run_options = C.RunOptions()
# Use IO binding
_create_iobinding(self._training_io_binding,
inputs, self._onnx_training, self._device)
_create_iobinding(training_io_binding, inputs, self._onnx_training, self._device)
# Run and return module outputs.
forward_outputs, run_id = self._training_session.run_forward(
self._training_io_binding, self._run_options)
forward_outputs, run_id = self._training_session.run_forward(training_io_binding, run_options)
user_outputs = tuple(_ortvalue_to_torch_tensor(
forward_output) for forward_output in forward_outputs)
ctx.run_id = run_id
# Disable materializing grads then None object will not be converted
# to a tensor filled with zeros prior to calling backward.
# Also save shape, device and type info to ctx for materializing
# tensor in backward if output grad is None.
# Disable materializing grads then None object will not be converted to a tensor filled with zeros prior to calling backward.
# Also save shape, device and type info to ctx for materializing tensor in backward if output grad is None.
ctx.set_materialize_grads(False)
ctx.output_info = [
(output.shape, output.device, output.dtype) for output in user_outputs]
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
ctx.run_info = onnxruntime.training.RunStateInfo(run_id, run_options, training_io_binding, output_info)
# Assert that the outputs and model device match
_check_same_device(
@ -201,6 +199,7 @@ class ORTModule(torch.nn.Module):
def backward(ctx, *grad_outputs):
'''Performs backward pass based on grad wrt module output
'''
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
# Assert that the grad_outputs and model device match
_check_same_device(
@ -217,7 +216,7 @@ class ORTModule(torch.nn.Module):
continue
if grad_output is None:
shape, device, dtype = ctx.output_info[idx]
shape, device, dtype = ctx.run_info.output_info[idx]
if idx in self._onnx_graphs_info.output_grad_indices_require_full_shape:
grad_output = torch.zeros(
shape, device=device, dtype=dtype)
@ -231,10 +230,10 @@ class ORTModule(torch.nn.Module):
grad_output) for grad_output in contiguous_grad_outputs]
# Run and get results
run_id = ctx.run_id
self._training_session.run_backward(
backward_grad_output_ortvalue, run_id)
backward_outputs = self._training_io_binding.get_outputs()
run_id = ctx.run_info.run_id
training_io_binding = ctx.run_info.io_binding
self._training_session.run_backward(backward_grad_output_ortvalue, run_id)
backward_outputs = training_io_binding.get_outputs()
# Return input and initializer gradients
num_user_input_grads = len(self._input_names_require_grad)
@ -265,9 +264,9 @@ class ORTModule(torch.nn.Module):
# The OrtValue has a shared_ptr to the data.
# At this point there are two shared_ptrs to the data, one through the
# OrtValue in the output iobinding, and the other through the copy in OrtDLManagedTensor.
# The following call clears the iobinding output, reducing the use_count to 1,
# so that once torch finishes computation on the DLpack tensors, the memory can be freed.
self._training_io_binding.clear_binding_outputs()
# The following call clears the iobinding output, reducing the use_count to 1, so that once torch finishes computation
# on the DLpack tensors, the memory can be freed.
training_io_binding.clear_binding_outputs()
return tuple(results)
return _ortmodule_io.populate_user_output_from_schema_and_outputs(
@ -322,8 +321,6 @@ class ORTModule(torch.nn.Module):
# Training model
self._onnx_training = None
self._training_session = None
self._training_io_binding = None
self._run_options = None
# Log level
self._loglevel = getattr(logging, 'WARNING')
@ -408,14 +405,6 @@ class ORTModule(torch.nn.Module):
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
session_options, providers, provider_options)
# Use this global run_options for now
self._run_options = C.RunOptions()
# IO binding
# TODO: Reuse output buffers as some of output tensors have same shape,
# especially the backward graph outputs.
self._training_io_binding = self._training_session.io_binding()
def _build_training_graph(self, *inputs, **kwargs):
if self._use_static_shape:
self._module_gradient_graph_builder.build(

View file

@ -0,0 +1,11 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
class RunStateInfo(object):
def __init__(self, run_id, run_options, io_binding, output_info):
self.run_id = run_id
self.run_options = run_options
self.io_binding = io_binding
self.output_info = output_info

View file

@ -503,7 +503,7 @@ def test_gradient_correctness():
x = torch.randn(N, D_in, device=device)
pt_prediction = run_step(pt_model, x)
ort_prediction = run_step(ort_model, x)
assert torch.allclose(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@ -545,6 +545,39 @@ def test_multiple_forward_only_calls():
assert torch.allclose(ort_prediction, pt_prediction)
def test_nesting_forward_backward_calls():
device = 'cuda'
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
# forward1
ort_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
pt_x1 = copy.deepcopy(ort_x1)
ort_prediction1 = ort_model(ort_x1)
pt_prediction1 = pt_model(pt_x1)
assert torch.allclose(ort_prediction1, pt_prediction1)
ort_loss1 = ort_prediction1.sum()
pt_loss1 = pt_prediction1.sum()
# forward2
ort_x2 = torch.randn(N, D_in, device=device, requires_grad=True)
pt_x2 = copy.deepcopy(ort_x2)
ort_prediction2 = ort_model(ort_x2)
ort_loss2 = ort_prediction2.sum()
pt_prediction2 = pt_model(pt_x2)
pt_loss2 = pt_prediction2.sum()
assert torch.allclose(ort_prediction2, pt_prediction2)
# backward2
ort_loss2.backward()
pt_loss2.backward()
assert torch.allclose(ort_x2.grad, ort_x2.grad)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
# backward1
ort_loss1.backward()
pt_loss1.backward()
assert torch.allclose(ort_x1.grad, pt_x1.grad)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
def test_multiple_overlapping_forward_backward_calls():
device = 'cuda'
N, D_in, H, D_out = 32, 784, 500, 10