diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index f79c72a9c8..f6fbb93f0a 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -26,14 +26,13 @@ def timeit(enabled=True): return inner if enabled else noop_inner def get_device_index(device): - '''Returns device index from a device''' - - if type(device) == str: - # Could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + if isinstance(device, str): + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 device = torch.device(device) + elif isinstance(device, int): + return device return 0 if device.index is None else device.index - def get_device_index_from_input(input): '''Returns device index from a input PyTorch Tensor''' @@ -43,6 +42,53 @@ def get_device_index_from_input(input): device_index = get_device_index(input.device) return device_index +def get_device_from_input_args_kwargs(*args, **kwargs): + '''Returns device index from first PyTorch Tensor within *args or **kwargs''' + + device = None + if args: + device = torch.device(args[0].device) + if not device and kwargs: + device = torch.device(next(iter(kwargs.values())).device) + return device + +def get_device_from_module(module): + '''Returns the first device found in the `module`'s parameters or None''' + device = None + try: + device = next(module.parameters()).device + for param in module.parameters(): + if param.device != device: + raise RuntimeError('ORTModule supports a single device per model for now') + except StopIteration: + # Model doesn't have a device set to any of the model parameters + pass + return device + +def get_device_str(device): + if isinstance(device, str): + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + if device.find(':') == -1: + device += ':' + str(torch.cuda.current_device()) + elif isinstance(device, int): + device = 'cuda:' + str(device) + elif isinstance(device, torch.device): + if device.index is None: + device = device.type + ':' + str(torch.cuda.current_device()) + else: + device = device.type + ':' + str(device.index) + else: + raise RuntimeError('Unsupported device type') + return device + +def get_default_device_str(type): + if isinstance(type, str): + if type == 'cuda': + return 'cuda:' + str(torch.cuda.current_device()) + else: + return 'cpu' + else: + raise RuntimeError('Unsupported device type') def get_all_gradients_finite_name_from_session(session): '''Find all_gradients_finite node on Session graph and return its name''' diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index b9e3fe5824..98be8ea092 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -26,48 +26,18 @@ __TEMP_ENABLE_METHOD_TIMING__ = False T = TypeVar('T', bound='Module') -def _get_device_index(device): - if isinstance(device, str): - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - elif isinstance(device, int): - return device - return 0 if device.index is None else device.index - -def _get_device_str(device): - if isinstance(device, str): - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - if device.find(':') == -1: - device += ':' + str(torch.cuda.current_device()) - elif isinstance(device, int): - device = 'cuda:' + str(device) - elif isinstance(device, torch.device): - if device.index is None: - device = device.type + ':' + str(torch.cuda.current_device()) - else: - device = device.type + ':' + str(device.index) - else: - raise ('Unsupported device type') - return device - -def _get_default_device_str(type): - if type == 'cuda': - return 'cuda:' + str(torch.cuda.current_device()) - else: - return 'cpu' - def _create_iobinding(io_binding, inputs, model, device): '''Creates IO binding for a `model` inputs and output''' for idx, value_info in enumerate(model.graph.input): io_binding.bind_input(value_info.name, inputs[idx].device.type, - _get_device_index(inputs[idx].device), + _utils.get_device_index(inputs[idx].device), _utils.dtype_torch_to_numpy(inputs[idx].dtype), list(inputs[idx].size()), inputs[idx].data_ptr()) for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, - device_id=_get_device_index(device)) + device_id=_utils.get_device_index(device)) def _deepcopy_model_input(*inputs, **kwargs): sample_inputs_copy = [] @@ -165,8 +135,8 @@ class ORTModule(torch.nn.Module): assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module" super(ORTModule, self).__init__() - # TODO: This is incorrect when different layers may be in different devices - self._device = next(module.parameters()).device + # TODO: Single device support for now + self._device = _utils.get_device_from_module(module) self._device_changed = False # User module is wrapped to use its initializers and save computed gradients @@ -204,8 +174,6 @@ class ORTModule(torch.nn.Module): self._module_gradient_graph_builder.initialize(self._onnx_inference.SerializeToString(), grad_builder_config) def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs): - input_names, dynamic_axes, self._input_names_require_grad = \ - _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs) self._onnx_inference = self._get_inference_graph(*inputs, **kwargs) if self._save_onnx: @@ -251,7 +219,7 @@ class ORTModule(torch.nn.Module): def cpu(self: T) -> T: '''Thin layer to capture device for ORTModule IO bindings''' - if self._device.type != 'cpu': + if not self._device or self._device.type != 'cpu': self._device_changed = True self._device = torch.device('cpu') @@ -261,12 +229,12 @@ class ORTModule(torch.nn.Module): '''Thin layer to capture device for ORTModule IO bindings''' if device is None: - if _get_device_str(self._device) != _get_default_device_str('cuda'): + if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'): self._device_changed = True - self._device = torch.device(_get_default_device_str('cuda')) - elif _get_device_str(self._device) != _get_device_str(device): + self._device = torch.device(_utils.get_default_device_str('cuda')) + elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device): self._device_changed = True - self._device = torch.device(_get_device_str(device)) + self._device = torch.device(_utils.get_device_str(device)) return super(ORTModule, self).cuda(device) @@ -289,10 +257,15 @@ class ORTModule(torch.nn.Module): device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) if device: - device_str = _get_device_str(device) - if _get_device_str(self._device) != device_str: + try: + device_str = _utils.get_device_str(device) + if _utils.get_device_str(self._device) != device_str: + self._device_changed = True + self._device = torch.device(device_str) + except RuntimeError: self._device_changed = True self._device = torch.device(device_str) + return super(ORTModule, self).to(*args, **kwargs) def eval(self: T) -> T: @@ -315,7 +288,11 @@ class ORTModule(torch.nn.Module): return self._original_module(*inputs, **kwargs) # Exporting module to ONNX for the first time - if not self._onnx_inference: + if not self._onnx_training: + if not self._device: + self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs) + if not self._device: + raise RuntimeError('A device must be specified in the model or data!') self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs) _, _, input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs) @@ -366,7 +343,7 @@ class ORTModule(torch.nn.Module): backward_grad_output_ortvalue = [] for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]: backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy( - grad_output.dtype), grad_output.device.type, _get_device_index(grad_output.device), grad_output.data_ptr())) + grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr())) # Run and get results self._training_session.run_backward(backward_grad_output_ortvalue) @@ -441,16 +418,19 @@ class ORTModule(torch.nn.Module): # Therefore, deepcopy only the data component of the input tensors for export. sample_inputs_copy = _deepcopy_model_input(*inputs, **kwargs) - with torch.no_grad(): - torch.onnx.export(self._original_module, - sample_inputs_copy, - f, - input_names=input_names, - output_names=output_names, - opset_version=ONNX_OPSET_VERSION, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING, - dynamic_axes=dynamic_axes) + try: + with torch.no_grad(): + torch.onnx.export(self._original_module, + sample_inputs_copy, + f, + input_names=input_names, + output_names=output_names, + opset_version=ONNX_OPSET_VERSION, + do_constant_folding=False, + training=torch.onnx.TrainingMode.TRAINING, + dynamic_axes=dynamic_axes) + except RuntimeError as e: + raise RuntimeError('There was an error while exporting the PyTorch model to ONNX: {}'.format(e)) # TODO: this step might not be needed when we use the torch external allocator # clear cache after model export diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ca19a8ddf3..346a888306 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5,12 +5,13 @@ import torch from transformers import AutoConfig, BertForSequenceClassification import pytest +import warnings from unittest.mock import patch -import onnxruntime -from onnxruntime.training import ORTModule +from onnxruntime.training import _utils, ORTModule import _test_helpers + # PyTorch model definitions for tests class NeuralNetSinglePositionalArgument(torch.nn.Module): @@ -446,3 +447,87 @@ def test_dynamic_axes_config(): output = model_with_no_grad(x, y, None, None, None, None, z) assert output is not None assert _test_helpers.is_dynamic_axes(model_with_no_grad) + +def test_model_with_multiple_devices_cpu_cuda(): + class MultipleDeviceModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).cpu() + self.fc2 = torch.nn.Linear(10, 10).cuda() + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = MultipleDeviceModel() + with pytest.raises(RuntimeError) as e: + model = ORTModule(model) + assert str(e.value) == 'ORTModule supports a single device per model for now' + +def test_model_with_multiple_devices_to_to(): + class MultipleDeviceModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).to('cpu') + self.fc2 = torch.nn.Linear(10, 10).to('cuda') + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = MultipleDeviceModel() + with pytest.raises(RuntimeError) as e: + model = ORTModule(model) + assert str(e.value) == 'ORTModule supports a single device per model for now' + +def test_model_with_multiple_devices_to_cpu(): + class MultipleDeviceModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).to('cuda') + self.fc2 = torch.nn.Linear(10, 10).cpu() + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = MultipleDeviceModel() + with pytest.raises(RuntimeError) as e: + model = ORTModule(model) + assert str(e.value) == 'ORTModule supports a single device per model for now' + +def test_model_with_multiple_devices_to_cuda(): + class MultipleDeviceModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).to('cpu') + self.fc2 = torch.nn.Linear(10, 10).cuda() + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = MultipleDeviceModel() + with pytest.raises(RuntimeError) as e: + model = ORTModule(model) + assert str(e.value) == 'ORTModule supports a single device per model for now' + +@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2']) +def test_model_with_different_cuda_devices(device): + + # Trick to run this test in single GPU machines + device_id = _utils.get_device_index(device) + if device_id >= torch.cuda.device_count(): + warnings.warn('Skipping test_model_with_different_cuda_devices(cuda:{})'.format(device_id)) + return + + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + model = ORTModule(model) + model.to(device) + x = torch.randn(N, D_in, device=device) + model(x)