mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Filter nones from ctx saved tensors (#9063)
Co-authored-by: Aishwarya Bhandare <aibhanda@5cb7a9c3931a4b19a66ae028b49221a6000001.ahkw4qp232huflxlm4gmpq4nbh.jx.internal.cloudapp.net>
This commit is contained in:
parent
4930320647
commit
98ac341c5b
1 changed files with 3 additions and 1 deletions
|
|
@ -111,7 +111,9 @@ def call_python_forward_function(
|
|||
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21).
|
||||
# The next edges are stored in Node, with which we can get next gradient function.
|
||||
# https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527
|
||||
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors)
|
||||
# filter out the None in the saved_tensors.
|
||||
saved_tensors = [t for t in ctx.saved_tensors if t is not None]
|
||||
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, saved_tensors)
|
||||
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
|
||||
else:
|
||||
# Context must not present under non-training mode.
|
||||
|
|
|
|||
Loading…
Reference in a new issue