From 8e35df0bf3bfb60b73db5dec23e7de7a1714f418 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Wed, 4 Aug 2021 13:42:39 -0700 Subject: [PATCH] det_backward: return svd path for double backward (so that all ci tests pass) (#62570) Summary: Potentially fixes https://github.com/pytorch/pytorch/issues/62327 and fixes https://github.com/pytorch/pytorch/issues/62328. This PR replaces the double backward of det from eig to svd. The latter is slower but should be more stable. CC anjali411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62570 Reviewed By: pbelevich Differential Revision: D30072876 Pulled By: anjali411 fbshipit-source-id: c91b507dbfd6a3ec47dc6d0b0dcfa5f8c8228c30 --- torch/csrc/autograd/FunctionsManual.cpp | 113 ++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 8 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 0b581a9a053..b0ec2458a3d 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2601,6 +2601,99 @@ Tensor linalg_qr_backward(const std::vector &grads, c } } +Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) { + if (self.numel() == 0) { + return at::empty_like(self); + } + + auto det_backward_nonsingular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { + // Derived from Jacobi's formula for partial derivative, which can be found + // at https://en.wikipedia.org/wiki/Jacobi%27s_formula + // i.e. if A is the input matrix, then + // A_grad = A^{-H} (grad * det.conj()) I, where + // A^{-H} = (A^{-1}).T.conj() + + // create a matrix d := (grad * det.conj()) I + auto d = at::zeros_like(self); + d.diagonal(0, -2, -1).copy_((grad * det.conj()).unsqueeze(-1)); + return at::linalg_solve(self.transpose(-2, -1).conj(), d); + }; + + auto det_backward_singular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { + // Derived from the gradient formula that would be used if `self`'s + // determinant is calculated using SVD, like so: + // u, s, vh = svd(self) + // det(self) = det(u) * prod(s) * det(vh) + // + // This formula should be correct even if `self` is nonsingular. + Tensor u, s, vh; + std::tie(u, s, vh) = at::linalg_svd(self); + auto u_det = at::linalg_det(u); + auto s_prod = s.prod(-1); + auto vh_det = at::linalg_det(vh); + + auto u_det_grad = grad * (vh_det * s_prod).conj(); + auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det); + + auto s_prod_grad = handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj()); + auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false); + + auto vh_det_grad = grad * (u_det * s_prod).conj(); + auto vh_grad = det_backward_nonsingular(vh_det_grad, vh, vh_det); + auto v = vh.transpose(-2, -1).conj(); + auto v_grad = vh_grad.transpose(-2, -1).conj(); + + // svd_backward is written for a function + // svd: self -> (U, S, V), which is different + // from torch.linalg.svd which is a map self -> (U, S, Vh), where + // Vh = V.transpose(-2, -1).conj() + return svd_backward({u_grad, s_grad, v_grad}, self, true, true, u, s, v); + }; + + auto eps = at::native::_get_epsilon(c10::toValueType(self.scalar_type())); + auto singular_det_cutoff = eps * at::linalg_matrix_norm(self); + + if (self.dim() == 2) { + if (det.abs().lt(singular_det_cutoff).item()) { + return det_backward_singular(grad, self, det); + } else { + return det_backward_nonsingular(grad, self, det); + } + } else { + auto nonzero_det_mask = det.abs().ge(singular_det_cutoff); + if (nonzero_det_mask.all().item()) { + return det_backward_nonsingular(grad, self, det); + } + + auto zero_det_mask = nonzero_det_mask.logical_not(); + if (zero_det_mask.all().item()) { + return det_backward_singular(grad, self, det); + } + + Tensor self_grad = self.new_empty(self.sizes(), self.options()); + + auto nonzero_det_list = at::native::toListOfOptionalTensors(nonzero_det_mask); + self_grad.index_put_( + /*indices=*/nonzero_det_list, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/det_backward_nonsingular( + grad.index(nonzero_det_list), + self.index(nonzero_det_list), + det.index(nonzero_det_list))); + + auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask); + self_grad.index_put_( + /*indices=*/zero_det_list, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/det_backward_singular( + grad.index(zero_det_list), + self.index(zero_det_list), + det.index(zero_det_list))); + + return self_grad; + } +} + // The backward for this function is just a specialized version of // lu.backward, which is implemented in /torch/_autograd_functions.py Tensor _det_lu_based_helper_backward( @@ -2617,15 +2710,19 @@ Tensor _det_lu_based_helper_backward( return Tensor(); } - // lu_solve does not support double backward yet. - // Hence, if double backward, we use the eigendecomposition. + + // run det_backward only if backward is run on _det_lu_based_helper_backward. + // _det_lu_based_helper_backward is more stable for forward det computing functions, + // but it fails with double backward gradient checks (gradgradcheck). + // det_backward, on the other hand, is less stable (due to restrictions on svd_backward, + // namely, svd_backward requries distinct singular values which are sufficiently different + // from each other), yet, if its computation is stable, so is its double backward. + // Hence, if only single backward is run, we use _det_lu_based_helper_backward, + // for the double backward case we use det_backward. The latter approach could produce + // unstable gradients, therefore we DO NOT recommend double backpropagation through + // det computing functions. if (at::GradMode::is_enabled()) { - Tensor l, v; - std::tie(l, v) = at::linalg_eig(self); - - auto l_grad = prod_backward(det_grad, l, l.prod(-1), -1, false); - - return linalg_eig_backward({l_grad, {}}, self, l, v); + return det_backward(det_grad, self, det); } // we use a sequence of kernels to avoid memory copies and checks,