diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 200e43f258..d091b76f97 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -26,10 +26,7 @@ def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_fl if tensor_flag: # Got a tensor. Assume it's a DLPack tensor # and convert it to Pytorch tensor. - if not inplace_flag: - wrapped_arg = from_dlpack(arg) - else: - wrapped_arg = from_dlpack(arg).detach().contiguous() + wrapped_arg = from_dlpack(arg) # Only requires gradient when running under training mode # and the associated tensor has grad_flag=True (i.e., @@ -65,8 +62,8 @@ def call_python_forward_function( inplace: indicates if args can be modified inside the custom function. args: inputs to "backward_function". ''' - def generate_non_leaf_or_not(grad_flag, tensor_flag, arg): - if tensor_flag and grad_flag: + def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode): + if is_training_mode and tensor_flag and grad_flag: # "multiply one" helps change the torch tensor's is_leaf to be False. # This is required when the torch tensor is updated in-place during forward pass. # We cannot use view here, because PyTorch handels grad_fn for view differently. @@ -112,9 +109,9 @@ def call_python_forward_function( wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args)) - with torch.enable_grad(): + with torch.set_grad_enabled(is_training_mode): # Another level of wrap to avoid requires_grad=True for leaf variables. - new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg) + new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode) for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args)) # Run autograd.Function.apply(...). diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 607ebb6246..79748816a2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -817,4 +817,210 @@ def test_GeLU_When_Autograd_Func_Fallback_Not_Enabled(): inputs_on_device = [x_ort.to(device)] output = model(*inputs_on_device) except RuntimeError as e: - assert "Detected autograd functions usage in current model, the run will fail" in str(e) \ No newline at end of file + assert "Detected autograd functions usage in current model, the run will fail" in str(e) + +def test_MultipleStream_InForwardFunction(): + class MultipleStreamFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + default_stream = torch.cuda.current_stream() + ctx.save_for_backward(input) + stream = torch.cuda.Stream() + torch.cuda._sleep(1000 * 1000) + input = input * 0.2 + # on different stream + with torch.cuda.stream(stream): + stream.wait_stream(default_stream) + input= input * 2 + default_stream.wait_stream(stream) + return input + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output + + class MultipleStreamModel(torch.nn.Module): + def __init__(self, output_size): + super(MultipleStreamModel, self).__init__() + self.linear_a = torch.nn.Linear(output_size, output_size) + self.relu = MultipleStreamFunction.apply + + def forward(self, model_input): + model_input = model_input * 0.2 + out = self.relu(model_input) + return out + + output_size = 2 + + def model_builder(): + return MultipleStreamModel(output_size) + + def input_generator(): + return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + # Test multi-input and multi-output custom function. + cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) + + expected_ret_list = [torch.tensor([-0.7760, -0.7280])] + + compare_tensor_list(expected_ret_list, cuda_output_list) + + +def test_NonDefaultStream_InForwardFunction1(): + class MultipleStreamFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + default_stream = torch.cuda.current_stream() + stream = torch.cuda.Stream() + # on different stream + with torch.cuda.stream(stream): + stream.wait_stream(default_stream) + ctx.save_for_backward(input) + input = input * 0.4 + + default_stream.wait_stream(stream) + return input + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output + + class MultipleStreamModel(torch.nn.Module): + def __init__(self, output_size): + super(MultipleStreamModel, self).__init__() + self.linear_a = torch.nn.Linear(output_size, output_size) + self.relu = MultipleStreamFunction.apply + + def forward(self, model_input): + model_input = model_input * 0.2 + torch.cuda._sleep(1000 * 1000) + out = self.relu(model_input) + return out + + output_size = 2 + + def model_builder(): + return MultipleStreamModel(output_size) + + def input_generator(): + return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + # Test multi-input and multi-output custom function. + cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) + + expected_ret_list = [torch.tensor([-0.7760, -0.7280])] + + compare_tensor_list(expected_ret_list, cuda_output_list) + + +def test_NonDefaultStream_InForwardFunction2(): + class MultipleStreamFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + torch.cuda._sleep(1000 * 1000) + input = input * 0.4 + return input + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output + + class MultipleStreamModel(torch.nn.Module): + def __init__(self, output_size): + super(MultipleStreamModel, self).__init__() + self.linear_a = torch.nn.Linear(output_size, output_size) + self.relu = MultipleStreamFunction.apply + + def forward(self, model_input): + model_input = model_input * 0.2 + stream = torch.cuda.Stream() + default_stream = torch.cuda.current_stream() + # on different stream + with torch.cuda.stream(stream): + stream.wait_stream(default_stream) + out = self.relu(model_input) + default_stream.wait_stream(stream) + return out + + output_size = 2 + + def model_builder(): + return MultipleStreamModel(output_size) + + def input_generator(): + return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + # Test multi-input and multi-output custom function. + cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) + + expected_ret_list = [torch.tensor([-0.7760, -0.7280])] + + compare_tensor_list(expected_ret_list, cuda_output_list) + +def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): + class MultipleStreamFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + default_stream = torch.cuda.current_stream() + stream = torch.cuda.Stream() + # on different stream + with torch.cuda.stream(stream): + stream.wait_stream(default_stream) + ctx.save_for_backward(input) + input.mul_(0.4) + + ctx.mark_dirty(input) + default_stream.wait_stream(stream) + return input + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output + + class MultipleStreamModel(torch.nn.Module): + def __init__(self, output_size): + super(MultipleStreamModel, self).__init__() + self.linear_a = torch.nn.Linear(output_size, output_size) + self.relu = MultipleStreamFunction.apply + + def forward(self, model_input): + model_input = model_input * 0.2 + torch.cuda._sleep(1000 * 1000) + out = self.relu(model_input) + return out + + output_size = 2 + + def model_builder(): + return MultipleStreamModel(output_size) + + def input_generator(): + return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float) + + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + # Test multi-input and multi-output custom function. + cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input) + + expected_ret_list = [torch.tensor([-0.7760, -0.7280])] + + compare_tensor_list(expected_ret_list, cuda_output_list) diff --git a/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc b/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc index d5f72af221..e49b62ef6e 100644 --- a/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc +++ b/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc @@ -36,16 +36,10 @@ ONNX_OPERATOR_KERNEL_EX( PythonOpGrad); Status PythonOp::ComputeInternal(OpKernelContext* context) const { - // Todo(pengwa): perf impact and how much, leave it now to guarantee correctness. - CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); - void* diff_ctx = nullptr; std::vector returned_ortvalues; RunForward(context, &diff_ctx, returned_ortvalues); - // todo(pengwa): okay to remove it? - CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); - SetOutputs(context, diff_ctx, returned_ortvalues); RefCountTracker::GetInstance().DumpDetails("Forward Kernel Completed"); @@ -53,14 +47,9 @@ Status PythonOp::ComputeInternal(OpKernelContext* context) const { } Status PythonOpGrad::ComputeInternal(OpKernelContext* context) const { - // Todo(pengwa): perf impact and how much, leave it now to guarantee correctness. - CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); - std::vector returned_ortvalues; RunBackward(context, returned_ortvalues); - // todo(pengwa): okay to remove it? - CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); SetOutputs(context, returned_ortvalues);