Handle multiple devices scenarios (#6672)

* Handle multiple devices scenarios
This commit is contained in:
Thiago Crepaldi 2021-02-16 18:22:30 -08:00 committed by GitHub
parent 7ee5baa60d
commit 7f33671ade
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 173 additions and 62 deletions

View file

@ -26,14 +26,13 @@ def timeit(enabled=True):
return inner if enabled else noop_inner
def get_device_index(device):
'''Returns device index from a device'''
if type(device) == str:
# Could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
elif isinstance(device, int):
return device
return 0 if device.index is None else device.index
def get_device_index_from_input(input):
'''Returns device index from a input PyTorch Tensor'''
@ -43,6 +42,53 @@ def get_device_index_from_input(input):
device_index = get_device_index(input.device)
return device_index
def get_device_from_input_args_kwargs(*args, **kwargs):
'''Returns device index from first PyTorch Tensor within *args or **kwargs'''
device = None
if args:
device = torch.device(args[0].device)
if not device and kwargs:
device = torch.device(next(iter(kwargs.values())).device)
return device
def get_device_from_module(module):
'''Returns the first device found in the `module`'s parameters or None'''
device = None
try:
device = next(module.parameters()).device
for param in module.parameters():
if param.device != device:
raise RuntimeError('ORTModule supports a single device per model for now')
except StopIteration:
# Model doesn't have a device set to any of the model parameters
pass
return device
def get_device_str(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
if device.find(':') == -1:
device += ':' + str(torch.cuda.current_device())
elif isinstance(device, int):
device = 'cuda:' + str(device)
elif isinstance(device, torch.device):
if device.index is None:
device = device.type + ':' + str(torch.cuda.current_device())
else:
device = device.type + ':' + str(device.index)
else:
raise RuntimeError('Unsupported device type')
return device
def get_default_device_str(type):
if isinstance(type, str):
if type == 'cuda':
return 'cuda:' + str(torch.cuda.current_device())
else:
return 'cpu'
else:
raise RuntimeError('Unsupported device type')
def get_all_gradients_finite_name_from_session(session):
'''Find all_gradients_finite node on Session graph and return its name'''

View file

@ -26,48 +26,18 @@ __TEMP_ENABLE_METHOD_TIMING__ = False
T = TypeVar('T', bound='Module')
def _get_device_index(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
elif isinstance(device, int):
return device
return 0 if device.index is None else device.index
def _get_device_str(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
if device.find(':') == -1:
device += ':' + str(torch.cuda.current_device())
elif isinstance(device, int):
device = 'cuda:' + str(device)
elif isinstance(device, torch.device):
if device.index is None:
device = device.type + ':' + str(torch.cuda.current_device())
else:
device = device.type + ':' + str(device.index)
else:
raise ('Unsupported device type')
return device
def _get_default_device_str(type):
if type == 'cuda':
return 'cuda:' + str(torch.cuda.current_device())
else:
return 'cpu'
def _create_iobinding(io_binding, inputs, model, device):
'''Creates IO binding for a `model` inputs and output'''
for idx, value_info in enumerate(model.graph.input):
io_binding.bind_input(value_info.name, inputs[idx].device.type,
_get_device_index(inputs[idx].device),
_utils.get_device_index(inputs[idx].device),
_utils.dtype_torch_to_numpy(inputs[idx].dtype),
list(inputs[idx].size()),
inputs[idx].data_ptr())
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type,
device_id=_get_device_index(device))
device_id=_utils.get_device_index(device))
def _deepcopy_model_input(*inputs, **kwargs):
sample_inputs_copy = []
@ -165,8 +135,8 @@ class ORTModule(torch.nn.Module):
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
super(ORTModule, self).__init__()
# TODO: This is incorrect when different layers may be in different devices
self._device = next(module.parameters()).device
# TODO: Single device support for now
self._device = _utils.get_device_from_module(module)
self._device_changed = False
# User module is wrapped to use its initializers and save computed gradients
@ -204,8 +174,6 @@ class ORTModule(torch.nn.Module):
self._module_gradient_graph_builder.initialize(self._onnx_inference.SerializeToString(), grad_builder_config)
def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
input_names, dynamic_axes, self._input_names_require_grad = \
_parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
self._onnx_inference = self._get_inference_graph(*inputs, **kwargs)
if self._save_onnx:
@ -251,7 +219,7 @@ class ORTModule(torch.nn.Module):
def cpu(self: T) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
if self._device.type != 'cpu':
if not self._device or self._device.type != 'cpu':
self._device_changed = True
self._device = torch.device('cpu')
@ -261,12 +229,12 @@ class ORTModule(torch.nn.Module):
'''Thin layer to capture device for ORTModule IO bindings'''
if device is None:
if _get_device_str(self._device) != _get_default_device_str('cuda'):
if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'):
self._device_changed = True
self._device = torch.device(_get_default_device_str('cuda'))
elif _get_device_str(self._device) != _get_device_str(device):
self._device = torch.device(_utils.get_default_device_str('cuda'))
elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device):
self._device_changed = True
self._device = torch.device(_get_device_str(device))
self._device = torch.device(_utils.get_device_str(device))
return super(ORTModule, self).cuda(device)
@ -289,10 +257,15 @@ class ORTModule(torch.nn.Module):
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
if device:
device_str = _get_device_str(device)
if _get_device_str(self._device) != device_str:
try:
device_str = _utils.get_device_str(device)
if _utils.get_device_str(self._device) != device_str:
self._device_changed = True
self._device = torch.device(device_str)
except RuntimeError:
self._device_changed = True
self._device = torch.device(device_str)
return super(ORTModule, self).to(*args, **kwargs)
def eval(self: T) -> T:
@ -315,7 +288,11 @@ class ORTModule(torch.nn.Module):
return self._original_module(*inputs, **kwargs)
# Exporting module to ONNX for the first time
if not self._onnx_inference:
if not self._onnx_training:
if not self._device:
self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs)
if not self._device:
raise RuntimeError('A device must be specified in the model or data!')
self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
_, _, input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
@ -366,7 +343,7 @@ class ORTModule(torch.nn.Module):
backward_grad_output_ortvalue = []
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
grad_output.dtype), grad_output.device.type, _get_device_index(grad_output.device), grad_output.data_ptr()))
grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr()))
# Run and get results
self._training_session.run_backward(backward_grad_output_ortvalue)
@ -441,16 +418,19 @@ class ORTModule(torch.nn.Module):
# Therefore, deepcopy only the data component of the input tensors for export.
sample_inputs_copy = _deepcopy_model_input(*inputs, **kwargs)
with torch.no_grad():
torch.onnx.export(self._original_module,
sample_inputs_copy,
f,
input_names=input_names,
output_names=output_names,
opset_version=ONNX_OPSET_VERSION,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes)
try:
with torch.no_grad():
torch.onnx.export(self._original_module,
sample_inputs_copy,
f,
input_names=input_names,
output_names=output_names,
opset_version=ONNX_OPSET_VERSION,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes)
except RuntimeError as e:
raise RuntimeError('There was an error while exporting the PyTorch model to ONNX: {}'.format(e))
# TODO: this step might not be needed when we use the torch external allocator
# clear cache after model export

