onnxruntime/orttraining/orttraining/python/training/ortmodule.py
Sherlock a98c2ebb8c
Enable saving optimized models in OrtModule (#7214)
* Enable saving optimized models in OrtModule

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
2021-04-02 12:37:05 -07:00

507 lines
25 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from . import _utils
from . import _ortmodule_output_transformation as _ortmodule_io
from onnxruntime.training import register_custom_ops_pytorch_exporter
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi import _pybind_state as C
import functools
import io
import logging
import onnx
import onnxruntime
import torch
import inspect
from inspect import signature
from enum import IntEnum
from torch.utils.dlpack import from_dlpack, to_dlpack
from torch.utils.cpp_extension import load_inline
# Needed to override PyTorch methods
from typing import TypeVar
T = TypeVar('T', bound='Module')
ONNX_OPSET_VERSION = 12
def _ortvalue_to_torch_tensor(ortvalue):
# PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
# and convert the config to torch.uint8 tensor duing from_dlpack().
# So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor.
torch_tensor = from_dlpack(ortvalue._ortvalue.to_dlpack())
return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor
def _ortvalue_from_torch_tensor(torch_tensor):
return OrtValue(C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool))
class Verbosity(IntEnum):
VERBOSE = 0
INFO = 1
WARNING = 2
ERROR = 3
FATAL = 4
def _create_iobinding(io_binding, inputs, model, device):
'''Creates IO binding for a `model` inputs and output'''
for idx, value_info in enumerate(model.graph.input):
io_binding.bind_ortvalue_input(
value_info.name, _ortvalue_from_torch_tensor(inputs[idx]))
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type,
device_id=_utils.get_device_index(device))
def _check_same_device(device, argument_str, *args):
'''Check that all tensor arguments in *args reside on the same device as the input device'''
for arg in args:
if arg is not None and isinstance(arg, torch.Tensor):
arg_device = torch.device(arg.device)
if arg_device != device:
raise RuntimeError(
f"{argument_str} found on device {arg_device}, but expected it to be on module device {device}.")
def _load_torch_allocator_cpp_extension(verbosity):
torch_cuda_allocator_addresses_cpp_source = """
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
size_t cuda_caching_allocator_raw_alloc_address() {
return reinterpret_cast<size_t>(&c10::cuda::CUDACachingAllocator::raw_alloc);
}
size_t cuda_caching_allocator_raw_delete_address() {
return reinterpret_cast<size_t>(&c10::cuda::CUDACachingAllocator::raw_delete);
}
"""
return load_inline(name='inline_extension', cpp_sources=[torch_cuda_allocator_addresses_cpp_source],
functions=['cuda_caching_allocator_raw_alloc_address',
'cuda_caching_allocator_raw_delete_address'],
verbose=verbosity < Verbosity.WARNING, with_cuda=True)
class ORTModule(torch.nn.Module):
def __init__(self, module):
assert isinstance(
module, torch.nn.Module), "'module' must be a torch.nn.Module"
# Create forward dynamically, so each ORTModule instance will have its own copy.
# This is needed to be able to copy the forward signatures from the original PyTorch models
# and possibly have different signatures for different instances.
def _forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
ONNX model is exported the first time this method is executed.
Next, we build a full training graph with module_gradient_graph_builder.
Finally, we instantiate the ONNX Runtime InferenceSession.
'''
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation later.
# TODO: If the model is being executed with the gradient disabled (inside torch.no_grad() context for example),
# leverage pytorch model for now.
if not self._is_training():
return self._original_module(*inputs, **kwargs)
# Exporting module to ONNX for the first time
if not self._onnx_training:
device_from_module = _utils.get_device_from_module(
self._original_module)
if not self._device or self._device != device_from_module:
self._device = device_from_module
if not self._device:
raise RuntimeError(
'A device must be specified in the model or data!')
self._get_inference_graph_and_init_gradient_graph_builder(
*inputs, **kwargs)
# Flag to indicate whether the gradient_graph needs to be built
build_gradient_graph = self._current_input_shape is None
_, _, input_names_require_grad, new_input_shape = \
_ortmodule_io.parse_inputs_for_onnx_export(
self._original_module_parameters, self._onnx_inference, *inputs, **kwargs)
initializer_names_to_train_set_user_model = {name for name, param in
self._flattened_output_module.named_parameters() if param.requires_grad}
initializer_names_to_train_set_onnx_graph = set(self._onnx_graphs_info.initializer_names_to_train) \
if self._onnx_graphs_info else None
# If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
if input_names_require_grad != self._input_names_require_grad or \
initializer_names_to_train_set_user_model != initializer_names_to_train_set_onnx_graph:
self._input_names_require_grad = input_names_require_grad
self._initialize_module_gradient_graph_builder()
# Trigger the rebuilding of the gradient graph
build_gradient_graph = True
if build_gradient_graph:
self._current_input_shape = new_input_shape
self._build_training_graph()
self._create_training_session()
module_device = _utils.get_device_from_module(
self._original_module)
if self._device != module_device:
self._device = module_device
self._create_training_session()
class _ORTModuleFunction(torch.autograd.Function):
'''Use a custom torch.autograd.Function to associate self.backward_graph as the
gradient implementation for self.forward_graph.'''
@staticmethod
def forward(ctx, *inputs, **kwargs):
'''Performs forward pass based on user input and PyTorch initializer
Autograd Function's apply() doesn't support keyword arguments,
so `*inputs` has all the arguments - keyword arguments converted
to positional by the caller.
Module outputs are returned to the user
'''
# Assert that the input and model device match
_check_same_device(
self._device, "Input argument to forward", *inputs)
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
# especially the backward graph outputs.
training_io_binding = self._training_session.io_binding()
run_options = C.RunOptions()
# Use IO binding
_create_iobinding(training_io_binding, inputs, self._onnx_training, self._device)
# Run and return module outputs.
forward_outputs, run_id = self._training_session.run_forward(training_io_binding, run_options)
user_outputs = tuple(_ortvalue_to_torch_tensor(
forward_output) for forward_output in forward_outputs)
# Disable materializing grads then None object will not be converted to a tensor filled with zeros prior to calling backward.
# Also save shape, device and type info to ctx for materializing tensor in backward if output grad is None.
ctx.set_materialize_grads(False)
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
ctx.run_info = onnxruntime.training.RunStateInfo(run_id, run_options, training_io_binding, output_info)
# Assert that the outputs and model device match
_check_same_device(
self._device, "Output argument from forward", *user_outputs)
return user_outputs
@staticmethod
def backward(ctx, *grad_outputs):
'''Performs backward pass based on grad wrt module output
'''
assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
# Assert that the grad_outputs and model device match
_check_same_device(
self._device, "Input argument to backward", *grad_outputs)
# Use IO binding
# Push user output grads to ONNX backend.
contiguous_grad_outputs = []
for idx, grad_output in enumerate(grad_outputs):
if idx in self._onnx_graphs_info.output_grad_indices_non_differentiable:
assert grad_output is None, "ORT found the {}-th module output '{}' is non-differentiable according to the onnx graph. " \
"However, the gradient value is still provided by torch's autograd engine." \
.format(idx, self._onnx_graphs_info.user_output_names[idx])
continue
if grad_output is None:
shape, device, dtype = ctx.run_info.output_info[idx]
if idx in self._onnx_graphs_info.output_grad_indices_require_full_shape:
grad_output = torch.zeros(
shape, device=device, dtype=dtype)
else:
grad_output = torch.tensor(
0., device=device, dtype=dtype)
elif not grad_output.is_contiguous():
grad_output = grad_output.contiguous()
contiguous_grad_outputs.append(grad_output)
backward_grad_output_ortvalue = [_ortvalue_from_torch_tensor(
grad_output) for grad_output in contiguous_grad_outputs]
# Run and get results
run_id = ctx.run_info.run_id
training_io_binding = ctx.run_info.io_binding
self._training_session.run_backward(backward_grad_output_ortvalue, run_id)
backward_outputs = training_io_binding.get_outputs()
# Return input and initializer gradients
num_user_input_grads = len(self._input_names_require_grad)
results = []
for input_name in self._onnx_graphs_info.user_input_names:
try:
# Append to the results the backward output for each input that required grad
results.append(_ortvalue_to_torch_tensor(
backward_outputs[self._input_names_require_grad.index(input_name)]))
except ValueError:
# input_name is not found in the self._input_names_require_grad list
# Append None to results for each input that did not require grad
results.append(None)
# Append gradients of initializer to results
# Go over each initializer, check if it required grad and append to results accordingly
initializer_names_to_train_set = set(self._onnx_graphs_info.initializer_names_to_train) \
if self._onnx_graphs_info else None
initializer_index = num_user_input_grads
for initializer_name in self._onnx_graphs_info.initializer_names:
if initializer_name in initializer_names_to_train_set:
results.append(_ortvalue_to_torch_tensor(backward_outputs[initializer_index]))
initializer_index += 1
else:
results.append(None)
# The OrtValue has a shared_ptr to the data.
# At this point there are two shared_ptrs to the data, one through the
# OrtValue in the output iobinding, and the other through the copy in OrtDLManagedTensor.
# The following call clears the iobinding output, reducing the use_count to 1, so that once torch finishes computation
# on the DLpack tensors, the memory can be freed.
training_io_binding.clear_binding_outputs()
return tuple(results)
return _ortmodule_io.populate_user_output_from_schema_and_outputs(
self._original_module_output_schema,
self._onnx_graphs_info.user_output_names,
_ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*inputs, **kwargs)))
# Bind the forward method.
self.forward = _forward.__get__(self)
# Copy the forward signature from the PyTorch module.
functools.update_wrapper(
self.forward.__func__, module.forward.__func__)
super(ORTModule, self).__init__()
# Verbosity for logging
self._verbosity = Verbosity.WARNING
# Support contrib OPs
register_custom_ops_pytorch_exporter.register_custom_op()
# TODO: Single device support for now
self._device = _utils.get_device_from_module(module)
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
# Get the module that flattens the output from the original module into a tuple
self._flattened_output_module = \
_ortmodule_io.get_flattened_output_module(
self._original_module)
self._original_module_parameters = signature(
self._original_module.forward).parameters.values()
# TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters.
for input_parameter in self._original_module_parameters:
if input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
raise NotImplementedError(
"The model's forward method has **kwargs parameter which is currently not supported.")
self._onnx_inference = None
# Related to training graph shape inference
self._current_input_shape = None
# default execution order is priority-based for both dynamic/static shape input for now
# if we observe benefit of static shape, we can expose this flag to user
self._use_static_shape = False
self._module_gradient_graph_builder = None
self._input_names_require_grad = None
self._original_module_output_schema = None
self._onnx_graphs_info = None
# Training model
self._onnx_training = None
self._training_session = None
# Log level
self._loglevel = getattr(logging, 'WARNING')
# Debug flags
self._save_onnx = False
self._save_onnx_prefix = ''
from torch.utils.cpp_extension import ROCM_HOME
self.is_rocm_pytorch = (True if (
(torch.version.hip is not None) and (ROCM_HOME is not None)) else False)
# CPP extension to get torch CUDA allocator's alloc and free function addresses
# Disable external allocator for ROCM EP since external allocator is not supported yet.
self._use_external_cuda_allocator = (
False if self.is_rocm_pytorch else True)
if self._use_external_cuda_allocator:
self._torch_cuda_allocator = _load_torch_allocator_cpp_extension(
self._verbosity)
self._torch_alloc = self._torch_cuda_allocator.cuda_caching_allocator_raw_alloc_address()
self._torch_free = self._torch_cuda_allocator.cuda_caching_allocator_raw_delete_address()
def _is_training(self):
return self._flattened_output_module.training and torch.is_grad_enabled()
def _initialize_module_gradient_graph_builder(self):
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [name
for name, _ in self._flattened_output_module.named_parameters()]
initializer_names_to_train = []
if self._is_training():
initializer_names_to_train = [name
for name, param in self._flattened_output_module.named_parameters() if param.requires_grad]
# Build full training graph
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
grad_builder_config.initializer_names = initializer_names
grad_builder_config.initializer_names_to_train = initializer_names_to_train
grad_builder_config.input_names_require_grad = self._input_names_require_grad
self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
self._module_gradient_graph_builder.initialize(
self._onnx_inference.SerializeToString(), grad_builder_config)
def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
self._onnx_inference = self._get_inference_graph(*inputs, **kwargs)
if self._save_onnx:
onnx.save(self._onnx_inference, self._save_onnx_prefix + '_inference.onnx')
self._initialize_module_gradient_graph_builder()
def _create_training_session(self):
providers = None
provider_options = None
if self._device.type == 'cuda':
# Configure the InferenceSessions to use the specific GPU on which the model is placed.
providers = (["ROCMExecutionProvider"] if self.is_rocm_pytorch else [
"CUDAExecutionProvider"])
providers.append("CPUExecutionProvider")
if self._use_external_cuda_allocator:
provider_options = [{"device_id": str(self._device.index), "cuda_external_alloc": str(
self._torch_alloc), "cuda_external_free": str(self._torch_free)}, {}]
else:
provider_options = [{"device_id": str(self._device.index)}, {}]
elif self._device.type == 'cpu':
providers = ["CPUExecutionProvider"]
provider_options = [{}]
session_options = onnxruntime.SessionOptions()
session_options.enable_mem_pattern = False
session_options.use_deterministic_compute = False
# default to PRIORITY_BASED execution order
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = int(self._verbosity)
# enable dumping optimized training graph
if self._save_onnx:
session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx'
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
session_options, providers, provider_options)
def _build_training_graph(self, *inputs, **kwargs):
if self._use_static_shape:
self._module_gradient_graph_builder.build(
self._current_input_shape)
else:
self._module_gradient_graph_builder.build()
self._onnx_training = onnx.load_model_from_string(
self._module_gradient_graph_builder.get_training_model())
self._onnx_graphs_info = self._module_gradient_graph_builder.get_training_graph_info()
if self._save_onnx:
inference_optimized_model = onnx.load_model_from_string(
self._module_gradient_graph_builder.get_inference_optimized_model())
onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx')
onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx')
def eval(self: T) -> T:
self._flattened_output_module.eval()
def train(self: T, mode: bool = True) -> T:
self._flattened_output_module.train(mode)
def _convert_training_graph_input_to_list(self, *inputs, **kwargs):
'''Creates forward `*inputs` list from user input and PyTorch initializers
TODO: How IO binding model inputs and outputs affects initializer copies?
ONNX Runtime forward requires an ordered list of:
* User input: computed from forward InferenceSession
* Initializers: computed from original PyTorch model parameters
'''
# User inputs
non_none_inputs = [inp for inp in inputs if inp is not None]
named_buffers_iter = iter(self._flattened_output_module.named_buffers())
result = []
for input_idx, name in enumerate(self._onnx_graphs_info.user_input_names):
inp = None
if input_idx < len(non_none_inputs):
inp = non_none_inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
elif input_idx >= len(non_none_inputs):
# Registered buffers are translated to user_input+initializer in ONNX
# TODO: Check what happens when the number of inputs change form one call to the next
buffer_name, inp = next(named_buffers_iter)
assert buffer_name == name, f'Input name {name} expected, but {buffer_name} found!'
if inp is not None:
result.append(inp)
else:
# TODO: Re-export ONNX if any input from _onnx_graphs_info.user_input_names is None.
raise RuntimeError(
f'Input is present in ONNX graph but not provided: {name}.')
# Initializers
for param in self._flattened_output_module.named_parameters():
result.append(param[1])
return result
def _get_inference_graph(self, *inputs, **kwargs):
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
TODO: How to support dynamic axes? Dimensions are determined by samples
'''
# Setup dynamic axes for onnx model
input_names, dynamic_axes, self._input_names_require_grad, _ = \
_ortmodule_io.parse_inputs_for_onnx_export(
self._original_module_parameters, None, *inputs, **kwargs)
output_names, output_dynamic_axes, self._original_module_output_schema = \
_ortmodule_io.parse_outputs_for_onnx_export_and_extract_output_schema(
self._original_module, inputs, kwargs)
dynamic_axes.update(output_dynamic_axes)
# Export torch.nn.Module to ONNX
f = io.BytesIO()
# Deepcopy inputs, since input values may change after model run.
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
# Therefore, deepcopy only the data component of the input tensors for export.
sample_inputs_copy, sample_kwargs_copy = \
_ortmodule_io.deepcopy_model_input(
*inputs, **kwargs)
try:
with torch.no_grad():
torch.onnx.export(self._flattened_output_module,
sample_inputs_copy + (sample_kwargs_copy, ),
f,
input_names=input_names,
output_names=output_names,
opset_version=ONNX_OPSET_VERSION,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
verbose=self._verbosity < Verbosity.WARNING,
export_params=False,
keep_initializers_as_inputs=True)
except RuntimeError as e:
raise RuntimeError(
'There was an error while exporting the PyTorch model to ONNX: {}'.format(e))
return onnx.load_model_from_string(f.getvalue())