mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
det_backward: return svd path for double backward (so that all ci tests pass) (#62570)
Summary: Potentially fixes https://github.com/pytorch/pytorch/issues/62327 and fixes https://github.com/pytorch/pytorch/issues/62328. This PR replaces the double backward of det from eig to svd. The latter is slower but should be more stable. CC anjali411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62570 Reviewed By: pbelevich Differential Revision: D30072876 Pulled By: anjali411 fbshipit-source-id: c91b507dbfd6a3ec47dc6d0b0dcfa5f8c8228c30
This commit is contained in:
parent
6f0abba04c
commit
8e35df0bf3
1 changed files with 105 additions and 8 deletions
|
|
@ -2601,6 +2601,99 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c
|
|||
}
|
||||
}
|
||||
|
||||
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
|
||||
if (self.numel() == 0) {
|
||||
return at::empty_like(self);
|
||||
}
|
||||
|
||||
auto det_backward_nonsingular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
|
||||
// Derived from Jacobi's formula for partial derivative, which can be found
|
||||
// at https://en.wikipedia.org/wiki/Jacobi%27s_formula
|
||||
// i.e. if A is the input matrix, then
|
||||
// A_grad = A^{-H} (grad * det.conj()) I, where
|
||||
// A^{-H} = (A^{-1}).T.conj()
|
||||
|
||||
// create a matrix d := (grad * det.conj()) I
|
||||
auto d = at::zeros_like(self);
|
||||
d.diagonal(0, -2, -1).copy_((grad * det.conj()).unsqueeze(-1));
|
||||
return at::linalg_solve(self.transpose(-2, -1).conj(), d);
|
||||
};
|
||||
|
||||
auto det_backward_singular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
|
||||
// Derived from the gradient formula that would be used if `self`'s
|
||||
// determinant is calculated using SVD, like so:
|
||||
// u, s, vh = svd(self)
|
||||
// det(self) = det(u) * prod(s) * det(vh)
|
||||
//
|
||||
// This formula should be correct even if `self` is nonsingular.
|
||||
Tensor u, s, vh;
|
||||
std::tie(u, s, vh) = at::linalg_svd(self);
|
||||
auto u_det = at::linalg_det(u);
|
||||
auto s_prod = s.prod(-1);
|
||||
auto vh_det = at::linalg_det(vh);
|
||||
|
||||
auto u_det_grad = grad * (vh_det * s_prod).conj();
|
||||
auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det);
|
||||
|
||||
auto s_prod_grad = handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj());
|
||||
auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false);
|
||||
|
||||
auto vh_det_grad = grad * (u_det * s_prod).conj();
|
||||
auto vh_grad = det_backward_nonsingular(vh_det_grad, vh, vh_det);
|
||||
auto v = vh.transpose(-2, -1).conj();
|
||||
auto v_grad = vh_grad.transpose(-2, -1).conj();
|
||||
|
||||
// svd_backward is written for a function
|
||||
// svd: self -> (U, S, V), which is different
|
||||
// from torch.linalg.svd which is a map self -> (U, S, Vh), where
|
||||
// Vh = V.transpose(-2, -1).conj()
|
||||
return svd_backward({u_grad, s_grad, v_grad}, self, true, true, u, s, v);
|
||||
};
|
||||
|
||||
auto eps = at::native::_get_epsilon(c10::toValueType(self.scalar_type()));
|
||||
auto singular_det_cutoff = eps * at::linalg_matrix_norm(self);
|
||||
|
||||
if (self.dim() == 2) {
|
||||
if (det.abs().lt(singular_det_cutoff).item<bool>()) {
|
||||
return det_backward_singular(grad, self, det);
|
||||
} else {
|
||||
return det_backward_nonsingular(grad, self, det);
|
||||
}
|
||||
} else {
|
||||
auto nonzero_det_mask = det.abs().ge(singular_det_cutoff);
|
||||
if (nonzero_det_mask.all().item<bool>()) {
|
||||
return det_backward_nonsingular(grad, self, det);
|
||||
}
|
||||
|
||||
auto zero_det_mask = nonzero_det_mask.logical_not();
|
||||
if (zero_det_mask.all().item<bool>()) {
|
||||
return det_backward_singular(grad, self, det);
|
||||
}
|
||||
|
||||
Tensor self_grad = self.new_empty(self.sizes(), self.options());
|
||||
|
||||
auto nonzero_det_list = at::native::toListOfOptionalTensors(nonzero_det_mask);
|
||||
self_grad.index_put_(
|
||||
/*indices=*/nonzero_det_list,
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
/*value=*/det_backward_nonsingular(
|
||||
grad.index(nonzero_det_list),
|
||||
self.index(nonzero_det_list),
|
||||
det.index(nonzero_det_list)));
|
||||
|
||||
auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask);
|
||||
self_grad.index_put_(
|
||||
/*indices=*/zero_det_list,
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
/*value=*/det_backward_singular(
|
||||
grad.index(zero_det_list),
|
||||
self.index(zero_det_list),
|
||||
det.index(zero_det_list)));
|
||||
|
||||
return self_grad;
|
||||
}
|
||||
}
|
||||
|
||||
// The backward for this function is just a specialized version of
|
||||
// lu.backward, which is implemented in /torch/_autograd_functions.py
|
||||
Tensor _det_lu_based_helper_backward(
|
||||
|
|
@ -2617,15 +2710,19 @@ Tensor _det_lu_based_helper_backward(
|
|||
return Tensor();
|
||||
}
|
||||
|
||||
// lu_solve does not support double backward yet.
|
||||
// Hence, if double backward, we use the eigendecomposition.
|
||||
|
||||
// run det_backward only if backward is run on _det_lu_based_helper_backward.
|
||||
// _det_lu_based_helper_backward is more stable for forward det computing functions,
|
||||
// but it fails with double backward gradient checks (gradgradcheck).
|
||||
// det_backward, on the other hand, is less stable (due to restrictions on svd_backward,
|
||||
// namely, svd_backward requries distinct singular values which are sufficiently different
|
||||
// from each other), yet, if its computation is stable, so is its double backward.
|
||||
// Hence, if only single backward is run, we use _det_lu_based_helper_backward,
|
||||
// for the double backward case we use det_backward. The latter approach could produce
|
||||
// unstable gradients, therefore we DO NOT recommend double backpropagation through
|
||||
// det computing functions.
|
||||
if (at::GradMode::is_enabled()) {
|
||||
Tensor l, v;
|
||||
std::tie(l, v) = at::linalg_eig(self);
|
||||
|
||||
auto l_grad = prod_backward(det_grad, l, l.prod(-1), -1, false);
|
||||
|
||||
return linalg_eig_backward({l_grad, {}}, self, l, v);
|
||||
return det_backward(det_grad, self, det);
|
||||
}
|
||||
|
||||
// we use a sequence of kernels to avoid memory copies and checks,
|
||||
|
|
|
|||
Loading…
Reference in a new issue