mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Move IOBinding and RunOptions to ctx (#7028)
* Liqun/ort module perf1 (#6806) add mysql script to log perf data Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Resolve HTTP Error 503: Service Unavailable for MNIST dataset (#6989) * Reduce logging for ORTModule for the end user (#6982) * Support none types in forward output (#7001) * Missed test case for none type output (#7014) * save iobinding to ctx * save run_options to ctx * remove debug tests * PR comments and clean up * add RunStateInfo * remove whitespace edits * PR comments * remove test changes * fix test failure * Fit unit test test_nesting_forward_backward_calls Co-authored-by: liqunfu <liqfu@microsoft.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
2e3bbad19f
commit
cd67f12add
4 changed files with 66 additions and 32 deletions
|
|
@ -10,3 +10,4 @@ from .orttrainer import ORTTrainer, TrainStepInfo
|
|||
from . import amp, checkpoint, optim, model_desc_validation
|
||||
from .training_agent import TrainingAgent
|
||||
from .ortmodule import ORTModule
|
||||
from .runstateinfo import RunStateInfo
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class Verbosity(IntEnum):
|
|||
ERROR = 3
|
||||
FATAL = 4
|
||||
|
||||
|
||||
def _create_iobinding(io_binding, inputs, model, device):
|
||||
'''Creates IO binding for a `model` inputs and output'''
|
||||
for idx, value_info in enumerate(model.graph.input):
|
||||
|
|
@ -172,24 +171,23 @@ class ORTModule(torch.nn.Module):
|
|||
_check_same_device(
|
||||
self._device, "Input argument to forward", *inputs)
|
||||
|
||||
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
|
||||
# especially the backward graph outputs.
|
||||
training_io_binding = self._training_session.io_binding()
|
||||
run_options = C.RunOptions()
|
||||
|
||||
# Use IO binding
|
||||
_create_iobinding(self._training_io_binding,
|
||||
inputs, self._onnx_training, self._device)
|
||||
_create_iobinding(training_io_binding, inputs, self._onnx_training, self._device)
|
||||
|
||||
# Run and return module outputs.
|
||||
forward_outputs, run_id = self._training_session.run_forward(
|
||||
self._training_io_binding, self._run_options)
|
||||
forward_outputs, run_id = self._training_session.run_forward(training_io_binding, run_options)
|
||||
user_outputs = tuple(_ortvalue_to_torch_tensor(
|
||||
forward_output) for forward_output in forward_outputs)
|
||||
ctx.run_id = run_id
|
||||
|
||||
# Disable materializing grads then None object will not be converted
|
||||
# to a tensor filled with zeros prior to calling backward.
|
||||
# Also save shape, device and type info to ctx for materializing
|
||||
# tensor in backward if output grad is None.
|
||||
# Disable materializing grads then None object will not be converted to a tensor filled with zeros prior to calling backward.
|
||||
# Also save shape, device and type info to ctx for materializing tensor in backward if output grad is None.
|
||||
ctx.set_materialize_grads(False)
|
||||
ctx.output_info = [
|
||||
(output.shape, output.device, output.dtype) for output in user_outputs]
|
||||
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
|
||||
ctx.run_info = onnxruntime.training.RunStateInfo(run_id, run_options, training_io_binding, output_info)
|
||||
|
||||
# Assert that the outputs and model device match
|
||||
_check_same_device(
|
||||
|
|
@ -201,6 +199,7 @@ class ORTModule(torch.nn.Module):
|
|||
def backward(ctx, *grad_outputs):
|
||||
'''Performs backward pass based on grad wrt module output
|
||||
'''
|
||||
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
|
||||
|
||||
# Assert that the grad_outputs and model device match
|
||||
_check_same_device(
|
||||
|
|
@ -217,7 +216,7 @@ class ORTModule(torch.nn.Module):
|
|||
continue
|
||||
|
||||
if grad_output is None:
|
||||
shape, device, dtype = ctx.output_info[idx]
|
||||
shape, device, dtype = ctx.run_info.output_info[idx]
|
||||
if idx in self._onnx_graphs_info.output_grad_indices_require_full_shape:
|
||||
grad_output = torch.zeros(
|
||||
shape, device=device, dtype=dtype)
|
||||
|
|
@ -231,10 +230,10 @@ class ORTModule(torch.nn.Module):
|
|||
grad_output) for grad_output in contiguous_grad_outputs]
|
||||
|
||||
# Run and get results
|
||||
run_id = ctx.run_id
|
||||
self._training_session.run_backward(
|
||||
backward_grad_output_ortvalue, run_id)
|
||||
backward_outputs = self._training_io_binding.get_outputs()
|
||||
run_id = ctx.run_info.run_id
|
||||
training_io_binding = ctx.run_info.io_binding
|
||||
self._training_session.run_backward(backward_grad_output_ortvalue, run_id)
|
||||
backward_outputs = training_io_binding.get_outputs()
|
||||
|
||||
# Return input and initializer gradients
|
||||
num_user_input_grads = len(self._input_names_require_grad)
|
||||
|
|
@ -265,9 +264,9 @@ class ORTModule(torch.nn.Module):
|
|||
# The OrtValue has a shared_ptr to the data.
|
||||
# At this point there are two shared_ptrs to the data, one through the
|
||||
# OrtValue in the output iobinding, and the other through the copy in OrtDLManagedTensor.
|
||||
# The following call clears the iobinding output, reducing the use_count to 1,
|
||||
# so that once torch finishes computation on the DLpack tensors, the memory can be freed.
|
||||
self._training_io_binding.clear_binding_outputs()
|
||||
# The following call clears the iobinding output, reducing the use_count to 1, so that once torch finishes computation
|
||||
# on the DLpack tensors, the memory can be freed.
|
||||
training_io_binding.clear_binding_outputs()
|
||||
return tuple(results)
|
||||
|
||||
return _ortmodule_io.populate_user_output_from_schema_and_outputs(
|
||||
|
|
@ -322,8 +321,6 @@ class ORTModule(torch.nn.Module):
|
|||
# Training model
|
||||
self._onnx_training = None
|
||||
self._training_session = None
|
||||
self._training_io_binding = None
|
||||
self._run_options = None
|
||||
|
||||
# Log level
|
||||
self._loglevel = getattr(logging, 'WARNING')
|
||||
|
|
@ -408,14 +405,6 @@ class ORTModule(torch.nn.Module):
|
|||
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
|
||||
session_options, providers, provider_options)
|
||||
|
||||
# Use this global run_options for now
|
||||
self._run_options = C.RunOptions()
|
||||
|
||||
# IO binding
|
||||
# TODO: Reuse output buffers as some of output tensors have same shape,
|
||||
# especially the backward graph outputs.
|
||||
self._training_io_binding = self._training_session.io_binding()
|
||||
|
||||
def _build_training_graph(self, *inputs, **kwargs):
|
||||
if self._use_static_shape:
|
||||
self._module_gradient_graph_builder.build(
|
||||
|
|
|
|||
11
orttraining/orttraining/python/training/runstateinfo.py
Normal file
11
orttraining/orttraining/python/training/runstateinfo.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
class RunStateInfo(object):
|
||||
def __init__(self, run_id, run_options, io_binding, output_info):
|
||||
self.run_id = run_id
|
||||
self.run_options = run_options
|
||||
self.io_binding = io_binding
|
||||
self.output_info = output_info
|
||||
|
|
@ -503,7 +503,7 @@ def test_gradient_correctness():
|
|||
x = torch.randn(N, D_in, device=device)
|
||||
pt_prediction = run_step(pt_model, x)
|
||||
ort_prediction = run_step(ort_model, x)
|
||||
|
||||
|
||||
assert torch.allclose(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
|
|
@ -545,6 +545,39 @@ def test_multiple_forward_only_calls():
|
|||
|
||||
assert torch.allclose(ort_prediction, pt_prediction)
|
||||
|
||||
def test_nesting_forward_backward_calls():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
# forward1
|
||||
ort_x1 = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
pt_x1 = copy.deepcopy(ort_x1)
|
||||
ort_prediction1 = ort_model(ort_x1)
|
||||
pt_prediction1 = pt_model(pt_x1)
|
||||
assert torch.allclose(ort_prediction1, pt_prediction1)
|
||||
ort_loss1 = ort_prediction1.sum()
|
||||
pt_loss1 = pt_prediction1.sum()
|
||||
# forward2
|
||||
ort_x2 = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
pt_x2 = copy.deepcopy(ort_x2)
|
||||
ort_prediction2 = ort_model(ort_x2)
|
||||
ort_loss2 = ort_prediction2.sum()
|
||||
pt_prediction2 = pt_model(pt_x2)
|
||||
pt_loss2 = pt_prediction2.sum()
|
||||
assert torch.allclose(ort_prediction2, pt_prediction2)
|
||||
# backward2
|
||||
ort_loss2.backward()
|
||||
pt_loss2.backward()
|
||||
assert torch.allclose(ort_x2.grad, ort_x2.grad)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
# backward1
|
||||
ort_loss1.backward()
|
||||
pt_loss1.backward()
|
||||
assert torch.allclose(ort_x1.grad, pt_x1.grad)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
def test_multiple_overlapping_forward_backward_calls():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
|
|
|
|||
Loading…
Reference in a new issue