From 7660eeef3eff9a58b4d807dd951a34bafcb03ac0 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Thu, 24 Feb 2022 10:22:55 -0800 Subject: [PATCH] fix ortmodule's output device info when it runs on ort device (#10616) Co-authored-by: Cheng Tang --- .../test/linux_only_ortmodule_eager_test.py | 2 ++ .../training/ortmodule/_training_manager.py | 31 ++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py b/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py index 8134a116bf..d010c95134 100644 --- a/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py +++ b/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py @@ -68,6 +68,8 @@ class OrtModuleEagerTest(unittest.TestCase): #reload initial state model.load_state_dict(initial_state) #run on ort with ORTModule and eager mode + #use device_idx 1 to test non-zero device + torch_ort_eager.set_device(1, 'CPUExecutionProvider', {'dummy':'dummy'}) device = torch.device('ort', index=0) model.to(device) model = ORTModule(model) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 29ffb0dc70..eb8d1ec3a9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -316,19 +316,28 @@ class TrainingManager(GraphExecutionManager): session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] - fw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device), - C.OrtDevice.default_memory(), - _utils.get_device_index(self._device) - )] * (len(self._graph_info.user_output_names) + - len(self._graph_info.frontier_node_arg_map)) + device_type = self._device if type(self._device) is str else self._device.type.lower() + if device_type == 'ort': + fw_outputs_device_info = [C.get_ort_device(self._device.index)] * (len(self._graph_info.user_output_names) + + len(self._graph_info.frontier_node_arg_map)) + else: + fw_outputs_device_info = [ + C.OrtDevice(get_ort_device_type(self._device), + C.OrtDevice.default_memory(), + _utils.get_device_index(self._device) + )] * (len(self._graph_info.user_output_names) + + len(self._graph_info.frontier_node_arg_map)) bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output] - bw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device), - C.OrtDevice.default_memory(), - _utils.get_device_index(self._device) - )] * len(bw_fetches_names) + if device_type == 'ort': + bw_outputs_device_info = [ + C.get_ort_device(self._device.index)] * len(bw_fetches_names) + else: + bw_outputs_device_info = [ + C.OrtDevice(get_ort_device_type(self._device), + C.OrtDevice.default_memory(), + _utils.get_device_index(self._device) + )] * len(bw_fetches_names) self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(), fw_feed_names,