fix ortmodule's output device info when it runs on ort device (#10616)

Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Tang, Cheng 2022-02-24 10:22:55 -08:00 committed by GitHub
parent 446258fa28
commit 7660eeef3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 11 deletions

View file

@ -68,6 +68,8 @@ class OrtModuleEagerTest(unittest.TestCase):
#reload initial state
model.load_state_dict(initial_state)
#run on ort with ORTModule and eager mode
#use device_idx 1 to test non-zero device
torch_ort_eager.set_device(1, 'CPUExecutionProvider', {'dummy':'dummy'})
device = torch.device('ort', index=0)
model.to(device)
model = ORTModule(model)

View file

@ -316,19 +316,28 @@ class TrainingManager(GraphExecutionManager):
session_options, providers, provider_options = self._get_session_config()
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
fw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * (len(self._graph_info.user_output_names) +
len(self._graph_info.frontier_node_arg_map))
device_type = self._device if type(self._device) is str else self._device.type.lower()
if device_type == 'ort':
fw_outputs_device_info = [C.get_ort_device(self._device.index)] * (len(self._graph_info.user_output_names) +
len(self._graph_info.frontier_node_arg_map))
else:
fw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * (len(self._graph_info.user_output_names) +
len(self._graph_info.frontier_node_arg_map))
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
bw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(bw_fetches_names)
if device_type == 'ort':
bw_outputs_device_info = [
C.get_ort_device(self._device.index)] * len(bw_fetches_names)
else:
bw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(bw_fetches_names)
self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(),
fw_feed_names,