Refactor device handling and basic support for PyTorch Lightning (#6758)

This commit is contained in:
Thiago Crepaldi 2021-02-24 14:12:55 -08:00 committed by GitHub
parent 65ba51d93e
commit aa5cd37ac8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 240 additions and 106 deletions

View file

@ -203,7 +203,6 @@ class ORTModule(torch.nn.Module):
# TODO: Single device support for now
self._device = _utils.get_device_from_module(module)
self._device_changed = False
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
@ -239,8 +238,10 @@ class ORTModule(torch.nn.Module):
self._torch_free = self._torch_cuda_allocator.cuda_caching_allocator_raw_delete_address()
def _initialize_module_gradient_graph_builder(self):
# TODO: PyTorch exporter bug: changes the initializer order
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [p[0] for p in self._original_module.named_parameters()]
onnx_initializer_names = [p.name for p in self._onnx_inference.graph.initializer]
initializer_names = [p for p in initializer_names if p in onnx_initializer_names]
# Build full training graph
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
@ -295,58 +296,6 @@ class ORTModule(torch.nn.Module):
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx')
def cpu(self: T) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if not self._device or self._device.type != 'cpu':
self._device_changed = True
self._device = torch.device('cpu')
return super(ORTModule, self).cpu()
def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if device is None:
if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'):
self._device_changed = True
self._device = torch.device(_utils.get_default_device_str('cuda'))
elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device):
self._device_changed = True
self._device = torch.device(_utils.get_device_str(device))
return super(ORTModule, self).cuda(device)
@overload
def to(self: T, device: Optional[Union[int, torch.device]] = ...,
dtype: Optional[Union[torch.dtype, str]] = ...,
non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, dtype: Union[torch.dtype, str], non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T:
...
def to(self, *args, **kwargs):
'''Thin layer to capture device for ORTModule IO bindings'''
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
if device:
try:
device_str = _utils.get_device_str(device)
if _utils.get_device_str(self._device) != device_str:
self._device_changed = True
self._device = torch.device(device_str)
except RuntimeError:
self._device_changed = True
self._device = torch.device(device_str)
return super(ORTModule, self).to(*args, **kwargs)
def eval(self: T) -> T:
self._is_training = False
self._original_module.eval()
@ -368,8 +317,9 @@ class ORTModule(torch.nn.Module):
# Exporting module to ONNX for the first time
if not self._onnx_training:
if not self._device:
self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs)
device_from_module = _utils.get_device_from_module(self._original_module)
if not self._device or self._device != device_from_module:
self._device = device_from_module
if not self._device:
raise RuntimeError('A device must be specified in the model or data!')
self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
@ -385,11 +335,12 @@ class ORTModule(torch.nn.Module):
self._current_input_shape = new_input_shape
self._build_training_graph()
self._create_training_session()
# TODO: disabled for now, since it caused a bug in NVBert fp32 run
# When creating a new InferenceSession, there is a bug for destructing the original InferenceSession
# elif self._device_changed:
# self._create_training_session()
# self._device_changed = False
module_device = _utils.get_device_from_module(self._original_module)
if self._device != module_device:
self._device = module_device
self._create_training_session()
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.

View file

@ -62,6 +62,19 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def run_ortmodule_torch_lightning(cwd, log, data_dir):
log.debug('Running: ORTModule PyTorch Lightning sample .')
command = [sys.executable, 'orttraining_test_ortmodule_torch_lightning_basic.py', '--train-steps=470',
'--epochs=2', '--batch-size=256']
if data_dir:
command.extend(['--data_dir', data_dir])
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir):
log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda))
@ -91,6 +104,9 @@ def main():
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, data_dir=args.bert_data)
# TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19)
# run_ortmodule_torch_lightning(cwd, log, args.args.mnist)
return 0

View file

