diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index d83fa0a49c..92497441fb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -64,12 +64,13 @@ def call_python_forward_function( inplace: indicates if args can be modified inside the custom function. args: inputs to "backward_function". ''' - def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode): - if is_training_mode and tensor_flag and grad_flag: + + def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, is_inplace): + if is_training_mode and tensor_flag and grad_flag and is_inplace: # "multiply one" helps change the torch tensor's is_leaf to be False. # This is required when the torch tensor is updated in-place during forward pass. - # We cannot use view here, because PyTorch handels grad_fn for view differently. - non_leaf_arg = arg * arg.new_ones((1,)) + # We cannot use view here, because PyTorch handles grad_fn for view differently. + non_leaf_arg = arg * 1 return non_leaf_arg else: return arg @@ -114,7 +115,7 @@ def call_python_forward_function( with torch.set_grad_enabled(is_training_mode): # Another level of wrap to avoid requires_grad=True for leaf variables. - new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode) + new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, inplace) for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args)) # Run autograd.Function.apply(...). diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index 7455a26ea4..9544b09f18 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -196,22 +196,30 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): def enable_custom_autograd_function(module): for mode in [True, False]: module._torch_module._execution_manager(mode)._enable_custom_autograd_function = True + module._torch_module._execution_manager(mode)._save_onnx = True + module._torch_module._execution_manager(mode)._save_onnx_prefix = "utbench" def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False): - model.to(device) + with torch.no_grad(): + model = copy.deepcopy(model).to(device) if is_eval_mode: model.eval() else: model.train() - inputs_on_device = [input_.to(device) for input_ in input_list] + with torch.no_grad(): + inputs_on_device = [input_.to(device) for input_ in input_list] + for i, val in enumerate(input_list): + if val.requires_grad: + inputs_on_device[i].requires_grad_() + target = label_input.to(device) + output = model(*inputs_on_device) forward_outputs = [output] grad_outputs = [] if not is_eval_mode: criterion = torch.nn.MSELoss() - target = label_input.to(device) loss = criterion(output, target) loss.backward() for name, param in model.named_parameters(): @@ -220,8 +228,9 @@ def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_m return forward_outputs, grad_outputs def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False): - model = copy.deepcopy(model) - model.to(device) + with torch.no_grad(): + model = copy.deepcopy(model) + model.to(device) model = ORTModule(model) enable_custom_autograd_function(model) if is_eval_mode: @@ -229,14 +238,19 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode= else: model.train() - inputs_on_device = [input_.to(device) for input_ in input_list] + with torch.no_grad(): + inputs_on_device = [input_.to(device) for input_ in input_list] + for i, val in enumerate(input_list): + if val.requires_grad: + inputs_on_device[i].requires_grad_() + + target = label_input.to(device) output = model(*inputs_on_device) forward_outputs = [output] grad_outputs = [] if not is_eval_mode: criterion = torch.nn.MSELoss() - target = label_input.to(device) loss = criterion(output, target) loss.backward() for name, param in model.named_parameters(): @@ -248,28 +262,33 @@ def compare_tensor_list(val_list_a, val_list_b): for val_a, val_b in zip(val_list_a, val_list_b): assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6) -def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, ignore_grad_compare=False): +def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, + ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): cpu = torch.device("cpu") def cpu_barrier_func(): pass run_training_test_on_device_and_compare( - cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func, ignore_grad_compare) + cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func, + ignore_grad_compare, expected_outputs, expected_grads) def cuda_barrier_func(): torch.cuda.synchronize() cuda = torch.device('cuda:0') run_training_test_on_device_and_compare( - cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func, ignore_grad_compare) + cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func, + ignore_grad_compare, expected_outputs, expected_grads) -def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func, ignore_grad_compare=False): +def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func, + ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): repeats = 16 for i in range(repeats): m = pt_model_builder_func() x = pt_model_inputs_generator() - m_ort = copy.deepcopy(m) - x_ort = copy.deepcopy(x) + with torch.no_grad(): + m_ort = copy.deepcopy(m) + x_ort = copy.deepcopy(x) outputs, grads = run_with_pytorch_on_device( device, m, [x], pt_model_label_input) @@ -283,12 +302,18 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo val_list_b = [o.detach().cpu() for o in outputs_ort if o is not None] compare_tensor_list(val_list_a, val_list_b) + if len(expected_outputs) > 0: + compare_tensor_list(val_list_a, expected_outputs) + # For some test, it is expected the diff might be big due to inconsistent computation orders. if ignore_grad_compare is False: val_list_a = [o.detach().cpu() for o in grads if o is not None] val_list_b = [o.detach().cpu() for o in grads_ort if o is not None] compare_tensor_list(val_list_a, val_list_b) + if len(expected_grads) > 0: + compare_tensor_list(val_list_a, expected_grads) + def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input): cpu = torch.device("cpu") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 79748816a2..5be49c86ef 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -359,7 +359,7 @@ def test_InplaceUpdateInputAsOutputNotRequireGradWithMarkDirty(): run_training_test_and_compare(model_builder, input_generator, label_input) - +@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") def test_InplaceUpdateInputAsOutputRequireGrad(): class InplaceUpdateInputAsOutputRequireGradFunction(torch.autograd.Function): @staticmethod @@ -412,7 +412,7 @@ def test_InplaceUpdateInputAsOutputRequireGrad(): run_training_test_and_compare( model_builder, input_generator, label_input, ignore_grad_compare=True) - +@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") def test_InplaceUpdateInputNotAsOutputRequireGrad(): class InplaceUpdateInputNotAsOutputRequireGradFunction(torch.autograd.Function): # without mark_ditry, the inner computation graph is extracted into another subgraph, which is a duplicated computation with the PythonOp. @@ -843,12 +843,11 @@ def test_MultipleStream_InForwardFunction(): class MultipleStreamModel(torch.nn.Module): def __init__(self, output_size): super(MultipleStreamModel, self).__init__() - self.linear_a = torch.nn.Linear(output_size, output_size) self.relu = MultipleStreamFunction.apply def forward(self, model_input): - model_input = model_input * 0.2 - out = self.relu(model_input) + b = model_input * 0.2 + out = self.relu(b) return out output_size = 2 @@ -857,19 +856,15 @@ def test_MultipleStream_InForwardFunction(): return MultipleStreamModel(output_size) def input_generator(): - return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + return torch.tensor([2.8, 3.4], requires_grad=True) # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) - - expected_ret_list = [torch.tensor([-0.7760, -0.7280])] - - compare_tensor_list(expected_ret_list, cuda_output_list) - + run_training_test_and_compare(model_builder, input_generator, label_input, + expected_outputs=[torch.tensor([0.224, 0.272])]) def test_NonDefaultStream_InForwardFunction1(): class MultipleStreamFunction(torch.autograd.Function): @@ -894,7 +889,6 @@ def test_NonDefaultStream_InForwardFunction1(): class MultipleStreamModel(torch.nn.Module): def __init__(self, output_size): super(MultipleStreamModel, self).__init__() - self.linear_a = torch.nn.Linear(output_size, output_size) self.relu = MultipleStreamFunction.apply def forward(self, model_input): @@ -909,18 +903,15 @@ def test_NonDefaultStream_InForwardFunction1(): return MultipleStreamModel(output_size) def input_generator(): - return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + return torch.tensor([2.8, 3.4], requires_grad=True) # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) - - expected_ret_list = [torch.tensor([-0.7760, -0.7280])] - - compare_tensor_list(expected_ret_list, cuda_output_list) + run_training_test_and_compare(model_builder, input_generator, label_input, + expected_outputs=[torch.tensor([0.224, 0.272])]) def test_NonDefaultStream_InForwardFunction2(): @@ -940,7 +931,6 @@ def test_NonDefaultStream_InForwardFunction2(): class MultipleStreamModel(torch.nn.Module): def __init__(self, output_size): super(MultipleStreamModel, self).__init__() - self.linear_a = torch.nn.Linear(output_size, output_size) self.relu = MultipleStreamFunction.apply def forward(self, model_input): @@ -960,18 +950,16 @@ def test_NonDefaultStream_InForwardFunction2(): return MultipleStreamModel(output_size) def input_generator(): - return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + return torch.tensor([2.8, 3.4], requires_grad=True) # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) + run_training_test_and_compare(model_builder, input_generator, label_input, + expected_outputs=[torch.tensor([0.224, 0.272])]) - expected_ret_list = [torch.tensor([-0.7760, -0.7280])] - - compare_tensor_list(expected_ret_list, cuda_output_list) def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): class MultipleStreamFunction(torch.autograd.Function): @@ -997,7 +985,6 @@ def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): class MultipleStreamModel(torch.nn.Module): def __init__(self, output_size): super(MultipleStreamModel, self).__init__() - self.linear_a = torch.nn.Linear(output_size, output_size) self.relu = MultipleStreamFunction.apply def forward(self, model_input): @@ -1012,15 +999,12 @@ def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): return MultipleStreamModel(output_size) def input_generator(): - return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + return torch.tensor([2.8, 3.4], requires_grad=True) # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) - - expected_ret_list = [torch.tensor([-0.7760, -0.7280])] - - compare_tensor_list(expected_ret_list, cuda_output_list) + run_training_test_and_compare(model_builder, input_generator, label_input, + expected_outputs=[torch.tensor([0.224, 0.272])]) diff --git a/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc b/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc index eb376760ef..621183007e 100644 --- a/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc +++ b/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc @@ -50,7 +50,6 @@ Status PythonOpGrad::ComputeInternal(OpKernelContext* context) const { std::vector returned_ortvalues; RunBackward(context, returned_ortvalues); - SetOutputs(context, returned_ortvalues); RefCountTracker::GetInstance().DumpDetails("Backward Kernel Completed");