det_backward: return svd path for double backward (so that all ci tests pass) (#62570)

Summary:
Potentially fixes https://github.com/pytorch/pytorch/issues/62327 and fixes https://github.com/pytorch/pytorch/issues/62328.
This PR replaces the double backward of det from eig to svd. The latter is slower but should be more stable.

CC anjali411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62570

Reviewed By: pbelevich

Differential Revision: D30072876

Pulled By: anjali411

fbshipit-source-id: c91b507dbfd6a3ec47dc6d0b0dcfa5f8c8228c30
This commit is contained in:
Nikita Vedeneev 2021-08-04 13:42:39 -07:00 committed by Facebook GitHub Bot
parent 6f0abba04c
commit 8e35df0bf3

View file

@ -2601,6 +2601,99 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c
}
}
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
if (self.numel() == 0) {
return at::empty_like(self);
}
auto det_backward_nonsingular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
// Derived from Jacobi's formula for partial derivative, which can be found
// at https://en.wikipedia.org/wiki/Jacobi%27s_formula
// i.e. if A is the input matrix, then
// A_grad = A^{-H} (grad * det.conj()) I, where
// A^{-H} = (A^{-1}).T.conj()
// create a matrix d := (grad * det.conj()) I
auto d = at::zeros_like(self);
d.diagonal(0, -2, -1).copy_((grad * det.conj()).unsqueeze(-1));
return at::linalg_solve(self.transpose(-2, -1).conj(), d);
};
auto det_backward_singular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
// Derived from the gradient formula that would be used if `self`'s
// determinant is calculated using SVD, like so:
// u, s, vh = svd(self)
// det(self) = det(u) * prod(s) * det(vh)
//
// This formula should be correct even if `self` is nonsingular.
Tensor u, s, vh;
std::tie(u, s, vh) = at::linalg_svd(self);
auto u_det = at::linalg_det(u);
auto s_prod = s.prod(-1);
auto vh_det = at::linalg_det(vh);
auto u_det_grad = grad * (vh_det * s_prod).conj();
auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det);
auto s_prod_grad = handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj());
auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false);
auto vh_det_grad = grad * (u_det * s_prod).conj();
auto vh_grad = det_backward_nonsingular(vh_det_grad, vh, vh_det);
auto v = vh.transpose(-2, -1).conj();
auto v_grad = vh_grad.transpose(-2, -1).conj();
// svd_backward is written for a function
// svd: self -> (U, S, V), which is different
// from torch.linalg.svd which is a map self -> (U, S, Vh), where
// Vh = V.transpose(-2, -1).conj()
return svd_backward({u_grad, s_grad, v_grad}, self, true, true, u, s, v);
};
auto eps = at::native::_get_epsilon(c10::toValueType(self.scalar_type()));
auto singular_det_cutoff = eps * at::linalg_matrix_norm(self);
if (self.dim() == 2) {
if (det.abs().lt(singular_det_cutoff).item<bool>()) {
return det_backward_singular(grad, self, det);
} else {
return det_backward_nonsingular(grad, self, det);
}
} else {
auto nonzero_det_mask = det.abs().ge(singular_det_cutoff);
if (nonzero_det_mask.all().item<bool>()) {
return det_backward_nonsingular(grad, self, det);
}
auto zero_det_mask = nonzero_det_mask.logical_not();
if (zero_det_mask.all().item<bool>()) {
return det_backward_singular(grad, self, det);
}
Tensor self_grad = self.new_empty(self.sizes(), self.options());
auto nonzero_det_list = at::native::toListOfOptionalTensors(nonzero_det_mask);
self_grad.index_put_(
/*indices=*/nonzero_det_list,
// NOLINTNEXTLINE(bugprone-argument-comment)
/*value=*/det_backward_nonsingular(
grad.index(nonzero_det_list),
self.index(nonzero_det_list),
det.index(nonzero_det_list)));
auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask);
self_grad.index_put_(
/*indices=*/zero_det_list,
// NOLINTNEXTLINE(bugprone-argument-comment)
/*value=*/det_backward_singular(
grad.index(zero_det_list),
self.index(zero_det_list),
det.index(zero_det_list)));
return self_grad;
}
}
// The backward for this function is just a specialized version of
// lu.backward, which is implemented in /torch/_autograd_functions.py
Tensor _det_lu_based_helper_backward(
@ -2617,15 +2710,19 @@ Tensor _det_lu_based_helper_backward(
return Tensor();
}
// lu_solve does not support double backward yet.
// Hence, if double backward, we use the eigendecomposition.
// run det_backward only if backward is run on _det_lu_based_helper_backward.
// _det_lu_based_helper_backward is more stable for forward det computing functions,
// but it fails with double backward gradient checks (gradgradcheck).
// det_backward, on the other hand, is less stable (due to restrictions on svd_backward,
// namely, svd_backward requries distinct singular values which are sufficiently different
// from each other), yet, if its computation is stable, so is its double backward.
// Hence, if only single backward is run, we use _det_lu_based_helper_backward,
// for the double backward case we use det_backward. The latter approach could produce
// unstable gradients, therefore we DO NOT recommend double backpropagation through
// det computing functions.
if (at::GradMode::is_enabled()) {
Tensor l, v;
std::tie(l, v) = at::linalg_eig(self);
auto l_grad = prod_backward(det_grad, l, l.prod(-1), -1, false);
return linalg_eig_backward({l_grad, {}}, self, l, v);
return det_backward(det_grad, self, det);
}
// we use a sequence of kernels to avoid memory copies and checks,