autograd function fallback perf (#8312)

* fix known issues

* Update orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
pengwa 2021-07-09 00:29:40 +08:00 committed by GitHub
parent c254c3c355
commit 6dbfb8db0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 212 additions and 20 deletions

View file

@ -26,10 +26,7 @@ def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_fl
if tensor_flag:
# Got a tensor. Assume it's a DLPack tensor
# and convert it to Pytorch tensor.
if not inplace_flag:
wrapped_arg = from_dlpack(arg)
else:
wrapped_arg = from_dlpack(arg).detach().contiguous()
wrapped_arg = from_dlpack(arg)
# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
@ -65,8 +62,8 @@ def call_python_forward_function(
inplace: indicates if args can be modified inside the custom function.
args: inputs to "backward_function".
'''
def generate_non_leaf_or_not(grad_flag, tensor_flag, arg):
if tensor_flag and grad_flag:
def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode):
if is_training_mode and tensor_flag and grad_flag:
# "multiply one" helps change the torch tensor's is_leaf to be False.
# This is required when the torch tensor is updated in-place during forward pass.
# We cannot use view here, because PyTorch handels grad_fn for view differently.
@ -112,9 +109,9 @@ def call_python_forward_function(
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
with torch.enable_grad():
with torch.set_grad_enabled(is_training_mode):
# Another level of wrap to avoid requires_grad=True for leaf variables.
new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg)
new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args))
# Run autograd.Function.apply(...).

View file

@ -817,4 +817,210 @@ def test_GeLU_When_Autograd_Func_Fallback_Not_Enabled():
inputs_on_device = [x_ort.to(device)]
output = model(*inputs_on_device)
except RuntimeError as e:
assert "Detected autograd functions usage in current model, the run will fail" in str(e)
assert "Detected autograd functions usage in current model, the run will fail" in str(e)
def test_MultipleStream_InForwardFunction():
class MultipleStreamFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
default_stream = torch.cuda.current_stream()
ctx.save_for_backward(input)
stream = torch.cuda.Stream()
torch.cuda._sleep(1000 * 1000)
input = input * 0.2
# on different stream
with torch.cuda.stream(stream):
stream.wait_stream(default_stream)
input= input * 2
default_stream.wait_stream(stream)
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output
class MultipleStreamModel(torch.nn.Module):
def __init__(self, output_size):
super(MultipleStreamModel, self).__init__()
self.linear_a = torch.nn.Linear(output_size, output_size)
self.relu = MultipleStreamFunction.apply
def forward(self, model_input):
model_input = model_input * 0.2
out = self.relu(model_input)
return out
output_size = 2
def model_builder():
return MultipleStreamModel(output_size)
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
compare_tensor_list(expected_ret_list, cuda_output_list)
def test_NonDefaultStream_InForwardFunction1():
class MultipleStreamFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
default_stream = torch.cuda.current_stream()
stream = torch.cuda.Stream()
# on different stream
with torch.cuda.stream(stream):
stream.wait_stream(default_stream)
ctx.save_for_backward(input)
input = input * 0.4
default_stream.wait_stream(stream)
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output
class MultipleStreamModel(torch.nn.Module):
def __init__(self, output_size):
super(MultipleStreamModel, self).__init__()
self.linear_a = torch.nn.Linear(output_size, output_size)
self.relu = MultipleStreamFunction.apply
def forward(self, model_input):
model_input = model_input * 0.2
torch.cuda._sleep(1000 * 1000)
out = self.relu(model_input)
return out
output_size = 2
def model_builder():
return MultipleStreamModel(output_size)
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
compare_tensor_list(expected_ret_list, cuda_output_list)
def test_NonDefaultStream_InForwardFunction2():
class MultipleStreamFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
torch.cuda._sleep(1000 * 1000)
input = input * 0.4
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output
class MultipleStreamModel(torch.nn.Module):
def __init__(self, output_size):
super(MultipleStreamModel, self).__init__()
self.linear_a = torch.nn.Linear(output_size, output_size)
self.relu = MultipleStreamFunction.apply
def forward(self, model_input):
model_input = model_input * 0.2
stream = torch.cuda.Stream()
default_stream = torch.cuda.current_stream()
# on different stream
with torch.cuda.stream(stream):
stream.wait_stream(default_stream)
out = self.relu(model_input)
default_stream.wait_stream(stream)
return out
output_size = 2
def model_builder():
return MultipleStreamModel(output_size)
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
compare_tensor_list(expected_ret_list, cuda_output_list)
def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
class MultipleStreamFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
default_stream = torch.cuda.current_stream()
stream = torch.cuda.Stream()
# on different stream
with torch.cuda.stream(stream):
stream.wait_stream(default_stream)
ctx.save_for_backward(input)
input.mul_(0.4)
ctx.mark_dirty(input)
default_stream.wait_stream(stream)
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output
class MultipleStreamModel(torch.nn.Module):
def __init__(self, output_size):
super(MultipleStreamModel, self).__init__()
self.linear_a = torch.nn.Linear(output_size, output_size)
self.relu = MultipleStreamFunction.apply
def forward(self, model_input):
model_input = model_input * 0.2
torch.cuda._sleep(1000 * 1000)
out = self.relu(model_input)
return out
output_size = 2
def model_builder():
return MultipleStreamModel(output_size)
def input_generator():
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
# Test multi-input and multi-output custom function.
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
compare_tensor_list(expected_ret_list, cuda_output_list)

View file

@ -36,16 +36,10 @@ ONNX_OPERATOR_KERNEL_EX(
PythonOpGrad);
Status PythonOp::ComputeInternal(OpKernelContext* context) const {
// Todo(pengwa): perf impact and how much, leave it now to guarantee correctness.
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
void* diff_ctx = nullptr;
std::vector<OrtValue> returned_ortvalues;
RunForward(context, &diff_ctx, returned_ortvalues);
// todo(pengwa): okay to remove it?
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
SetOutputs(context, diff_ctx, returned_ortvalues);
RefCountTracker::GetInstance().DumpDetails("Forward Kernel Completed");
@ -53,14 +47,9 @@ Status PythonOp::ComputeInternal(OpKernelContext* context) const {
}
Status PythonOpGrad::ComputeInternal(OpKernelContext* context) const {
// Todo(pengwa): perf impact and how much, leave it now to guarantee correctness.
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
std::vector<OrtValue> returned_ortvalues;
RunBackward(context, returned_ortvalues);
// todo(pengwa): okay to remove it?
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
SetOutputs(context, returned_ortvalues);