mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Handle multiple devices scenarios (#6672)
* Handle multiple devices scenarios
This commit is contained in:
parent
7ee5baa60d
commit
7f33671ade
3 changed files with 173 additions and 62 deletions
|
|
@ -26,14 +26,13 @@ def timeit(enabled=True):
|
|||
return inner if enabled else noop_inner
|
||||
|
||||
def get_device_index(device):
|
||||
'''Returns device index from a device'''
|
||||
|
||||
if type(device) == str:
|
||||
# Could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
if isinstance(device, str):
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
return device
|
||||
return 0 if device.index is None else device.index
|
||||
|
||||
|
||||
def get_device_index_from_input(input):
|
||||
'''Returns device index from a input PyTorch Tensor'''
|
||||
|
||||
|
|
@ -43,6 +42,53 @@ def get_device_index_from_input(input):
|
|||
device_index = get_device_index(input.device)
|
||||
return device_index
|
||||
|
||||
def get_device_from_input_args_kwargs(*args, **kwargs):
|
||||
'''Returns device index from first PyTorch Tensor within *args or **kwargs'''
|
||||
|
||||
device = None
|
||||
if args:
|
||||
device = torch.device(args[0].device)
|
||||
if not device and kwargs:
|
||||
device = torch.device(next(iter(kwargs.values())).device)
|
||||
return device
|
||||
|
||||
def get_device_from_module(module):
|
||||
'''Returns the first device found in the `module`'s parameters or None'''
|
||||
device = None
|
||||
try:
|
||||
device = next(module.parameters()).device
|
||||
for param in module.parameters():
|
||||
if param.device != device:
|
||||
raise RuntimeError('ORTModule supports a single device per model for now')
|
||||
except StopIteration:
|
||||
# Model doesn't have a device set to any of the model parameters
|
||||
pass
|
||||
return device
|
||||
|
||||
def get_device_str(device):
|
||||
if isinstance(device, str):
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
if device.find(':') == -1:
|
||||
device += ':' + str(torch.cuda.current_device())
|
||||
elif isinstance(device, int):
|
||||
device = 'cuda:' + str(device)
|
||||
elif isinstance(device, torch.device):
|
||||
if device.index is None:
|
||||
device = device.type + ':' + str(torch.cuda.current_device())
|
||||
else:
|
||||
device = device.type + ':' + str(device.index)
|
||||
else:
|
||||
raise RuntimeError('Unsupported device type')
|
||||
return device
|
||||
|
||||
def get_default_device_str(type):
|
||||
if isinstance(type, str):
|
||||
if type == 'cuda':
|
||||
return 'cuda:' + str(torch.cuda.current_device())
|
||||
else:
|
||||
return 'cpu'
|
||||
else:
|
||||
raise RuntimeError('Unsupported device type')
|
||||
|
||||
def get_all_gradients_finite_name_from_session(session):
|
||||
'''Find all_gradients_finite node on Session graph and return its name'''
|
||||
|
|
|
|||
|
|
@ -26,48 +26,18 @@ __TEMP_ENABLE_METHOD_TIMING__ = False
|
|||
T = TypeVar('T', bound='Module')
|
||||
|
||||
|
||||
def _get_device_index(device):
|
||||
if isinstance(device, str):
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
return device
|
||||
return 0 if device.index is None else device.index
|
||||
|
||||
def _get_device_str(device):
|
||||
if isinstance(device, str):
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
if device.find(':') == -1:
|
||||
device += ':' + str(torch.cuda.current_device())
|
||||
elif isinstance(device, int):
|
||||
device = 'cuda:' + str(device)
|
||||
elif isinstance(device, torch.device):
|
||||
if device.index is None:
|
||||
device = device.type + ':' + str(torch.cuda.current_device())
|
||||
else:
|
||||
device = device.type + ':' + str(device.index)
|
||||
else:
|
||||
raise ('Unsupported device type')
|
||||
return device
|
||||
|
||||
def _get_default_device_str(type):
|
||||
if type == 'cuda':
|
||||
return 'cuda:' + str(torch.cuda.current_device())
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
def _create_iobinding(io_binding, inputs, model, device):
|
||||
'''Creates IO binding for a `model` inputs and output'''
|
||||
for idx, value_info in enumerate(model.graph.input):
|
||||
io_binding.bind_input(value_info.name, inputs[idx].device.type,
|
||||
_get_device_index(inputs[idx].device),
|
||||
_utils.get_device_index(inputs[idx].device),
|
||||
_utils.dtype_torch_to_numpy(inputs[idx].dtype),
|
||||
list(inputs[idx].size()),
|
||||
inputs[idx].data_ptr())
|
||||
|
||||
for value_info in model.graph.output:
|
||||
io_binding.bind_output(value_info.name, device.type,
|
||||
device_id=_get_device_index(device))
|
||||
device_id=_utils.get_device_index(device))
|
||||
|
||||
def _deepcopy_model_input(*inputs, **kwargs):
|
||||
sample_inputs_copy = []
|
||||
|
|
@ -165,8 +135,8 @@ class ORTModule(torch.nn.Module):
|
|||
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
# TODO: This is incorrect when different layers may be in different devices
|
||||
self._device = next(module.parameters()).device
|
||||
# TODO: Single device support for now
|
||||
self._device = _utils.get_device_from_module(module)
|
||||
self._device_changed = False
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
|
|
@ -204,8 +174,6 @@ class ORTModule(torch.nn.Module):
|
|||
self._module_gradient_graph_builder.initialize(self._onnx_inference.SerializeToString(), grad_builder_config)
|
||||
|
||||
def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
|
||||
input_names, dynamic_axes, self._input_names_require_grad = \
|
||||
_parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
|
||||
self._onnx_inference = self._get_inference_graph(*inputs, **kwargs)
|
||||
|
||||
if self._save_onnx:
|
||||
|
|
@ -251,7 +219,7 @@ class ORTModule(torch.nn.Module):
|
|||
def cpu(self: T) -> T:
|
||||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
||||
if self._device.type != 'cpu':
|
||||
if not self._device or self._device.type != 'cpu':
|
||||
self._device_changed = True
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
|
|
@ -261,12 +229,12 @@ class ORTModule(torch.nn.Module):
|
|||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
||||
if device is None:
|
||||
if _get_device_str(self._device) != _get_default_device_str('cuda'):
|
||||
if self._device and _utils.get_device_str(self._device) != _utils.get_default_device_str('cuda'):
|
||||
self._device_changed = True
|
||||
self._device = torch.device(_get_default_device_str('cuda'))
|
||||
elif _get_device_str(self._device) != _get_device_str(device):
|
||||
self._device = torch.device(_utils.get_default_device_str('cuda'))
|
||||
elif not self._device or _utils.get_device_str(self._device) != _utils.get_device_str(device):
|
||||
self._device_changed = True
|
||||
self._device = torch.device(_get_device_str(device))
|
||||
self._device = torch.device(_utils.get_device_str(device))
|
||||
|
||||
return super(ORTModule, self).cuda(device)
|
||||
|
||||
|
|
@ -289,10 +257,15 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device:
|
||||
device_str = _get_device_str(device)
|
||||
if _get_device_str(self._device) != device_str:
|
||||
try:
|
||||
device_str = _utils.get_device_str(device)
|
||||
if _utils.get_device_str(self._device) != device_str:
|
||||
self._device_changed = True
|
||||
self._device = torch.device(device_str)
|
||||
except RuntimeError:
|
||||
self._device_changed = True
|
||||
self._device = torch.device(device_str)
|
||||
|
||||
return super(ORTModule, self).to(*args, **kwargs)
|
||||
|
||||
def eval(self: T) -> T:
|
||||
|
|
@ -315,7 +288,11 @@ class ORTModule(torch.nn.Module):
|
|||
return self._original_module(*inputs, **kwargs)
|
||||
|
||||
# Exporting module to ONNX for the first time
|
||||
if not self._onnx_inference:
|
||||
if not self._onnx_training:
|
||||
if not self._device:
|
||||
self._device = _utils.get_device_from_input_args_kwargs(self._original_module, *inputs, **kwargs)
|
||||
if not self._device:
|
||||
raise RuntimeError('A device must be specified in the model or data!')
|
||||
self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
|
||||
|
||||
_, _, input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
|
||||
|
|
@ -366,7 +343,7 @@ class ORTModule(torch.nn.Module):
|
|||
backward_grad_output_ortvalue = []
|
||||
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
|
||||
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
|
||||
grad_output.dtype), grad_output.device.type, _get_device_index(grad_output.device), grad_output.data_ptr()))
|
||||
grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr()))
|
||||
|
||||
# Run and get results
|
||||
self._training_session.run_backward(backward_grad_output_ortvalue)
|
||||
|
|
@ -441,16 +418,19 @@ class ORTModule(torch.nn.Module):
|
|||
# Therefore, deepcopy only the data component of the input tensors for export.
|
||||
sample_inputs_copy = _deepcopy_model_input(*inputs, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(self._original_module,
|
||||
sample_inputs_copy,
|
||||
f,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamic_axes=dynamic_axes)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(self._original_module,
|
||||
sample_inputs_copy,
|
||||
f,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamic_axes=dynamic_axes)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError('There was an error while exporting the PyTorch model to ONNX: {}'.format(e))
|
||||
|
||||
# TODO: this step might not be needed when we use the torch external allocator
|
||||
# clear cache after model export
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@
|
|||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification
|
||||
import pytest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import ORTModule
|
||||
from onnxruntime.training import _utils, ORTModule
|
||||
import _test_helpers
|
||||
|
||||
|
||||
# PyTorch model definitions for tests
|
||||
|
||||
class NeuralNetSinglePositionalArgument(torch.nn.Module):
|
||||
|
|
@ -446,3 +447,87 @@ def test_dynamic_axes_config():
|
|||
output = model_with_no_grad(x, y, None, None, None, None, z)
|
||||
assert output is not None
|
||||
assert _test_helpers.is_dynamic_axes(model_with_no_grad)
|
||||
|
||||
def test_model_with_multiple_devices_cpu_cuda():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10).cpu()
|
||||
self.fc2 = torch.nn.Linear(10, 10).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
model = MultipleDeviceModel()
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
model = ORTModule(model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model for now'
|
||||
|
||||
def test_model_with_multiple_devices_to_to():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10).to('cpu')
|
||||
self.fc2 = torch.nn.Linear(10, 10).to('cuda')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
model = MultipleDeviceModel()
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
model = ORTModule(model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model for now'
|
||||
|
||||
def test_model_with_multiple_devices_to_cpu():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10).to('cuda')
|
||||
self.fc2 = torch.nn.Linear(10, 10).cpu()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
model = MultipleDeviceModel()
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
model = ORTModule(model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model for now'
|
||||
|
||||
def test_model_with_multiple_devices_to_cuda():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10).to('cpu')
|
||||
self.fc2 = torch.nn.Linear(10, 10).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
model = MultipleDeviceModel()
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
model = ORTModule(model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model for now'
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2'])
|
||||
def test_model_with_different_cuda_devices(device):
|
||||
|
||||
# Trick to run this test in single GPU machines
|
||||
device_id = _utils.get_device_index(device)
|
||||
if device_id >= torch.cuda.device_count():
|
||||
warnings.warn('Skipping test_model_with_different_cuda_devices(cuda:{})'.format(device_id))
|
||||
return
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
model.to(device)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
model(x)
|
||||
|
|
|
|||
Loading…
Reference in a new issue