From aa5cd37ac84638660d7b5c99f47f32a05dca9d41 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 24 Feb 2021 14:12:55 -0800 Subject: [PATCH] Refactor device handling and basic support for PyTorch Lightning (#6758) --- .../orttraining/python/training/ortmodule.py | 73 ++-------- .../python/orttraining_ortmodule_tests.py | 16 +++ .../python/orttraining_test_ortmodule_api.py | 127 +++++++++++------- ...ng_test_ortmodule_torch_lightning_basic.py | 111 +++++++++++++++ run_ortmodule_mvp_lightning.sh | 19 +++ ...e_mvp_mnist.sh => run_ortmodule_mvp_poc.sh | 0 ...d.sh => run_ortmodule_mvp_poc_deepspeed.sh | 0 7 files changed, 240 insertions(+), 106 deletions(-) create mode 100644 orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py create mode 100755 run_ortmodule_mvp_lightning.sh rename run_ortmodule_mvp_mnist.sh => run_ortmodule_mvp_poc.sh (100%) rename run_ortmodule_mvp_mnist_deepspeed.sh => run_ortmodule_mvp_poc_deepspeed.sh (100%) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 78d668bc14..5ce580e758 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -203,7 +203,6 @@ class ORTModule(torch.nn.Module): # TODO: Single device support for now self._device = _utils.get_device_from_module(module) - self._device_changed = False # User module is wrapped to use its initializers and save computed gradients self._original_module = module @@ -239,8 +238,10 @@ class ORTModule(torch.nn.Module): self._torch_free = self._torch_cuda_allocator.cuda_caching_allocator_raw_delete_address() def _initialize_module_gradient_graph_builder(self): - # TODO: PyTorch exporter bug: changes the initializer order + # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [p[0] for p in self._original_module.named_parameters()] + onnx_initializer_names = [p.name for p in self._onnx_inference.graph.initializer] + initializer_names = [p for p in initializer_names if p in onnx_initializer_names] # Build full training graph grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() @@ -295,58 +296,6 @@ class ORTModule(torch.nn.Module): if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx') - def cpu(self: T) -> T: - '''Thin layer to capture device for ORTModule IO bindings''' - - if not self._device or self._device.type != 'cpu': - self._device_changed = True - self._device = torch.device('cpu') - - return super(ORTModule, self).cpu() - - def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: - '''Thin layer to capture device for ORTModule IO bindings''' - - if device is None: - if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'): - self._device_changed = True - self._device = torch.device(_utils.get_default_device_str('cuda')) - elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device): - self._device_changed = True - self._device = torch.device(_utils.get_device_str(device)) - - return super(ORTModule, self).cuda(device) - - @overload - def to(self: T, device: Optional[Union[int, torch.device]] = ..., - dtype: Optional[Union[torch.dtype, str]] = ..., - non_blocking: bool = ...) -> T: - ... - - @overload - def to(self: T, dtype: Union[torch.dtype, str], non_blocking: bool = ...) -> T: - ... - - @overload - def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: - ... - - def to(self, *args, **kwargs): - '''Thin layer to capture device for ORTModule IO bindings''' - - device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) - if device: - try: - device_str = _utils.get_device_str(device) - if _utils.get_device_str(self._device) != device_str: - self._device_changed = True - self._device = torch.device(device_str) - except RuntimeError: - self._device_changed = True - self._device = torch.device(device_str) - - return super(ORTModule, self).to(*args, **kwargs) - def eval(self: T) -> T: self._is_training = False self._original_module.eval() @@ -368,8 +317,9 @@ class ORTModule(torch.nn.Module): # Exporting module to ONNX for the first time if not self._onnx_training: - if not self._device: - self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs) + device_from_module = _utils.get_device_from_module(self._original_module) + if not self._device or self._device != device_from_module: + self._device = device_from_module if not self._device: raise RuntimeError('A device must be specified in the model or data!') self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs) @@ -385,11 +335,12 @@ class ORTModule(torch.nn.Module): self._current_input_shape = new_input_shape self._build_training_graph() self._create_training_session() - # TODO: disabled for now, since it caused a bug in NVBert fp32 run - # When creating a new InferenceSession, there is a bug for destructing the original InferenceSession - # elif self._device_changed: - # self._create_training_session() - # self._device_changed = False + + module_device = _utils.get_device_from_module(self._original_module) + if self._device != module_device: + self._device = module_device + self._create_training_session() + # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 3f6a268550..48d2f531e0 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -62,6 +62,19 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir): run_subprocess(command, cwd=cwd, log=log).check_returncode() + +def run_ortmodule_torch_lightning(cwd, log, data_dir): + log.debug('Running: ORTModule PyTorch Lightning sample .') + + command = [sys.executable, 'orttraining_test_ortmodule_torch_lightning_basic.py', '--train-steps=470', + '--epochs=2', '--batch-size=256'] + + if data_dir: + command.extend(['--data_dir', data_dir]) + + run_subprocess(command, cwd=cwd, log=log).check_returncode() + + def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir): log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda)) @@ -91,6 +104,9 @@ def main(): run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, data_dir=args.bert_data) + # TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19) + # run_ortmodule_torch_lightning(cwd, log, args.args.mnist) + return 0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 931ef8a57f..ce0279dd6c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # orttraining_test_ortmodule_api.py +import math import torch from transformers import AutoConfig, BertForSequenceClassification from transformers.modeling_outputs import SequenceClassifierOutput @@ -231,93 +232,96 @@ def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_s assert ortmodule_result == ortmodule_result_again assert pytorch_result == ortmodule_result -def test_model_cuda(): +def test_torch_nn_module_cuda_method(): original_device = 'cpu' to_device = 'cuda' N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out) model = ORTModule(model) - x = torch.randn(N, D_in, device=to_device) for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == original_device + x = torch.randn(N, D_in, device=to_device) model = model.cuda() model(x) for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == to_device -def test_model_cpu(): +@pytest.mark.parametrize("set_gpu_on_original_module", [ + True, + False + ]) +def test_torch_nn_module_cpu_method(set_gpu_on_original_module): original_device = 'cuda' to_device = 'cpu' N, D_in, H, D_out = 64, 784, 500, 10 - model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device) - model = ORTModule(model) - x = torch.randn(N, D_in) + if set_gpu_on_original_module: + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device) + model = ORTModule(model) + else: + model = NeuralNetSinglePositionalArgument(D_in, H, D_out) + model = ORTModule(model).to(original_device) for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == original_device + x = torch.randn(N, D_in, device=to_device) model = model.cpu() model(x) - for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == to_device -@pytest.mark.parametrize("original_device, to_argument, requires_export, device_type, device_index", [ - ('cpu', torch.device('cuda'), True, 'cuda', 0), - ('cpu', 'cuda', True, 'cuda', 0), - ('cpu', 'cuda:0', True, 'cuda', 0), - ('cpu', 'cuda', True, 'cuda', 0), - ('cuda', 'cuda', False, 'cuda', 0), - ('cuda', 'cuda:0', False, 'cuda', 0), - ('cuda', torch.device('cuda'), False, 'cuda', 0), - ('cuda', 'cpu', True, 'cpu', 0), - ('cuda', torch.device('cpu'), True, 'cpu', 0), - ('cpu', 'cpu', False, 'cpu', None), - ('cpu', torch.device('cpu'), False, 'cpu', None), - ('cpu', torch.zeros(2, device=torch.device('cuda')), True, 'cuda', 0), +@pytest.mark.parametrize("original_device, to_argument", [ + ('cpu', 'cpu'), + ('cpu', 'cuda'), + ('cpu', 'cuda:0'), + ('cpu', torch.device('cpu')), + ('cpu', torch.device('cuda')), + ('cuda', 'cuda'), + ('cuda', 'cuda:0'), + ('cuda', 'cpu'), + ('cuda', torch.device('cuda')), + ('cuda', torch.device('cpu')), ]) -def test_model_to_device(original_device, to_argument, requires_export, device_type, device_index): +def test_torch_nn_module_to_api(original_device, to_argument): N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device) model = ORTModule(model) - x = torch.randn(N, D_in, device=device_type) + x = torch.randn(N, D_in, device=original_device) for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == original_device model = model.to(to_argument) - assert model._device_changed == requires_export - assert model._device == torch.device(device_type+':'+str(device_index) if device_index is not None else device_type) + x = x.to(to_argument) model(x) + assert _utils.get_device_str(model._device) == _utils.get_device_str(torch.device(to_argument)) - for _, parameter_value in model.named_parameters(): - assert parameter_value.device.type == device_type - -@pytest.mark.parametrize("original_device, to_device", [ - ('cuda', 'cpu'), - ('cpu', 'cuda') - ]) -def test_model_to_device_and_back_to_original(original_device, to_device): +def test_model_without_device(): + # Model doesn't have device (CPU is assumed) N, D_in, H, D_out = 64, 784, 500, 10 - model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device) + model = NeuralNetSinglePositionalArgument(D_in, H, D_out) model = ORTModule(model) - for _, parameter_value in model.named_parameters(): - assert parameter_value.device.type == original_device - model = model.to(to_device) - assert model._device_changed == True - assert model._device == torch.device(to_device+':0') + # User input is on GPU + input_device='cuda' + x = torch.randn(N, D_in).to(input_device) - for _, parameter_value in model.named_parameters(): - assert parameter_value.device.type == to_device + # ORTModule and PyTorch does not move model to where user input is hosted + with pytest.raises(RuntimeError) as type_error: + model(x) + assert "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(type_error.value) - model = model.to(original_device) - assert model._device_changed == True - assert model._device == torch.device(original_device+':0') - for _, parameter_value in model.named_parameters(): - assert parameter_value.device.type == original_device +def test_model_and_input_without_device(): + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out) + model = ORTModule(model) + x = torch.randn(N, D_in) + + # CPU is assumed for both model and user input + out = model(x) + out is not None # TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed # def test_model_with_different_devices_same_session(): @@ -693,3 +697,36 @@ def test_register_custom_ops_pytorch_exporter_torch_triu(): output = model(user_input) assert list(output.shape) == [1, 10, 10] + +def test_wrap_ortmodule_and_change_device(): + # Basic Sequencial model wrapping ORTModule + x = torch.linspace(-math.pi, math.pi, 2000) + xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) + y = torch.sin(x) + model = torch.nn.Sequential( + ORTModule(torch.nn.Linear(3, 1)), + torch.nn.Flatten(0, 1) + ) + + # Changing device for fun + model = model.cpu() + xx = xx.cpu() + y = y.cpu() + model = model.cuda() + xx = xx.cuda() + y = y.cuda() + + # Quick train + loss_fn = torch.nn.MSELoss(reduction='sum') + learning_rate = 1e-6 + for t in range(2000): + y_pred = model(xx) + loss = loss_fn(y_pred, y) + model.zero_grad() + loss.backward() + with torch.no_grad(): + for param in model.parameters(): + param -= learning_rate * param.grad + + # Checking training finished normally + assert y_pred is not None and loss is not None diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py new file mode 100644 index 0000000000..7c3b41acc2 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py @@ -0,0 +1,111 @@ +import argparse +from multiprocessing import cpu_count + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import transforms +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader +import pytorch_lightning as pl + +import onnxruntime +from onnxruntime.training import ORTModule + + +class LitAutoEncoder(pl.LightningModule): + + def __init__(self, lr, use_ortmodule=True): + super().__init__() + self.lr = lr + self.encoder = nn.Sequential( + nn.Linear(28*28, 64), + nn.ReLU(), + nn.Linear(64, 3) + ) + self.decoder = nn.Sequential( + nn.Linear(3, 64), + nn.ReLU(), + nn.Linear(64, 28*28) + ) + if use_ortmodule: + self.encoder = ORTModule(self.encoder) + # TODO: Remove this comment below when multiple ORTModule instances is supported + # self.decoder = ORTModule(self.decoder) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + # training_step defined the train loop. + # It is independent of forward + x, y = batch + x = x.view(x.size(0), -1) + + z = self.encoder(x) + + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + # Logging to TensorBoard by default + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--train-steps', type=int, default=-1, metavar='N', + help='number of steps to train. Set -1 to run through whole dataset (default: -1)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') + parser.add_argument('--pytorch-only', action='store_true', default=False, + help='disables ONNX Runtime training') + parser.add_argument('--epochs', type=int, default=5, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--data-dir', type=str, default='./mnist', + help='Path to the mnist data directory') + + args = parser.parse_args() + + # Common setup + torch.manual_seed(args.seed) + onnxruntime.set_seed(args.seed) + + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # Data loader + dataset = MNIST(args.data_dir, download=True, transform=transforms.ToTensor()) + train_loader = DataLoader(dataset, num_workers=cpu_count(), batch_size=args.batch_size) + + # Model architecture + autoencoder = LitAutoEncoder(lr=args.lr, use_ortmodule=not args.pytorch_only) + + # Train loop + kwargs = {} + if device == 'cuda': + kwargs.update({'gpus': 1}) + if args.train_steps > 0: + kwargs.update({'max_steps': args.train_steps}) + if args.epochs > 0: + kwargs.update({'max_epochs': args.epochs}) + trainer = pl.Trainer(**kwargs) + trainer.fit(autoencoder, train_loader) + + +if __name__ == '__main__': + main() diff --git a/run_ortmodule_mvp_lightning.sh b/run_ortmodule_mvp_lightning.sh new file mode 100755 index 0000000000..809d8d507a --- /dev/null +++ b/run_ortmodule_mvp_lightning.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +cur_dir=$(basename `pwd`) + +if [[ ${cur_dir} != "RelWithDebInfo" ]] +then + echo "Going to build folder (aka build/Linux/RelWithDebInfo)" + cd build/Linux/RelWithDebInfo +fi + +echo "Exporting PYTHONPATH to use build dir as onnxruntime package" +export PYTHONPATH=$(pwd) + +echo "Copying PyTorch frontend source-code to build folder" +cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ + +echo "Running Flexible API (ORTModule)" +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py --help +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py $@ diff --git a/run_ortmodule_mvp_mnist.sh b/run_ortmodule_mvp_poc.sh similarity index 100% rename from run_ortmodule_mvp_mnist.sh rename to run_ortmodule_mvp_poc.sh diff --git a/run_ortmodule_mvp_mnist_deepspeed.sh b/run_ortmodule_mvp_poc_deepspeed.sh similarity index 100% rename from run_ortmodule_mvp_mnist_deepspeed.sh rename to run_ortmodule_mvp_poc_deepspeed.sh