Filter nones from ctx saved tensors (#9063)

Co-authored-by: Aishwarya Bhandare <aibhanda@5cb7a9c3931a4b19a66ae028b49221a6000001.ahkw4qp232huflxlm4gmpq4nbh.jx.internal.cloudapp.net>
This commit is contained in:
ashbhandare 2021-09-15 10:13:45 -07:00 committed by GitHub
parent 4930320647
commit 98ac341c5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -111,7 +111,9 @@ def call_python_forward_function(
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21).
# The next edges are stored in Node, with which we can get next gradient function.
# https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors)
# filter out the None in the saved_tensors.
saved_tensors = [t for t in ctx.saved_tensors if t is not None]
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors)
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
else:
# Context must not present under non-training mode.