From 441b30b2d26d36ca1db2930ade2fe82622ce0cd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 30 Nov 2022 12:49:41 +0100 Subject: [PATCH] Move a function call outside a loop in ORTModule (#13771) ### Description The proposed change is useful for ORTModule when the output graph has multiple outputs. ### Motivation and Context performance Signed-off-by: xadupre --- orttraining/orttraining/python/training/ortmodule/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 475739a7b5..feb5ed7d12 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -196,8 +196,9 @@ def _create_iobinding(io_binding, inputs, model, device): for idx, value_info in enumerate(model.graph.input): io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx]))) + device_id = get_device_index(device) for value_info in model.graph.output: - io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) + io_binding.bind_output(value_info.name, device.type, device_id=device_id) def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module, user_module: torch.nn.Module):