mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
autograd function fallback perf (#8312)
* fix known issues * Update orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
parent
c254c3c355
commit
6dbfb8db0e
3 changed files with 212 additions and 20 deletions
|
|
@ -26,10 +26,7 @@ def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_fl
|
|||
if tensor_flag:
|
||||
# Got a tensor. Assume it's a DLPack tensor
|
||||
# and convert it to Pytorch tensor.
|
||||
if not inplace_flag:
|
||||
wrapped_arg = from_dlpack(arg)
|
||||
else:
|
||||
wrapped_arg = from_dlpack(arg).detach().contiguous()
|
||||
wrapped_arg = from_dlpack(arg)
|
||||
|
||||
# Only requires gradient when running under training mode
|
||||
# and the associated tensor has grad_flag=True (i.e.,
|
||||
|
|
@ -65,8 +62,8 @@ def call_python_forward_function(
|
|||
inplace: indicates if args can be modified inside the custom function.
|
||||
args: inputs to "backward_function".
|
||||
'''
|
||||
def generate_non_leaf_or_not(grad_flag, tensor_flag, arg):
|
||||
if tensor_flag and grad_flag:
|
||||
def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode):
|
||||
if is_training_mode and tensor_flag and grad_flag:
|
||||
# "multiply one" helps change the torch tensor's is_leaf to be False.
|
||||
# This is required when the torch tensor is updated in-place during forward pass.
|
||||
# We cannot use view here, because PyTorch handels grad_fn for view differently.
|
||||
|
|
@ -112,9 +109,9 @@ def call_python_forward_function(
|
|||
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
|
||||
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
|
||||
|
||||
with torch.enable_grad():
|
||||
with torch.set_grad_enabled(is_training_mode):
|
||||
# Another level of wrap to avoid requires_grad=True for leaf variables.
|
||||
new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg)
|
||||
new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode)
|
||||
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args))
|
||||
|
||||
# Run autograd.Function.apply(...).
|
||||
|
|
|
|||
|
|
@ -817,4 +817,210 @@ def test_GeLU_When_Autograd_Func_Fallback_Not_Enabled():
|
|||
inputs_on_device = [x_ort.to(device)]
|
||||
output = model(*inputs_on_device)
|
||||
except RuntimeError as e:
|
||||
assert "Detected autograd functions usage in current model, the run will fail" in str(e)
|
||||
assert "Detected autograd functions usage in current model, the run will fail" in str(e)
|
||||
|
||||
def test_MultipleStream_InForwardFunction():
|
||||
class MultipleStreamFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
default_stream = torch.cuda.current_stream()
|
||||
ctx.save_for_backward(input)
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda._sleep(1000 * 1000)
|
||||
input = input * 0.2
|
||||
# on different stream
|
||||
with torch.cuda.stream(stream):
|
||||
stream.wait_stream(default_stream)
|
||||
input= input * 2
|
||||
default_stream.wait_stream(stream)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output
|
||||
|
||||
class MultipleStreamModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(MultipleStreamModel, self).__init__()
|
||||
self.linear_a = torch.nn.Linear(output_size, output_size)
|
||||
self.relu = MultipleStreamFunction.apply
|
||||
|
||||
def forward(self, model_input):
|
||||
model_input = model_input * 0.2
|
||||
out = self.relu(model_input)
|
||||
return out
|
||||
|
||||
output_size = 2
|
||||
|
||||
def model_builder():
|
||||
return MultipleStreamModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
|
||||
|
||||
compare_tensor_list(expected_ret_list, cuda_output_list)
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction1():
|
||||
class MultipleStreamFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
default_stream = torch.cuda.current_stream()
|
||||
stream = torch.cuda.Stream()
|
||||
# on different stream
|
||||
with torch.cuda.stream(stream):
|
||||
stream.wait_stream(default_stream)
|
||||
ctx.save_for_backward(input)
|
||||
input = input * 0.4
|
||||
|
||||
default_stream.wait_stream(stream)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output
|
||||
|
||||
class MultipleStreamModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(MultipleStreamModel, self).__init__()
|
||||
self.linear_a = torch.nn.Linear(output_size, output_size)
|
||||
self.relu = MultipleStreamFunction.apply
|
||||
|
||||
def forward(self, model_input):
|
||||
model_input = model_input * 0.2
|
||||
torch.cuda._sleep(1000 * 1000)
|
||||
out = self.relu(model_input)
|
||||
return out
|
||||
|
||||
output_size = 2
|
||||
|
||||
def model_builder():
|
||||
return MultipleStreamModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
|
||||
|
||||
compare_tensor_list(expected_ret_list, cuda_output_list)
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction2():
|
||||
class MultipleStreamFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
torch.cuda._sleep(1000 * 1000)
|
||||
input = input * 0.4
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output
|
||||
|
||||
class MultipleStreamModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(MultipleStreamModel, self).__init__()
|
||||
self.linear_a = torch.nn.Linear(output_size, output_size)
|
||||
self.relu = MultipleStreamFunction.apply
|
||||
|
||||
def forward(self, model_input):
|
||||
model_input = model_input * 0.2
|
||||
stream = torch.cuda.Stream()
|
||||
default_stream = torch.cuda.current_stream()
|
||||
# on different stream
|
||||
with torch.cuda.stream(stream):
|
||||
stream.wait_stream(default_stream)
|
||||
out = self.relu(model_input)
|
||||
default_stream.wait_stream(stream)
|
||||
return out
|
||||
|
||||
output_size = 2
|
||||
|
||||
def model_builder():
|
||||
return MultipleStreamModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
|
||||
|
||||
compare_tensor_list(expected_ret_list, cuda_output_list)
|
||||
|
||||
def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
|
||||
class MultipleStreamFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
default_stream = torch.cuda.current_stream()
|
||||
stream = torch.cuda.Stream()
|
||||
# on different stream
|
||||
with torch.cuda.stream(stream):
|
||||
stream.wait_stream(default_stream)
|
||||
ctx.save_for_backward(input)
|
||||
input.mul_(0.4)
|
||||
|
||||
ctx.mark_dirty(input)
|
||||
default_stream.wait_stream(stream)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output
|
||||
|
||||
class MultipleStreamModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(MultipleStreamModel, self).__init__()
|
||||
self.linear_a = torch.nn.Linear(output_size, output_size)
|
||||
self.relu = MultipleStreamFunction.apply
|
||||
|
||||
def forward(self, model_input):
|
||||
model_input = model_input * 0.2
|
||||
torch.cuda._sleep(1000 * 1000)
|
||||
out = self.relu(model_input)
|
||||
return out
|
||||
|
||||
output_size = 2
|
||||
|
||||
def model_builder():
|
||||
return MultipleStreamModel(output_size)
|
||||
|
||||
def input_generator():
|
||||
return torch.tensor([2.8, 3.4], requires_grad=True) #torch.randn(output_size, dtype=torch.float)
|
||||
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
# Test multi-input and multi-output custom function.
|
||||
cpu_output_list, cuda_output_list = run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
expected_ret_list = [torch.tensor([-0.7760, -0.7280])]
|
||||
|
||||
compare_tensor_list(expected_ret_list, cuda_output_list)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,10 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
PythonOpGrad);
|
||||
|
||||
Status PythonOp::ComputeInternal(OpKernelContext* context) const {
|
||||
// Todo(pengwa): perf impact and how much, leave it now to guarantee correctness.
|
||||
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
|
||||
|
||||
void* diff_ctx = nullptr;
|
||||
std::vector<OrtValue> returned_ortvalues;
|
||||
RunForward(context, &diff_ctx, returned_ortvalues);
|
||||
|
||||
// todo(pengwa): okay to remove it?
|
||||
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
|
||||
|
||||
SetOutputs(context, diff_ctx, returned_ortvalues);
|
||||
|
||||
RefCountTracker::GetInstance().DumpDetails("Forward Kernel Completed");
|
||||
|
|
@ -53,14 +47,9 @@ Status PythonOp::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
Status PythonOpGrad::ComputeInternal(OpKernelContext* context) const {
|
||||
// Todo(pengwa): perf impact and how much, leave it now to guarantee correctness.
|
||||
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
|
||||
|
||||
std::vector<OrtValue> returned_ortvalues;
|
||||
RunBackward(context, returned_ortvalues);
|
||||
|
||||
// todo(pengwa): okay to remove it?
|
||||
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
|
||||
|
||||
SetOutputs(context, returned_ortvalues);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue