diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 208fd16e0e..67bef86af1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -111,7 +111,9 @@ def call_python_forward_function( # (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). # The next edges are stored in Node, with which we can get next gradient function. # https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 - torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors) + # filter out the None in the saved_tensors. + saved_tensors = [t for t in ctx.saved_tensors if t is not None] + torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors) torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output) else: # Context must not present under non-training mode.