mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Backward operation of torch.eig for real eigenvalues (#33090)
Summary: Another pull request to follow up issue https://github.com/pytorch/pytorch/issues/32531. Here I implemented the backward operation for `torch.eig` with a condition that all the eigenvalues are real. This pull request is independent of my another pull request https://github.com/pytorch/pytorch/issues/32932, which means that there is no dependency between this PR and my another PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33090 Differential Revision: D19814347 Pulled By: albanD fbshipit-source-id: 2fae30964e97987abb690544df8240aedeae56e8
This commit is contained in:
parent
c917a247a8
commit
9d94f56ce0
3 changed files with 97 additions and 1 deletions
|
|
@ -2299,6 +2299,39 @@ class TestAutograd(TestCase):
|
|||
[True, False]):
|
||||
_test_with_size(a_size, b_size, upper)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_eig(self):
|
||||
def func(B):
|
||||
return torch.eig(B, eigenvectors=True)
|
||||
|
||||
def func_eigvals(B):
|
||||
return torch.eig(B, eigenvectors=True)[0]
|
||||
|
||||
def func_eigvecs(B):
|
||||
return torch.eig(B, eigenvectors=True)[1]
|
||||
|
||||
def run_test(dims):
|
||||
# The backward operation for eig only works for real eigenvalues,
|
||||
# so the matrix should be B = U^{-1}*A*U where A is a random
|
||||
# symmetric matrix and U is a random full-rank matrix.
|
||||
# Slight change to the matrix should not make the eigenvalues
|
||||
# complex, so we apply requires_grad_ to B, not A and U
|
||||
|
||||
A = random_symmetric_matrix(dims[-1], *dims[:-2])
|
||||
U = torch.rand(*dims)
|
||||
Uinv = torch.inverse(U)
|
||||
B = torch.matmul(Uinv, torch.matmul(A, U)).requires_grad_()
|
||||
|
||||
gradcheck(func, [B])
|
||||
gradgradcheck(func, [B])
|
||||
gradcheck(func_eigvals, [B])
|
||||
gradgradcheck(func_eigvals, [B])
|
||||
gradcheck(func_eigvecs, [B])
|
||||
gradgradcheck(func_eigvecs, [B])
|
||||
|
||||
for dims in [(3, 3), (5, 5)]:
|
||||
run_test(dims)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_symeig(self):
|
||||
def func(root, upper):
|
||||
|
|
@ -3138,6 +3171,20 @@ class TestAutograd(TestCase):
|
|||
out.backward()
|
||||
self.assertIn('MyFunc.apply', str(w[0].message))
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_eig_no_eigenvectors(self):
|
||||
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
|
||||
w, v = torch.eig(A, eigenvectors=False)
|
||||
with self.assertRaisesRegex(RuntimeError, 'cannot compute backward'):
|
||||
torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_eig_complex_eigenvalues(self):
|
||||
A = torch.tensor([[0., -1.], [1., 0.]], dtype=torch.float32, requires_grad=True)
|
||||
w, v = torch.eig(A, eigenvectors=True)
|
||||
with self.assertRaisesRegex(RuntimeError, 'does not support complex eigenvalues'):
|
||||
torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_symeig_no_eigenvectors(self):
|
||||
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@
|
|||
self: _fused_dropout_backward(grad, result1, p)
|
||||
|
||||
- name: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)
|
||||
self: not_implemented("eig")
|
||||
self: eig_backward(grads, self, eigenvectors, eigenvalues, eigenvectors_return)
|
||||
|
||||
- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
self: zeros_like(self)
|
||||
|
|
|
|||
|
|
@ -1792,6 +1792,55 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
|||
return u_term + sigma_term + v_term;
|
||||
}
|
||||
|
||||
// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation"
|
||||
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
||||
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
||||
bool eigenvectors, const Tensor& lambda, const Tensor& v) {
|
||||
// This gradient only works for real eigenvalues at the moment.
|
||||
TORCH_CHECK(eigenvectors,
|
||||
"eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ",
|
||||
"and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)");
|
||||
auto zeros = at::zeros({1}, lambda.options());
|
||||
TORCH_CHECK(
|
||||
at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros),
|
||||
"eig_backward: Backward calculation does not support complex eigenvalues at the moment.");
|
||||
|
||||
auto glambda = grads[0];
|
||||
auto gv = grads[1];
|
||||
auto vt = v.transpose(-2, -1);
|
||||
|
||||
Tensor result;
|
||||
// contribution from the eigenvectors
|
||||
if (gv.defined()) {
|
||||
auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1);
|
||||
|
||||
auto hm = rlambda.transpose(-2,-1) - rlambda;
|
||||
hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
|
||||
hm.pow_(-1.0);
|
||||
|
||||
auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v;
|
||||
auto B = hm * at::matmul(vt, gvortho);
|
||||
auto A = at::matmul(B, vt);
|
||||
|
||||
std::tie(result, std::ignore) = at::solve(A, vt);
|
||||
}
|
||||
// contribution from eigenvalues
|
||||
if (glambda.defined()) {
|
||||
auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt;
|
||||
auto A = at::matmul(v, grlambda);
|
||||
auto vvt = at::matmul(v, vt);
|
||||
if (result.defined()) {
|
||||
Tensor result1;
|
||||
std::tie(result1, std::ignore) = at::solve(A, vvt);
|
||||
result = result.add(result1);
|
||||
}
|
||||
else {
|
||||
std::tie(result, std::ignore) = at::solve(A, vvt);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
|
||||
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
||||
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue