diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 91637fd7f9..4e8da631b8 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2497,9 +2497,13 @@ Return true if all elements are true and false otherwise. .Input(0, "module_outputs", "Module outputs to be returned to pytorch.", "T", OpSchema::Variadic, /*is_homogeneous*/ false, /*min_arity*/ 1) + /* + For a situation where there are no trainable parameters in a model, the YieldOp minimum + number of arguments expected for module_output_grad should be 0. + */ .Output(0, "module_outputs_grad", "Gradient of module outputs returned from pytorch.", "T", OpSchema::Variadic, /*is_homogeneous*/ false, - /*min_arity*/ 1) + /*min_arity*/ 0) .Attr("non_differentiable_outputs", "The indices of the module outputs that doesn't have a gradient.", AttributeProto::INTS, OPTIONAL_VALUE) .Attr("full_shape_outputs", "The indices of the module outputs that must have full shape.", AttributeProto::INTS) .TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.") diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 439d780683..e7b7b60613 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -207,7 +207,7 @@ class GraphExecutionManager(ABC): # All required models have already been exported previously return False - self._set_device_from_module() + self._set_device_from_module(inputs, kwargs) self._onnx_model = self._get_exported_model(*inputs, **kwargs) _utils._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model, self._loglevel < _logger.LogLevel.WARNING, self.is_rocm_pytorch) if self._save_onnx: @@ -269,14 +269,15 @@ class GraphExecutionManager(ABC): return onnx.load_model_from_string(f.getvalue()) - def _set_device_from_module(self): + def _set_device_from_module(self, inputs, kwargs): """Get the device from the module and save it to self._device""" - device_from_module = _utils.get_device_from_module(self._original_module) - if not self._device or self._device != device_from_module: - self._device = device_from_module + device = _utils.get_device_from_module(self._original_module) or \ + _utils.get_device_from_inputs(inputs, kwargs) + if not self._device or self._device != device: + self._device = device if not self._device: - raise RuntimeError('A device must be specified in the model!') + raise RuntimeError('A device must be specified in the model or inputs!') def _get_graph_transformer_config(self): graph_transformer_config = C.TrainingGraphTransformerConfiguration() diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index f6e618d2f8..21faf6d6ea 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -81,12 +81,13 @@ class TrainingManager(GraphExecutionManager): if build_gradient_graph: self._build_graph() - module_device = _utils.get_device_from_module(self._original_module) + device = _utils.get_device_from_module(self._original_module) or \ + _utils.get_device_from_inputs(inputs, kwargs) # The _training_session/_inference_session should be created every time # the graph was built or if the device changed between calls to forward - create_execution_session = build_gradient_graph or self._device != module_device - if self._device != module_device: - self._device = module_device + create_execution_session = build_gradient_graph or self._device != device + if self._device != device: + self._device = device if create_execution_session: # Create execution session creates the training_session self._create_execution_agent() diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index cf850d4b8d..67ae17f403 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -99,6 +99,18 @@ def get_device_from_module(module): pass return device + +def get_device_from_inputs(args, kwargs): + '''Returns device from first PyTorch Tensor within args or kwargs''' + + device = None + if args: + device = torch.device(args[0].device) + elif kwargs: + device = torch.device(next(iter(kwargs.values())).device) + return device + + def _create_iobinding(io_binding, inputs, model, device): '''Creates IO binding for a `model` inputs and output''' for idx, value_info in enumerate(model.graph.input): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 3cdd36afc0..20963c38b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -250,6 +250,13 @@ class UnusedMiddleParameterNet(torch.nn.Module): out = out + self.buffer return out +class StatelessModel(torch.nn.Module): + def __init__(self): + super(StatelessModel, self).__init__() + + def forward(self, x): + return x + # TODO: This is a workaround for the problem that pytest is still cleaning up the previous test # while the next task already start. @pytest.fixture(autouse=True) @@ -2562,3 +2569,32 @@ def test_output_order(): assert len(out_pt) == len(out_ort) for x, y in zip(out_pt, out_ort): _test_helpers.assert_values_are_close(x, y) + +@pytest.mark.parametrize("device", ['cuda', 'cpu', None]) +def test_stateless_model_specified_device(device): + + N, D_in, H, D_out = 32, 784, 500, 10 + pt_model = StatelessModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(N, D_in, device=device) + ort_x = pt_x.clone() + + pt_y = pt_model(pt_x) + ort_y = ort_model(ort_x) + + _test_helpers.assert_values_are_close(pt_y, ort_y) + +def test_stateless_model_unspecified_device(): + + N, D_in, H, D_out = 32, 784, 500, 10 + pt_model = StatelessModel() + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(N, D_in) + ort_x = pt_x.clone() + + pt_y = pt_model(pt_x) + ort_y = ort_model(ort_x) + + _test_helpers.assert_values_are_close(pt_y, ort_y)