View file

@ -5,12 +5,13 @@
import torch
from transformers import AutoConfig, BertForSequenceClassification
import pytest
import warnings
from unittest.mock import patch
import onnxruntime
from onnxruntime.training import ORTModule
from onnxruntime.training import _utils, ORTModule
import _test_helpers
# PyTorch model definitions for tests
class NeuralNetSinglePositionalArgument(torch.nn.Module):
@ -446,3 +447,87 @@ def test_dynamic_axes_config():
output = model_with_no_grad(x, y, None, None, None, None, z)
assert output is not None
assert _test_helpers.is_dynamic_axes(model_with_no_grad)
def test_model_with_multiple_devices_cpu_cuda():
class MultipleDeviceModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10).cpu()
self.fc2 = torch.nn.Linear(10, 10).cuda()
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MultipleDeviceModel()
with pytest.raises(RuntimeError) as e:
model = ORTModule(model)
assert str(e.value) == 'ORTModule supports a single device per model for now'
def test_model_with_multiple_devices_to_to():
class MultipleDeviceModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10).to('cpu')
self.fc2 = torch.nn.Linear(10, 10).to('cuda')
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MultipleDeviceModel()
with pytest.raises(RuntimeError) as e:
model = ORTModule(model)
assert str(e.value) == 'ORTModule supports a single device per model for now'
def test_model_with_multiple_devices_to_cpu():
class MultipleDeviceModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10).to('cuda')
self.fc2 = torch.nn.Linear(10, 10).cpu()
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MultipleDeviceModel()
with pytest.raises(RuntimeError) as e:
model = ORTModule(model)
assert str(e.value) == 'ORTModule supports a single device per model for now'
def test_model_with_multiple_devices_to_cuda():
class MultipleDeviceModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10).to('cpu')
self.fc2 = torch.nn.Linear(10, 10).cuda()
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MultipleDeviceModel()
with pytest.raises(RuntimeError) as e:
model = ORTModule(model)
assert str(e.value) == 'ORTModule supports a single device per model for now'
@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2'])
def test_model_with_different_cuda_devices(device):
# Trick to run this test in single GPU machines
device_id = _utils.get_device_index(device)
if device_id >= torch.cuda.device_count():
warnings.warn('Skipping test_model_with_different_cuda_devices(cuda:{})'.format(device_id))
return
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
model = ORTModule(model)
model.to(device)
x = torch.randn(N, D_in, device=device)
model(x)