Refactor IObinding

This commit is contained in:
Thiago Crepaldi 2020-11-16 16:30:00 -08:00
parent 39ac95b2fc
commit f13c2a61d5
6 changed files with 141 additions and 303 deletions

View file

@ -9,50 +9,59 @@ import warnings
import numpy as np
from inspect import signature
# Needed to re-implement PyTorch's cpu,cuda,to methods
from torch import Tensor, device, dtype
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from onnxruntime.capi import _pybind_state as C
from . import _utils
ONNX_OPSET_VERSION = 12
# Needed to re-implement PyTorch's cpu,cuda,to methods
T = TypeVar('T', bound='Module')
def get_device_index(device):
if type(device) == str:
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
return 0 if device.index is None else device.index
def _create_iobinding(io_binding, inputs, model, output_buffers, device):
'''Creates IO binding for a `model` inputs and output'''
def _get_device_index(device):
if type(device) == str:
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
return 0 if device.index is None else device.index
def prepare_io_binding(io_binding, inputs, model, output_buffers, device):
idx = 0
for value_info in model.graph.input:
io_binding.bind_input(value_info.name, inputs[idx].device.type, get_device_index(inputs[idx].device),
_utils.dtype_torch_to_numpy(inputs[idx].dtype), list(inputs[idx].size()),
for idx, value_info in enumerate(model.graph.input):
io_binding.bind_input(value_info.name, inputs[idx].device.type,
_get_device_index(inputs[idx].device),
_utils.dtype_torch_to_numpy(inputs[idx].dtype),
list(inputs[idx].size()),
inputs[idx].data_ptr())
idx += 1
for value_info in model.graph.output:
name = value_info.name
output_tensor = output_buffers[name]
io_binding.bind_output(name, output_tensor.device.type, get_device_index(device),
_utils.dtype_torch_to_numpy(output_tensor.dtype), list(output_tensor.size()),
io_binding.bind_output(name, output_tensor.device.type,
_get_device_index(device),
_utils.dtype_torch_to_numpy(output_tensor.dtype),
list(output_tensor.size()),
output_tensor.data_ptr())
def _onnx_value_info_to_buffer_tensor(value_info, device):
'''Create a torch zeroed tensor with the same shape and type of `value_info`'''
def value_info_to_buffer_tensor(value_info, device):
shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type)
return torch.zeros(shape, device=device, dtype=dtype)
class ORTModule(torch.nn.Module):
def __init__(self, module, device="cpu", use_iobinding=False):
def __init__(self, module):
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
super(ORTModule, self).__init__()
self._device = device
self._use_iobinding = use_iobinding
self._export_again = False
# TODO: This is incorrect when different layers may be in different devices
self._device = next(module.parameters()).device
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
@ -62,18 +71,78 @@ class ORTModule(torch.nn.Module):
# Forward pass
self._onnx_forward = None
self._forward_session = None
self._forward_io_binding = None
self._forward_output_buffers = {}
# Backward pass
self._onnx_backward = None
self._backward_session = None
self._backward_io_binding = None
self._backward_output_buffers = {}
# Log level
self._loglevel = getattr(logging, 'WARNING')
# TODO: debug flags
# Debug flags
self._save_onnx = False
self._save_onnx_prefix = ''
def cpu(self: T) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if self._device != 'cpu':
self._require_export = True
self._device = 'cpu'
return super(ORTModule, self).cpu()
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if device:
device = str(device)
else:
device = 'cuda'
if self._device != str(device):
self._require_export = True
self._device = device
return super(ORTModule, self).cuda(device)
@overload
def to(self: T, device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if device:
device = str(device)
else:
device = None
if self._device != str(device) and device is not None:
self._require_export = True
self._device = device
return super(ORTModule, self).to(device, dtype, non_blocking)
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
# TODO: Should we do anything?
self._require_export = False
return super(ORTModule, self).to(dtype, non_blocking)
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
# TODO: Long shot by sending model to tensor's device
device = None
if tensor:
device = str(tensor.device)
if self._device != str(device) and device is not None:
self._require_export = True
self._device = device
return super(ORTModule, self).to(tensor, non_blocking)
def to(self, *args, **kwargs):
'''Thin layer to capture device for ORTModule IO bindings'''
# TODO: Should we do anything?
self._require_export = False
return super(ORTModule, self).to(args, kwargs)
def forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
@ -82,13 +151,15 @@ class ORTModule(torch.nn.Module):
Next, a full training graph is splitted in forward and backward graph which are used
to instantiate ONNX Runtime InferenceSession`s
'''
if not self._onnx_forward:
if not self._onnx_forward or self._require_export:
self._require_export = False
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
# TODO: PyTorch exporter bug: changes the initializer order
# Use the order in original module
initializer_names = [p[0] for p in self._original_module.named_parameters()]
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names)
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names)
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
@ -99,19 +170,18 @@ class ORTModule(torch.nn.Module):
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
execution_providers = ['CPUExecutionProvider'] if self._device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=execution_providers)
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=execution_providers)
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
if self._use_iobinding:
self._forward_io_binding = self._forward_session.io_binding()
self._forward_output_buffers = {}
for output in self._onnx_forward.graph.output:
self._forward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device)
self._backward_io_binding = self._backward_session.io_binding()
self._backward_output_buffers = {}
for output in self._onnx_backward.graph.output:
self._backward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device)
# IO binding
self._forward_io_binding = self._forward_session.io_binding()
self._forward_output_buffers = {}
for output in self._onnx_forward.graph.output:
self._forward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device))
self._backward_io_binding = self._backward_session.io_binding()
self._backward_output_buffers = {}
for output in self._onnx_backward.graph.output:
self._backward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device))
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
@ -129,42 +199,26 @@ class ORTModule(torch.nn.Module):
* Intermediate tensors
'''
if not self._use_iobinding:
# Convert input to dict of torch tensors
data_dict = self._convert_forward_input_list_to_dict(*inputs)
# Use IO binding
_create_iobinding(self._forward_io_binding, inputs,
self._onnx_forward,
self._forward_output_buffers,
str(self._device))
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict)
# Feed forward
outputs, intermediate = self._run_forward_graph(data_dict_numpy)
outputs = tuple(torch.from_numpy(item) for item in outputs)
# Save input, initializers and intermediate tensors to be used during backward
user_input = self._onnx_graphs_info.user_input_names
backward_user_input = self._onnx_graphs_info.backward_user_input_names
ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input)
forward_initializer = self._onnx_graphs_info.initializer_names_to_train
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer)
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
# TODO: Support original module output (currently dict is not supported)
if len(outputs) == 1:
return outputs[0]
return outputs
# Use IO binding.
prepare_io_binding(self._forward_io_binding, inputs, self._onnx_forward, self._forward_output_buffers, self._device)
# Run
self._forward_session.run_with_iobinding(self._forward_io_binding)
# Stash tensors needed by backward
forward_input_dict = self._convert_forward_input_list_to_dict(*inputs)
ctx_inputs = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_user_input_names)
ctx_initializers = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_intializer_names_as_input)
ctx_intermediates = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.intermediate_tensor_names)
ctx_inputs = tuple(forward_input_dict[name] \
for name in self._onnx_graphs_info.backward_user_input_names)
ctx_initializers = tuple(forward_input_dict[name] \
for name in self._onnx_graphs_info.backward_intializer_names_as_input)
ctx_intermediates = tuple(self._forward_output_buffers[name] \
for name in self._onnx_graphs_info.intermediate_tensor_names)
ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates])
# Return model output
outputs = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.user_output_names)
return outputs[0] if len(outputs) == 1 else outputs
@ -178,22 +232,22 @@ class ORTModule(torch.nn.Module):
TODO: Input gradient is hard-coded to torch.tensor([1.])
'''
if not self._use_iobinding:
saved_tensors = ctx.saved_tensors
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names)
result += [torch.from_numpy(grad) for grad in grad_weights]
return tuple(result)
# Use IO binding.
# Use IO binding
grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output))
backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names)
prepare_io_binding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output], self._onnx_backward, self._backward_output_buffers, self._device)
_create_iobinding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output],
self._onnx_backward,
self._backward_output_buffers,
str(self._device))
# Run
self._backward_session.run_with_iobinding(self._backward_io_binding)
# Return input and initializer gradients
results = [torch.tensor([1])] * len(self._onnx_graphs_info.user_input_names)
results += [self._backward_output_buffers[name] for name in self._onnx_graphs_info.initializer_grad_names_to_train]
results += [self._backward_output_buffers[name] \
for name in self._onnx_graphs_info.initializer_grad_names_to_train]
return tuple(results)
proc_inputs = [data for data in inputs if data is not None]
@ -203,6 +257,7 @@ class ORTModule(torch.nn.Module):
'''Creates forward `*inputs` list from user input and PyTorch initializers
TODO: **kwargs is not supported
TODO: How IO binding model inputs and outputs affects initializer copies?
ONNX Runtime forward requires an order list of:
* User input: computed from forward InferenceSession
@ -224,18 +279,6 @@ class ORTModule(torch.nn.Module):
return result
def _convert_dict_torch_to_numpy(self, tensor_dict):
'''Convert `tensor_dict` PyTorch tensors to numpy tensors
This is a ONNX Runtime requirement
TODO: #UseIOBinding
'''
result = {}
for k,v in tensor_dict.items():
result.update({k : v.detach().cpu().numpy()})
return result
def _convert_forward_input_list_to_dict(self, *inputs):
'''Convert forward `*inputs` list to dict
@ -302,34 +345,6 @@ class ORTModule(torch.nn.Module):
return result
def _run_forward_graph(self, inputs):
'''Execute forward pass on ONNX Runtime
Output order has to be specified to ONNX Runtime backend
to distinguish intermediate from output tensors
'''
forward_output = self._forward_session.run([*self._onnx_graphs_info.user_output_names,
*self._onnx_graphs_info.intermediate_tensor_names], inputs)
output = forward_output[:len(self._onnx_graphs_info.user_output_names)]
intermediates = forward_output[len(self._onnx_graphs_info.user_output_names):]
return output, intermediates
def _run_backward_graph(self, *inputs, **kwargs):
'''Execute backward pass on ONNX Runtime
`*inputs` is converted from list to a list of detached numpy tensors before
being fed to an ONNX Runtime InferenceSession
TODO: **kwargs are not supported
'''
# Convert input to dict of torch tensors
data = self._convert_backward_input_list_to_dict(*inputs)
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
data = self._convert_dict_torch_to_numpy(data)
return self._backward_session.run(self._onnx_graphs_info.initializer_grad_names_to_train, data)
@staticmethod
def _get_forward_graph(module, *inputs, **kwargs):

View file

@ -1,161 +0,0 @@
import argparse
import logging
import torch
from torchvision import datasets, transforms
import onnxruntime
from onnxruntime.training import ORTModule
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
model.train()
for iteration, (data, target) in enumerate(train_loader):
if iteration == args.train_steps:
break
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
optimizer.zero_grad()
if args.pytorch_only:
probability = model(data)
else:
probability = model(data)
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters())))
pytorch_backward_graph.view()
loss = loss_fn(probability, target)
loss.backward()
optimizer.step()
# Stats
if iteration % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, iteration * len(data), len(train_loader.dataset),
100. * iteration / len(train_loader), loss))
def test(args, model, device, loss_fn, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
output = model(data)
# Stats
test_loss += loss_fn(output, target, False).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def my_loss(x, target, is_train=True):
if is_train:
return torch.nn.CrossEntropyLoss()(x, target)
else:
return torch.nn.CrossEntropyLoss(reduction='sum')(x, target)
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
help='input batch size for training (default: 20)')
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
help='input batch size for testing (default: 20)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--use_iobinding', action='store_true', default=False,
help='use IO binding')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--view-graphs', action='store_true', default=False,
help='views forward and backward graphs')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
help='Log level (default: WARNING)')
args = parser.parse_args()
# Common setup
torch.manual_seed(args.seed)
onnxruntime.set_seed(args.seed)
# TODO: CUDA support is broken due to copying from PyTorch into ORT
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# device = 'cpu'
## Data loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size,
shuffle=True)
if args.test_batch_size > 0:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.test_batch_size, shuffle=True)
# Model architecture
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
if not args.pytorch_only:
print('Training MNIST on ORTModule....')
model = ORTModule(model, device, args.use_iobinding)
# TODO: change it to False to stop saving ONNX models
model._save_onnx = True
model._save_onnx_prefix = 'MNIST'
# Set log level
numeric_level = getattr(logging, args.log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % args.log_level)
logging.basicConfig(level=numeric_level)
else:
print('Training MNIST on vanilla PyTorch....')
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
# Train loop
for epoch in range(1, args.epochs + 1):
train(args, model, device, optimizer, my_loss, train_loader, epoch)
if args.test_batch_size > 0:
test(args, model, device, my_loss, test_loader)
if __name__ == '__main__':
main()

View file

@ -1,3 +1,5 @@
import pdb
import argparse
import logging
import torch
@ -111,17 +113,15 @@ def main():
onnxruntime.set_seed(args.seed)
# TODO: CUDA support is broken due to copying from PyTorch into ORT
# if not args.no_cuda and torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
device = 'cpu'
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
## Data loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size,
shuffle=True)
if args.test_batch_size > 0:

View file

@ -15,4 +15,5 @@ echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --no-cuda --epochs 4 --log-interval 20 --log-level=DEBUG
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --help
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py $@

View file

@ -15,4 +15,5 @@ echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py --epochs 10 --log-interval 100 --log-level=DEBUG
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py --help
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py $@

View file

@ -1,18 +0,0 @@
#!/bin/bash
cur_dir=$(basename `pwd`)
if [[ ${cur_dir} != "RelWithDebInfo" ]]
then
echo "Going to build folder (aka build/Linux/RelWithDebInfo)"
cd build/Linux/RelWithDebInfo
fi
echo "Exporting PYTHONPATH to use build dir as onnxruntime package"
export PYTHONPATH=$(pwd)
echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py --epochs 10 --log-interval 100 --log-level=DEBUG --use_iobinding