@ -2,6 +2,7 @@
# Licensed under the MIT License.
# orttraining_test_ortmodule_api.py
import math
import torch
from transformers import AutoConfig, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
@ -231,93 +232,96 @@ def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_s
assert ortmodule_result == ortmodule_result_again
assert pytorch_result == ortmodule_result
def test_model_cuda():
def test_torch_nn_module_cuda_method():
original_device = 'cpu'
to_device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
model = ORTModule(model)
x = torch.randn(N, D_in, device=to_device)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
x = torch.randn(N, D_in, device=to_device)
model = model.cuda()
model(x)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == to_device
def test_model_cpu():
@pytest.mark.parametrize("set_gpu_on_original_module", [
True,
False
])
def test_torch_nn_module_cpu_method(set_gpu_on_original_module):
original_device = 'cuda'
to_device = 'cpu'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
model = ORTModule(model)
x = torch.randn(N, D_in)
if set_gpu_on_original_module:
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
model = ORTModule(model)
else:
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
model = ORTModule(model).to(original_device)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
x = torch.randn(N, D_in, device=to_device)
model = model.cpu()
model(x)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == to_device
@pytest.mark.parametrize("original_device, to_argument, requires_export, device_type, device_index", [
('cpu', torch.device('cuda'), True, 'cuda', 0),
('cpu', 'cuda', True, 'cuda', 0),
('cpu', 'cuda:0', True, 'cuda', 0),
('cpu', 'cuda', True, 'cuda', 0),
('cuda', 'cuda', False, 'cuda', 0),
('cuda', 'cuda:0', False, 'cuda', 0),
('cuda', torch.device('cuda'), False, 'cuda', 0),
('cuda', 'cpu', True, 'cpu', 0),
('cuda', torch.device('cpu'), True, 'cpu', 0),
('cpu', 'cpu', False, 'cpu', None),
('cpu', torch.device('cpu'), False, 'cpu', None),
('cpu', torch.zeros(2, device=torch.device('cuda')), True, 'cuda', 0),
@pytest.mark.parametrize("original_device, to_argument", [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cpu', 'cuda:0'),
('cpu', torch.device('cpu')),
('cpu', torch.device('cuda')),
('cuda', 'cuda'),
('cuda', 'cuda:0'),
('cuda', 'cpu'),
('cuda', torch.device('cuda')),
('cuda', torch.device('cpu')),
])
def test_model_to_device(original_device, to_argument, requires_export, device_type, device_index):
def test_torch_nn_module_to_api(original_device, to_argument):
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device_type)
x = torch.randn(N, D_in, device=original_device)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
model = model.to(to_argument)
assert model._device_changed == requires_export
assert model._device == torch.device(device_type+':'+str(device_index) if device_index is not None else device_type)
x = x.to(to_argument)
model(x)
assert _utils.get_device_str(model._device) == _utils.get_device_str(torch.device(to_argument))
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == device_type
@pytest.mark.parametrize("original_device, to_device", [
('cuda', 'cpu'),
('cpu', 'cuda')
])
def test_model_to_device_and_back_to_original(original_device, to_device):
def test_model_without_device():
# Model doesn't have device (CPU is assumed)
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
model = ORTModule(model)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
model = model.to(to_device)
assert model._device_changed == True
assert model._device == torch.device(to_device+':0')
# User input is on GPU
input_device='cuda'
x = torch.randn(N, D_in).to(input_device)
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == to_device
# ORTModule and PyTorch does not move model to where user input is hosted
with pytest.raises(RuntimeError) as type_error:
model(x)
assert "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(type_error.value)
model = model.to(original_device)
assert model._device_changed == True
assert model._device == torch.device(original_device+':0')
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
def test_model_and_input_without_device():
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
model = ORTModule(model)
x = torch.randn(N, D_in)
# CPU is assumed for both model and user input
out = model(x)
out is not None
# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed
# def test_model_with_different_devices_same_session():
@ -693,3 +697,36 @@ def test_register_custom_ops_pytorch_exporter_torch_triu():
output = model(user_input)
assert list(output.shape) == [1, 10, 10]
def test_wrap_ortmodule_and_change_device():
# Basic Sequencial model wrapping ORTModule
x = torch.linspace(-math.pi, math.pi, 2000)
xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3]))
y = torch.sin(x)
model = torch.nn.Sequential(
ORTModule(torch.nn.Linear(3, 1)),
torch.nn.Flatten(0, 1)
)
# Changing device for fun
model = model.cpu()
xx = xx.cpu()
y = y.cpu()
model = model.cuda()
xx = xx.cuda()
y = y.cuda()
# Quick train
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-6
for t in range(2000):
y_pred = model(xx)
loss = loss_fn(y_pred, y)
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# Checking training finished normally
assert y_pred is not None and loss is not None

View file

@ -0,0 +1,111 @@
import argparse
from multiprocessing import cpu_count
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import onnxruntime
from onnxruntime.training import ORTModule
class LitAutoEncoder(pl.LightningModule):
def __init__(self, lr, use_ortmodule=True):
super().__init__()
self.lr = lr
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)
if use_ortmodule:
self.encoder = ORTModule(self.encoder)
# TODO: Remove this comment below when multiple ORTModule instances is supported
# self.decoder = ORTModule(self.decoder)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
parser.add_argument('--epochs', type=int, default=5, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--data-dir', type=str, default='./mnist',
help='Path to the mnist data directory')
args = parser.parse_args()
# Common setup
torch.manual_seed(args.seed)
onnxruntime.set_seed(args.seed)
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Data loader
dataset = MNIST(args.data_dir, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset, num_workers=cpu_count(), batch_size=args.batch_size)
# Model architecture
autoencoder = LitAutoEncoder(lr=args.lr, use_ortmodule=not args.pytorch_only)
# Train loop
kwargs = {}
if device == 'cuda':
kwargs.update({'gpus': 1})
if args.train_steps > 0:
kwargs.update({'max_steps': args.train_steps})
if args.epochs > 0:
kwargs.update({'max_epochs': args.epochs})
trainer = pl.Trainer(**kwargs)
trainer.fit(autoencoder, train_loader)
if __name__ == '__main__':
main()

19
run_ortmodule_mvp_lightning.sh Executable file
View file

@ -0,0 +1,19 @@
#!/bin/bash
cur_dir=$(basename `pwd`)
if [[ ${cur_dir} != "RelWithDebInfo" ]]
then
echo "Going to build folder (aka build/Linux/RelWithDebInfo)"
cd build/Linux/RelWithDebInfo
fi
echo "Exporting PYTHONPATH to use build dir as onnxruntime package"
export PYTHONPATH=$(pwd)
echo "Copying PyTorch frontend source-code to build folder"
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
echo "Running Flexible API (ORTModule)"
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py --help
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py $@