Handle model with no parameters (#7736)

* Handle model with no parameters

* Set the minimum module_output_grads as 0 to handle parameterless models
This commit is contained in:
baijumeswani 2021-05-18 09:33:57 -07:00 committed by GitHub
parent 96deec596f
commit e161213f8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 11 deletions

View file

@ -2497,9 +2497,13 @@ Return true if all elements are true and false otherwise.
.Input(0, "module_outputs", "Module outputs to be returned to pytorch.", "T", OpSchema::Variadic,
/*is_homogeneous*/ false,
/*min_arity*/ 1)
/*
For a situation where there are no trainable parameters in a model, the YieldOp minimum
number of arguments expected for module_output_grad should be 0.
*/
.Output(0, "module_outputs_grad", "Gradient of module outputs returned from pytorch.", "T", OpSchema::Variadic,
/*is_homogeneous*/ false,
/*min_arity*/ 1)
/*min_arity*/ 0)
.Attr("non_differentiable_outputs", "The indices of the module outputs that doesn't have a gradient.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("full_shape_outputs", "The indices of the module outputs that must have full shape.", AttributeProto::INTS)
.TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.")

View file

@ -207,7 +207,7 @@ class GraphExecutionManager(ABC):
# All required models have already been exported previously
return False
self._set_device_from_module()
self._set_device_from_module(inputs, kwargs)
self._onnx_model = self._get_exported_model(*inputs, **kwargs)
_utils._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model, self._loglevel < _logger.LogLevel.WARNING, self.is_rocm_pytorch)
if self._save_onnx:
@ -269,14 +269,15 @@ class GraphExecutionManager(ABC):
return onnx.load_model_from_string(f.getvalue())
def _set_device_from_module(self):
def _set_device_from_module(self, inputs, kwargs):
"""Get the device from the module and save it to self._device"""
device_from_module = _utils.get_device_from_module(self._original_module)
if not self._device or self._device != device_from_module:
self._device = device_from_module
device = _utils.get_device_from_module(self._original_module) or \
_utils.get_device_from_inputs(inputs, kwargs)
if not self._device or self._device != device:
self._device = device
if not self._device:
raise RuntimeError('A device must be specified in the model!')
raise RuntimeError('A device must be specified in the model or inputs!')
def _get_graph_transformer_config(self):
graph_transformer_config = C.TrainingGraphTransformerConfiguration()

View file

@ -81,12 +81,13 @@ class TrainingManager(GraphExecutionManager):
if build_gradient_graph:
self._build_graph()
module_device = _utils.get_device_from_module(self._original_module)
device = _utils.get_device_from_module(self._original_module) or \
_utils.get_device_from_inputs(inputs, kwargs)
# The _training_session/_inference_session should be created every time
# the graph was built or if the device changed between calls to forward
create_execution_session = build_gradient_graph or self._device != module_device
if self._device != module_device:
self._device = module_device
create_execution_session = build_gradient_graph or self._device != device
if self._device != device:
self._device = device
if create_execution_session:
# Create execution session creates the training_session
self._create_execution_agent()

View file

@ -99,6 +99,18 @@ def get_device_from_module(module):
pass
return device
def get_device_from_inputs(args, kwargs):
'''Returns device from first PyTorch Tensor within args or kwargs'''
device = None
if args:
device = torch.device(args[0].device)
elif kwargs:
device = torch.device(next(iter(kwargs.values())).device)
return device
def _create_iobinding(io_binding, inputs, model, device):
'''Creates IO binding for a `model` inputs and output'''
for idx, value_info in enumerate(model.graph.input):

View file

@ -250,6 +250,13 @@ class UnusedMiddleParameterNet(torch.nn.Module):
out = out + self.buffer
return out
class StatelessModel(torch.nn.Module):
def __init__(self):
super(StatelessModel, self).__init__()
def forward(self, x):
return x
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
# while the next task already start.
@pytest.fixture(autouse=True)
@ -2562,3 +2569,32 @@ def test_output_order():
assert len(out_pt) == len(out_ort)
for x, y in zip(out_pt, out_ort):
_test_helpers.assert_values_are_close(x, y)
@pytest.mark.parametrize("device", ['cuda', 'cpu', None])
def test_stateless_model_specified_device(device):
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = StatelessModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x = torch.randn(N, D_in, device=device)
ort_x = pt_x.clone()
pt_y = pt_model(pt_x)
ort_y = ort_model(ort_x)
_test_helpers.assert_values_are_close(pt_y, ort_y)
def test_stateless_model_unspecified_device():
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = StatelessModel()
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x = torch.randn(N, D_in)
ort_x = pt_x.clone()
pt_y = pt_model(pt_x)
ort_y = ort_model(ort_x)
_test_helpers.assert_values_are_close(pt_y, ort_y)