mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Refactor device handling and basic support for PyTorch Lightning (#6758)
This commit is contained in:
parent
65ba51d93e
commit
aa5cd37ac8
7 changed files with 240 additions and 106 deletions
|
|
@ -203,7 +203,6 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# TODO: Single device support for now
|
||||
self._device = _utils.get_device_from_module(module)
|
||||
self._device_changed = False
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
|
|
@ -239,8 +238,10 @@ class ORTModule(torch.nn.Module):
|
|||
self._torch_free = self._torch_cuda_allocator.cuda_caching_allocator_raw_delete_address()
|
||||
|
||||
def _initialize_module_gradient_graph_builder(self):
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
onnx_initializer_names = [p.name for p in self._onnx_inference.graph.initializer]
|
||||
initializer_names = [p for p in initializer_names if p in onnx_initializer_names]
|
||||
|
||||
# Build full training graph
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
|
@ -295,58 +296,6 @@ class ORTModule(torch.nn.Module):
|
|||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx')
|
||||
|
||||
def cpu(self: T) -> T:
|
||||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
||||
if not self._device or self._device.type != 'cpu':
|
||||
self._device_changed = True
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
return super(ORTModule, self).cpu()
|
||||
|
||||
def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
|
||||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
||||
if device is None:
|
||||
if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'):
|
||||
self._device_changed = True
|
||||
self._device = torch.device(_utils.get_default_device_str('cuda'))
|
||||
elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device):
|
||||
self._device_changed = True
|
||||
self._device = torch.device(_utils.get_device_str(device))
|
||||
|
||||
return super(ORTModule, self).cuda(device)
|
||||
|
||||
@overload
|
||||
def to(self: T, device: Optional[Union[int, torch.device]] = ...,
|
||||
dtype: Optional[Union[torch.dtype, str]] = ...,
|
||||
non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, dtype: Union[torch.dtype, str], non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
||||
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device:
|
||||
try:
|
||||
device_str = _utils.get_device_str(device)
|
||||
if _utils.get_device_str(self._device) != device_str:
|
||||
self._device_changed = True
|
||||
self._device = torch.device(device_str)
|
||||
except RuntimeError:
|
||||
self._device_changed = True
|
||||
self._device = torch.device(device_str)
|
||||
|
||||
return super(ORTModule, self).to(*args, **kwargs)
|
||||
|
||||
def eval(self: T) -> T:
|
||||
self._is_training = False
|
||||
self._original_module.eval()
|
||||
|
|
@ -368,8 +317,9 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Exporting module to ONNX for the first time
|
||||
if not self._onnx_training:
|
||||
if not self._device:
|
||||
self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs)
|
||||
device_from_module = _utils.get_device_from_module(self._original_module)
|
||||
if not self._device or self._device != device_from_module:
|
||||
self._device = device_from_module
|
||||
if not self._device:
|
||||
raise RuntimeError('A device must be specified in the model or data!')
|
||||
self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
|
||||
|
|
@ -385,11 +335,12 @@ class ORTModule(torch.nn.Module):
|
|||
self._current_input_shape = new_input_shape
|
||||
self._build_training_graph()
|
||||
self._create_training_session()
|
||||
# TODO: disabled for now, since it caused a bug in NVBert fp32 run
|
||||
# When creating a new InferenceSession, there is a bug for destructing the original InferenceSession
|
||||
# elif self._device_changed:
|
||||
# self._create_training_session()
|
||||
# self._device_changed = False
|
||||
|
||||
module_device = _utils.get_device_from_module(self._original_module)
|
||||
if self._device != module_device:
|
||||
self._device = module_device
|
||||
self._create_training_session()
|
||||
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
|
|
|
|||
|
|
@ -62,6 +62,19 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
|
|||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
||||
def run_ortmodule_torch_lightning(cwd, log, data_dir):
|
||||
log.debug('Running: ORTModule PyTorch Lightning sample .')
|
||||
|
||||
command = [sys.executable, 'orttraining_test_ortmodule_torch_lightning_basic.py', '--train-steps=470',
|
||||
'--epochs=2', '--batch-size=256']
|
||||
|
||||
if data_dir:
|
||||
command.extend(['--data_dir', data_dir])
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
||||
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir):
|
||||
log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda))
|
||||
|
||||
|
|
@ -91,6 +104,9 @@ def main():
|
|||
|
||||
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, data_dir=args.bert_data)
|
||||
|
||||
# TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19)
|
||||
# run_ortmodule_torch_lightning(cwd, log, args.args.mnist)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# orttraining_test_ortmodule_api.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
|
@ -231,93 +232,96 @@ def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_s
|
|||
assert ortmodule_result == ortmodule_result_again
|
||||
assert pytorch_result == ortmodule_result
|
||||
|
||||
def test_model_cuda():
|
||||
def test_torch_nn_module_cuda_method():
|
||||
original_device = 'cpu'
|
||||
to_device = 'cuda'
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=to_device)
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
|
||||
x = torch.randn(N, D_in, device=to_device)
|
||||
model = model.cuda()
|
||||
model(x)
|
||||
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == to_device
|
||||
|
||||
def test_model_cpu():
|
||||
@pytest.mark.parametrize("set_gpu_on_original_module", [
|
||||
True,
|
||||
False
|
||||
])
|
||||
def test_torch_nn_module_cpu_method(set_gpu_on_original_module):
|
||||
original_device = 'cuda'
|
||||
to_device = 'cpu'
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in)
|
||||
if set_gpu_on_original_module:
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
|
||||
model = ORTModule(model)
|
||||
else:
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
model = ORTModule(model).to(original_device)
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
|
||||
x = torch.randn(N, D_in, device=to_device)
|
||||
model = model.cpu()
|
||||
model(x)
|
||||
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == to_device
|
||||
|
||||
@pytest.mark.parametrize("original_device, to_argument, requires_export, device_type, device_index", [
|
||||
('cpu', torch.device('cuda'), True, 'cuda', 0),
|
||||
('cpu', 'cuda', True, 'cuda', 0),
|
||||
('cpu', 'cuda:0', True, 'cuda', 0),
|
||||
('cpu', 'cuda', True, 'cuda', 0),
|
||||
('cuda', 'cuda', False, 'cuda', 0),
|
||||
('cuda', 'cuda:0', False, 'cuda', 0),
|
||||
('cuda', torch.device('cuda'), False, 'cuda', 0),
|
||||
('cuda', 'cpu', True, 'cpu', 0),
|
||||
('cuda', torch.device('cpu'), True, 'cpu', 0),
|
||||
('cpu', 'cpu', False, 'cpu', None),
|
||||
('cpu', torch.device('cpu'), False, 'cpu', None),
|
||||
('cpu', torch.zeros(2, device=torch.device('cuda')), True, 'cuda', 0),
|
||||
@pytest.mark.parametrize("original_device, to_argument", [
|
||||
('cpu', 'cpu'),
|
||||
('cpu', 'cuda'),
|
||||
('cpu', 'cuda:0'),
|
||||
('cpu', torch.device('cpu')),
|
||||
('cpu', torch.device('cuda')),
|
||||
('cuda', 'cuda'),
|
||||
('cuda', 'cuda:0'),
|
||||
('cuda', 'cpu'),
|
||||
('cuda', torch.device('cuda')),
|
||||
('cuda', torch.device('cpu')),
|
||||
])
|
||||
def test_model_to_device(original_device, to_argument, requires_export, device_type, device_index):
|
||||
def test_torch_nn_module_to_api(original_device, to_argument):
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device_type)
|
||||
x = torch.randn(N, D_in, device=original_device)
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
|
||||
model = model.to(to_argument)
|
||||
assert model._device_changed == requires_export
|
||||
assert model._device == torch.device(device_type+':'+str(device_index) if device_index is not None else device_type)
|
||||
x = x.to(to_argument)
|
||||
model(x)
|
||||
assert _utils.get_device_str(model._device) == _utils.get_device_str(torch.device(to_argument))
|
||||
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == device_type
|
||||
|
||||
@pytest.mark.parametrize("original_device, to_device", [
|
||||
('cuda', 'cpu'),
|
||||
('cpu', 'cuda')
|
||||
])
|
||||
def test_model_to_device_and_back_to_original(original_device, to_device):
|
||||
def test_model_without_device():
|
||||
# Model doesn't have device (CPU is assumed)
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device)
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
model = ORTModule(model)
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
|
||||
model = model.to(to_device)
|
||||
assert model._device_changed == True
|
||||
assert model._device == torch.device(to_device+':0')
|
||||
# User input is on GPU
|
||||
input_device='cuda'
|
||||
x = torch.randn(N, D_in).to(input_device)
|
||||
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == to_device
|
||||
# ORTModule and PyTorch does not move model to where user input is hosted
|
||||
with pytest.raises(RuntimeError) as type_error:
|
||||
model(x)
|
||||
assert "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(type_error.value)
|
||||
|
||||
model = model.to(original_device)
|
||||
assert model._device_changed == True
|
||||
assert model._device == torch.device(original_device+':0')
|
||||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
def test_model_and_input_without_device():
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in)
|
||||
|
||||
# CPU is assumed for both model and user input
|
||||
out = model(x)
|
||||
out is not None
|
||||
|
||||
# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed
|
||||
# def test_model_with_different_devices_same_session():
|
||||
|
|
@ -693,3 +697,36 @@ def test_register_custom_ops_pytorch_exporter_torch_triu():
|
|||
|
||||
output = model(user_input)
|
||||
assert list(output.shape) == [1, 10, 10]
|
||||
|
||||
def test_wrap_ortmodule_and_change_device():
|
||||
# Basic Sequencial model wrapping ORTModule
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3]))
|
||||
y = torch.sin(x)
|
||||
model = torch.nn.Sequential(
|
||||
ORTModule(torch.nn.Linear(3, 1)),
|
||||
torch.nn.Flatten(0, 1)
|
||||
)
|
||||
|
||||
# Changing device for fun
|
||||
model = model.cpu()
|
||||
xx = xx.cpu()
|
||||
y = y.cpu()
|
||||
model = model.cuda()
|
||||
xx = xx.cuda()
|
||||
y = y.cuda()
|
||||
|
||||
# Quick train
|
||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||
learning_rate = 1e-6
|
||||
for t in range(2000):
|
||||
y_pred = model(xx)
|
||||
loss = loss_fn(y_pred, y)
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
with torch.no_grad():
|
||||
for param in model.parameters():
|
||||
param -= learning_rate * param.grad
|
||||
|
||||
# Checking training finished normally
|
||||
assert y_pred is not None and loss is not None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
import argparse
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import ORTModule
|
||||
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
|
||||
def __init__(self, lr, use_ortmodule=True):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(28*28, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3)
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(3, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 28*28)
|
||||
)
|
||||
if use_ortmodule:
|
||||
self.encoder = ORTModule(self.encoder)
|
||||
# TODO: Remove this comment below when multiple ORTModule instances is supported
|
||||
# self.decoder = ORTModule(self.decoder)
|
||||
|
||||
def forward(self, x):
|
||||
# in lightning, forward defines the prediction/inference actions
|
||||
embedding = self.encoder(x)
|
||||
return embedding
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# training_step defined the train loop.
|
||||
# It is independent of forward
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
z = self.encoder(x)
|
||||
|
||||
x_hat = self.decoder(z)
|
||||
loss = F.mse_loss(x_hat, x)
|
||||
# Logging to TensorBoard by default
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
|
||||
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--epochs', type=int, default=5, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--data-dir', type=str, default='./mnist',
|
||||
help='Path to the mnist data directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Common setup
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Data loader
|
||||
dataset = MNIST(args.data_dir, download=True, transform=transforms.ToTensor())
|
||||
train_loader = DataLoader(dataset, num_workers=cpu_count(), batch_size=args.batch_size)
|
||||
|
||||
# Model architecture
|
||||
autoencoder = LitAutoEncoder(lr=args.lr, use_ortmodule=not args.pytorch_only)
|
||||
|
||||
# Train loop
|
||||
kwargs = {}
|
||||
if device == 'cuda':
|
||||
kwargs.update({'gpus': 1})
|
||||
if args.train_steps > 0:
|
||||
kwargs.update({'max_steps': args.train_steps})
|
||||
if args.epochs > 0:
|
||||
kwargs.update({'max_epochs': args.epochs})
|
||||
trainer = pl.Trainer(**kwargs)
|
||||
trainer.fit(autoencoder, train_loader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
19
run_ortmodule_mvp_lightning.sh
Executable file
19
run_ortmodule_mvp_lightning.sh
Executable file
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
|
||||
cur_dir=$(basename `pwd`)
|
||||
|
||||
if [[ ${cur_dir} != "RelWithDebInfo" ]]
|
||||
then
|
||||
echo "Going to build folder (aka build/Linux/RelWithDebInfo)"
|
||||
cd build/Linux/RelWithDebInfo
|
||||
fi
|
||||
|
||||
echo "Exporting PYTHONPATH to use build dir as onnxruntime package"
|
||||
export PYTHONPATH=$(pwd)
|
||||
|
||||
echo "Copying PyTorch frontend source-code to build folder"
|
||||
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
|
||||
|
||||
echo "Running Flexible API (ORTModule)"
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py --help
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py $@
|
||||
Loading…
Reference in a new issue