mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
fix ortmodule's output device info when it runs on ort device (#10616)
Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
446258fa28
commit
7660eeef3e
2 changed files with 22 additions and 11 deletions
|
|
@ -68,6 +68,8 @@ class OrtModuleEagerTest(unittest.TestCase):
|
|||
#reload initial state
|
||||
model.load_state_dict(initial_state)
|
||||
#run on ort with ORTModule and eager mode
|
||||
#use device_idx 1 to test non-zero device
|
||||
torch_ort_eager.set_device(1, 'CPUExecutionProvider', {'dummy':'dummy'})
|
||||
device = torch.device('ort', index=0)
|
||||
model.to(device)
|
||||
model = ORTModule(model)
|
||||
|
|
|
|||
|
|
@ -316,19 +316,28 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
session_options, providers, provider_options = self._get_session_config()
|
||||
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
|
||||
fw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * (len(self._graph_info.user_output_names) +
|
||||
len(self._graph_info.frontier_node_arg_map))
|
||||
device_type = self._device if type(self._device) is str else self._device.type.lower()
|
||||
if device_type == 'ort':
|
||||
fw_outputs_device_info = [C.get_ort_device(self._device.index)] * (len(self._graph_info.user_output_names) +
|
||||
len(self._graph_info.frontier_node_arg_map))
|
||||
else:
|
||||
fw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * (len(self._graph_info.user_output_names) +
|
||||
len(self._graph_info.frontier_node_arg_map))
|
||||
|
||||
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
|
||||
bw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(bw_fetches_names)
|
||||
if device_type == 'ort':
|
||||
bw_outputs_device_info = [
|
||||
C.get_ort_device(self._device.index)] * len(bw_fetches_names)
|
||||
else:
|
||||
bw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(bw_fetches_names)
|
||||
|
||||
self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(),
|
||||
fw_feed_names,
|
||||
|
|
|
|||
Loading…
Reference in a new issue