mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Handle model with no parameters (#7736)
* Handle model with no parameters * Set the minimum module_output_grads as 0 to handle parameterless models
This commit is contained in:
parent
96deec596f
commit
e161213f8e
5 changed files with 65 additions and 11 deletions
|
|
@ -2497,9 +2497,13 @@ Return true if all elements are true and false otherwise.
|
|||
.Input(0, "module_outputs", "Module outputs to be returned to pytorch.", "T", OpSchema::Variadic,
|
||||
/*is_homogeneous*/ false,
|
||||
/*min_arity*/ 1)
|
||||
/*
|
||||
For a situation where there are no trainable parameters in a model, the YieldOp minimum
|
||||
number of arguments expected for module_output_grad should be 0.
|
||||
*/
|
||||
.Output(0, "module_outputs_grad", "Gradient of module outputs returned from pytorch.", "T", OpSchema::Variadic,
|
||||
/*is_homogeneous*/ false,
|
||||
/*min_arity*/ 1)
|
||||
/*min_arity*/ 0)
|
||||
.Attr("non_differentiable_outputs", "The indices of the module outputs that doesn't have a gradient.", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("full_shape_outputs", "The indices of the module outputs that must have full shape.", AttributeProto::INTS)
|
||||
.TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.")
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class GraphExecutionManager(ABC):
|
|||
# All required models have already been exported previously
|
||||
return False
|
||||
|
||||
self._set_device_from_module()
|
||||
self._set_device_from_module(inputs, kwargs)
|
||||
self._onnx_model = self._get_exported_model(*inputs, **kwargs)
|
||||
_utils._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model, self._loglevel < _logger.LogLevel.WARNING, self.is_rocm_pytorch)
|
||||
if self._save_onnx:
|
||||
|
|
@ -269,14 +269,15 @@ class GraphExecutionManager(ABC):
|
|||
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
def _set_device_from_module(self):
|
||||
def _set_device_from_module(self, inputs, kwargs):
|
||||
"""Get the device from the module and save it to self._device"""
|
||||
|
||||
device_from_module = _utils.get_device_from_module(self._original_module)
|
||||
if not self._device or self._device != device_from_module:
|
||||
self._device = device_from_module
|
||||
device = _utils.get_device_from_module(self._original_module) or \
|
||||
_utils.get_device_from_inputs(inputs, kwargs)
|
||||
if not self._device or self._device != device:
|
||||
self._device = device
|
||||
if not self._device:
|
||||
raise RuntimeError('A device must be specified in the model!')
|
||||
raise RuntimeError('A device must be specified in the model or inputs!')
|
||||
|
||||
def _get_graph_transformer_config(self):
|
||||
graph_transformer_config = C.TrainingGraphTransformerConfiguration()
|
||||
|
|
|
|||
|
|
@ -81,12 +81,13 @@ class TrainingManager(GraphExecutionManager):
|
|||
if build_gradient_graph:
|
||||
self._build_graph()
|
||||
|
||||
module_device = _utils.get_device_from_module(self._original_module)
|
||||
device = _utils.get_device_from_module(self._original_module) or \
|
||||
_utils.get_device_from_inputs(inputs, kwargs)
|
||||
# The _training_session/_inference_session should be created every time
|
||||
# the graph was built or if the device changed between calls to forward
|
||||
create_execution_session = build_gradient_graph or self._device != module_device
|
||||
if self._device != module_device:
|
||||
self._device = module_device
|
||||
create_execution_session = build_gradient_graph or self._device != device
|
||||
if self._device != device:
|
||||
self._device = device
|
||||
if create_execution_session:
|
||||
# Create execution session creates the training_session
|
||||
self._create_execution_agent()
|
||||
|
|
|
|||
|
|
@ -99,6 +99,18 @@ def get_device_from_module(module):
|
|||
pass
|
||||
return device
|
||||
|
||||
|
||||
def get_device_from_inputs(args, kwargs):
|
||||
'''Returns device from first PyTorch Tensor within args or kwargs'''
|
||||
|
||||
device = None
|
||||
if args:
|
||||
device = torch.device(args[0].device)
|
||||
elif kwargs:
|
||||
device = torch.device(next(iter(kwargs.values())).device)
|
||||
return device
|
||||
|
||||
|
||||
def _create_iobinding(io_binding, inputs, model, device):
|
||||
'''Creates IO binding for a `model` inputs and output'''
|
||||
for idx, value_info in enumerate(model.graph.input):
|
||||
|
|
|
|||
|
|
@ -250,6 +250,13 @@ class UnusedMiddleParameterNet(torch.nn.Module):
|
|||
out = out + self.buffer
|
||||
return out
|
||||
|
||||
class StatelessModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(StatelessModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
|
||||
# while the next task already start.
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -2562,3 +2569,32 @@ def test_output_order():
|
|||
assert len(out_pt) == len(out_ort)
|
||||
for x, y in zip(out_pt, out_ort):
|
||||
_test_helpers.assert_values_are_close(x, y)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu', None])
|
||||
def test_stateless_model_specified_device(device):
|
||||
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = StatelessModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(N, D_in, device=device)
|
||||
ort_x = pt_x.clone()
|
||||
|
||||
pt_y = pt_model(pt_x)
|
||||
ort_y = ort_model(ort_x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_y, ort_y)
|
||||
|
||||
def test_stateless_model_unspecified_device():
|
||||
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = StatelessModel()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(N, D_in)
|
||||
ort_x = pt_x.clone()
|
||||
|
||||
pt_y = pt_model(pt_x)
|
||||
ort_y = ort_model(ort_x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_y, ort_y)
|
||||
|
|
|
|||
Loading…
Reference in a new issue