mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Eliminate global usage of torch.set_default_dtype in test_autograd (#56446)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56446 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28000589 Pulled By: mruberry fbshipit-source-id: c8fb2907d656138e72ecf8fb3e572591f8972900
This commit is contained in:
parent
154eca0309
commit
87242d2393
2 changed files with 138 additions and 126 deletions
|
|
@ -17,10 +17,6 @@ from functools import reduce, partial
|
|||
import torch
|
||||
import json
|
||||
|
||||
# TODO: remove this global setting
|
||||
# Autograd tests use double as the default dtype
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
from torch import nn
|
||||
from torch._six import inf, nan
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
|
@ -327,9 +323,9 @@ class TestAutograd(TestCase):
|
|||
grads[1] * var2 * scale + grads[4] * scale,
|
||||
)
|
||||
|
||||
t1 = torch.rand(10, requires_grad=True)
|
||||
t2 = torch.rand(10, requires_grad=True)
|
||||
t3 = torch.rand(10)
|
||||
t1 = torch.rand(10, dtype=torch.double, requires_grad=True)
|
||||
t2 = torch.rand(10, dtype=torch.double, requires_grad=True)
|
||||
t3 = torch.rand(10, dtype=torch.double)
|
||||
scale = random.randint(0, 10)
|
||||
res = MyFunction.apply(t1, t2, scale, t3)
|
||||
self.assertEqual(scale, res[0])
|
||||
|
|
@ -450,7 +446,7 @@ class TestAutograd(TestCase):
|
|||
|
||||
@skipIfNoLapack
|
||||
def test_slogdet_sign(self):
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
a = torch.randn(3, 3, dtype=torch.double, requires_grad=True)
|
||||
s, logdet = a.slogdet()
|
||||
|
||||
# test that sign should not require grad
|
||||
|
|
@ -763,6 +759,7 @@ class TestAutograd(TestCase):
|
|||
def test_hooks_cpp(self):
|
||||
# Tests hooks for autograd function implemented in C++
|
||||
bn = torch.nn.BatchNorm1d(5, affine=False)
|
||||
bn.double()
|
||||
bn.eval()
|
||||
|
||||
counter = [0]
|
||||
|
|
@ -771,13 +768,13 @@ class TestAutograd(TestCase):
|
|||
counter[0] += 1
|
||||
return grad * 2
|
||||
|
||||
x = torch.ones(5, 5, requires_grad=True)
|
||||
x = torch.ones(5, 5, dtype=torch.double, requires_grad=True)
|
||||
z = bn(x)
|
||||
z.register_hook(bw_hook)
|
||||
z.sum().backward()
|
||||
|
||||
self.assertEqual(counter[0], 1, msg='bw_hook not called')
|
||||
self.assertEqual(x.grad, torch.ones(5, 5) * 2, atol=1e-5, rtol=0)
|
||||
self.assertEqual(x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0)
|
||||
|
||||
def test_hook_none(self):
|
||||
# WARNING: this is a test for autograd internals.
|
||||
|
|
@ -891,15 +888,15 @@ class TestAutograd(TestCase):
|
|||
fn = FixedGradientFunction
|
||||
|
||||
# sparse first
|
||||
x = torch.randn(size, requires_grad=True)
|
||||
x = torch.randn(size, dtype=torch.double, requires_grad=True)
|
||||
(fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().backward()
|
||||
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
|
||||
# dense first
|
||||
x = torch.randn(size, requires_grad=True)
|
||||
x = torch.randn(size, dtype=torch.double, requires_grad=True)
|
||||
(fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
|
||||
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
|
||||
# sparse only
|
||||
x = torch.randn(size, requires_grad=True)
|
||||
x = torch.randn(size, dtype=torch.double, requires_grad=True)
|
||||
(fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
|
||||
self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
|
||||
|
||||
|
|
@ -995,8 +992,8 @@ class TestAutograd(TestCase):
|
|||
self.assertRaises(RuntimeError, call_backwards)
|
||||
|
||||
def test_backward_with_inputs(self):
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
y = torch.randn(2, 2, requires_grad=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
|
||||
def fn():
|
||||
return x ** 2 + y * x + y ** 2
|
||||
|
|
@ -1017,31 +1014,31 @@ class TestAutograd(TestCase):
|
|||
reset_grad()
|
||||
torch.autograd.backward(fn(), gradient, inputs=[x])
|
||||
self.assertEqual(x.grad, x_grad_expected)
|
||||
self.assertEqual(y.grad, torch.zeros(2, 2))
|
||||
self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False)
|
||||
|
||||
reset_grad()
|
||||
torch.autograd.backward(fn(), gradient, inputs=[y])
|
||||
self.assertEqual(y.grad, y_grad_expected)
|
||||
self.assertEqual(x.grad, torch.zeros(2, 2))
|
||||
self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
|
||||
|
||||
reset_grad()
|
||||
torch.autograd.backward(fn(), gradient, inputs=y)
|
||||
self.assertEqual(y.grad, y_grad_expected)
|
||||
self.assertEqual(x.grad, torch.zeros(2, 2))
|
||||
self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
|
||||
|
||||
reset_grad()
|
||||
self.assertRaisesRegex(RuntimeError, 'cannot be empty',
|
||||
lambda: torch.autograd.backward(fn(), gradient, inputs=[]))
|
||||
|
||||
def test_backward_with_nonleaf_inputs(self):
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
x_nonleaf = x * 1
|
||||
y = torch.randn(2, 2, requires_grad=True)
|
||||
z = torch.randn(2, 2, requires_grad=True)
|
||||
y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
z = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
|
||||
out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2
|
||||
|
||||
out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y])
|
||||
out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[x, y])
|
||||
x_grad_expected = 2 * x + y
|
||||
y_grad_expected = x + 2 * y
|
||||
|
||||
|
|
@ -1049,12 +1046,13 @@ class TestAutograd(TestCase):
|
|||
self.assertEqual(x.grad, x_grad_expected)
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, 'not a leaf Tensor',
|
||||
lambda: out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y, x_nonleaf]))
|
||||
lambda: out.backward(torch.ones(2, 2, dtype=torch.double),
|
||||
create_graph=True, inputs=[x, y, x_nonleaf]))
|
||||
|
||||
# backward doesn't have an allow_unused flag, so the behavior of backward
|
||||
# when variable is not part of the graph is as if allow_used were true
|
||||
# x.grad will simply be None.
|
||||
out.backward(torch.ones(2, 2), create_graph=True, inputs=[z])
|
||||
out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z])
|
||||
self.assertIsNone(z.grad)
|
||||
|
||||
def test_dependent_backward(self):
|
||||
|
|
@ -1929,7 +1927,7 @@ class TestAutograd(TestCase):
|
|||
self.assertIs(a, b)
|
||||
return a + b
|
||||
|
||||
x = torch.randn(5, 5, requires_grad=True)
|
||||
x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(fn, [x])
|
||||
gradgradcheck(fn, [x])
|
||||
|
||||
|
|
@ -1950,7 +1948,7 @@ class TestAutograd(TestCase):
|
|||
self.assertIs(a, b)
|
||||
return a + b
|
||||
|
||||
x = torch.randn(5, 5, requires_grad=True)
|
||||
x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(inplace_fn, [x])
|
||||
gradgradcheck(inplace_fn, [x])
|
||||
|
||||
|
|
@ -2030,7 +2028,7 @@ class TestAutograd(TestCase):
|
|||
|
||||
def test_select_sum(self):
|
||||
# both select and sum return Scalars in ATen; ensure they work together.
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
x = torch.randn(10, dtype=torch.double, requires_grad=True)
|
||||
|
||||
def func(x):
|
||||
return x.select(0, 1).sum()
|
||||
|
|
@ -2041,9 +2039,9 @@ class TestAutograd(TestCase):
|
|||
def test_diagonal_expanded_v(self):
|
||||
value = torch.rand([])
|
||||
v_expanded = torch.tensor(value).expand(10)
|
||||
a = torch.rand(10, 10, requires_grad=True)
|
||||
a = torch.rand(10, 10, dtype=torch.double, requires_grad=True)
|
||||
result, = torch.autograd.grad(a.diagonal(), a, v_expanded)
|
||||
self.assertEqual(result, torch.eye(10) * value)
|
||||
self.assertEqual(result, torch.eye(10, dtype=torch.double) * value)
|
||||
|
||||
def test_select_expanded_v(self):
|
||||
v_expanded = torch.rand(10).expand(10, 10)
|
||||
|
|
@ -2735,28 +2733,28 @@ class TestAutograd(TestCase):
|
|||
d.sum().backward()
|
||||
|
||||
def test_broadcast_tensors(self):
|
||||
f_args_variable = (torch.randn(3, requires_grad=True),
|
||||
torch.randn(1, 2, 1, requires_grad=True),
|
||||
torch.randn(1, 1, requires_grad=True),
|
||||
torch.randn(5, 1, 1, requires_grad=True))
|
||||
f_args_variable = (torch.randn(3, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(1, 2, 1, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(1, 1, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(5, 1, 1, dtype=torch.double, requires_grad=True))
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_broadcast_tensors", "broadcast",
|
||||
lambda a, b, c, d: torch.broadcast_tensors(a, b, c, d),
|
||||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_block_diag(self):
|
||||
f_args_variable = (torch.randn(1, S, requires_grad=True),
|
||||
torch.randn(2, S, requires_grad=True),
|
||||
torch.randn(3, S, requires_grad=True))
|
||||
f_args_variable = (torch.randn(1, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(2, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(3, S, dtype=torch.double, requires_grad=True))
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_block_diag", "block_diag",
|
||||
lambda a, b, c: torch.block_diag(a, b, c),
|
||||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_cat(self):
|
||||
f_args_variable = (torch.randn(1, S, S, requires_grad=True),
|
||||
torch.randn(2, S, S, requires_grad=True),
|
||||
torch.randn(3, S, S, requires_grad=True),
|
||||
f_args_variable = (torch.randn(1, S, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(2, S, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(3, S, S, dtype=torch.double, requires_grad=True),
|
||||
0)
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_cat", "cat",
|
||||
|
|
@ -2764,9 +2762,9 @@ class TestAutograd(TestCase):
|
|||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_cat_negdim_1(self):
|
||||
f_args_variable = (torch.randn(S, S, 1, requires_grad=True),
|
||||
torch.randn(S, S, 2, requires_grad=True),
|
||||
torch.randn(S, S, 3, requires_grad=True),
|
||||
f_args_variable = (torch.randn(S, S, 1, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, S, 2, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, S, 3, dtype=torch.double, requires_grad=True),
|
||||
-1)
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_cat_negdim_1", "cat",
|
||||
|
|
@ -2774,9 +2772,9 @@ class TestAutograd(TestCase):
|
|||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_cat_negdim_2(self):
|
||||
f_args_variable = (torch.randn(S, 1, S, requires_grad=True),
|
||||
torch.randn(S, 2, S, requires_grad=True),
|
||||
torch.randn(S, 3, S, requires_grad=True),
|
||||
f_args_variable = (torch.randn(S, 1, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, 2, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, 3, S, dtype=torch.double, requires_grad=True),
|
||||
-2)
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_cat_negdim_2", "cat",
|
||||
|
|
@ -2784,8 +2782,8 @@ class TestAutograd(TestCase):
|
|||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_cat_empty_legacy(self):
|
||||
f_args_variable = (torch.randn(0, requires_grad=True),
|
||||
torch.randn(S, S, requires_grad=True))
|
||||
f_args_variable = (torch.randn(0, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, S, dtype=torch.double, requires_grad=True))
|
||||
# gradgradcheck doesn't work, probably because legacy size tracking is wrong somewhere,
|
||||
# hence False passed below, but gradcheck checked explicitly.
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
|
|
@ -2795,16 +2793,16 @@ class TestAutograd(TestCase):
|
|||
self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION))
|
||||
|
||||
def test_cat_empty(self):
|
||||
f_args_variable = (torch.randn(0, S, requires_grad=True),
|
||||
torch.randn(S, S, requires_grad=True))
|
||||
f_args_variable = (torch.randn(0, S, dtype=torch.double, requires_grad=True),
|
||||
torch.randn(S, S, dtype=torch.double, requires_grad=True))
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_cat_empty", "cat",
|
||||
lambda a, b: torch.cat((a, b)),
|
||||
True, f_args_variable, f_args_tensor)
|
||||
|
||||
def test_trapz(self):
|
||||
f_args_variable = (torch.randn(2, 3, requires_grad=True),
|
||||
torch.tensor([[1.0, 2.0, 5.5], [2.3, 0.5, 6.2]], requires_grad=True))
|
||||
f_args_variable = (torch.randn(2, 3, dtype=torch.double, requires_grad=True),
|
||||
torch.tensor([[1.0, 2.0, 5.5], [2.3, 0.5, 6.2]], dtype=torch.double, requires_grad=True))
|
||||
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
|
||||
run_functional_checks(self, "test_trapz", "trapz",
|
||||
lambda y, x: torch.trapz(y, x),
|
||||
|
|
@ -3379,9 +3377,9 @@ class TestAutograd(TestCase):
|
|||
|
||||
def _test_lerp_tensor_weights(self, cast):
|
||||
def construct_inputs(*shapes):
|
||||
start = cast(torch.randn(shapes[0])).requires_grad_()
|
||||
end = cast(torch.randn(shapes[1])).requires_grad_()
|
||||
weight = cast(torch.randn(shapes[2])).requires_grad_()
|
||||
start = cast(torch.randn(shapes[0], dtype=torch.double)).requires_grad_()
|
||||
end = cast(torch.randn(shapes[1], dtype=torch.double)).requires_grad_()
|
||||
weight = cast(torch.randn(shapes[2], dtype=torch.double)).requires_grad_()
|
||||
return [start, end, weight]
|
||||
|
||||
all_test_shapes = [((3, 3, 3), (3, 3, 3), (3, 3, 3)), # no broadcasting
|
||||
|
|
@ -3945,10 +3943,10 @@ class TestAutograd(TestCase):
|
|||
def fn(sparse):
|
||||
return torch.sparse.sum(sparse)
|
||||
|
||||
gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=True,
|
||||
gradcheck(fn, torch.rand(10, dtype=torch.double).to_sparse().requires_grad_(True), check_sparse_nnz=True,
|
||||
check_batched_grad=False, fast_mode=fast_mode)
|
||||
with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
|
||||
gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=False,
|
||||
gradcheck(fn, torch.rand(10, dtype=torch.double).to_sparse().requires_grad_(True), check_sparse_nnz=False,
|
||||
check_batched_grad=False, fast_mode=fast_mode)
|
||||
check(fast_mode=True)
|
||||
check(fast_mode=False)
|
||||
|
|
@ -3965,7 +3963,7 @@ class TestAutograd(TestCase):
|
|||
return NonDetFunc.apply(grad_out, ctx._jitter) * (1 + torch.rand_like(grad_out) * ctx._jitter), None
|
||||
|
||||
def check(fast_mode):
|
||||
inp = torch.randn(5, 5, requires_grad=True)
|
||||
inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, check_batched_grad=False, fast_mode=fast_mode)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Backward is not reentrant'):
|
||||
gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, check_batched_grad=False, fast_mode=fast_mode)
|
||||
|
|
@ -4040,7 +4038,7 @@ class TestAutograd(TestCase):
|
|||
|
||||
def test_gradcheck_check_batched_grad(self):
|
||||
def check(fast_mode):
|
||||
x = torch.rand(10, requires_grad=True).to_sparse()
|
||||
x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse()
|
||||
# runtime error while compute batched grad (print big error)
|
||||
with self.assertRaisesRegex(RuntimeError, 'gradcheck or gradgradcheck failed while testing batched gradient'):
|
||||
gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=True, check_batched_grad=True, fast_mode=fast_mode)
|
||||
|
|
@ -4060,7 +4058,7 @@ class TestAutograd(TestCase):
|
|||
y = x.clone()
|
||||
y.register_hook(hook)
|
||||
return y.to_dense()
|
||||
x = torch.ones((2, 2), requires_grad=True).to_sparse()
|
||||
x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse()
|
||||
with self.assertRaisesRegex(RuntimeError, 'grad is sparse tensor, but has incorrect sparse_dim'):
|
||||
gradcheck(fn, (x,), atol=1e-1, check_sparse_nnz=True, check_batched_grad=False, fast_mode=fast_mode)
|
||||
self.assertFalse(gradcheck(fn, (x,), atol=1e-1, check_sparse_nnz=True, check_batched_grad=False,
|
||||
|
|
@ -4071,7 +4069,7 @@ class TestAutograd(TestCase):
|
|||
y = x.clone()
|
||||
y.register_hook(lambda x: x + 1e-2)
|
||||
return y
|
||||
x = torch.ones(1, requires_grad=True)
|
||||
x = torch.ones(1, dtype=torch.double, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, 'backward not multiplied by grad_output'):
|
||||
gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode)
|
||||
self.assertFalse(gradcheck(fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode))
|
||||
|
|
@ -4081,7 +4079,7 @@ class TestAutograd(TestCase):
|
|||
y = x.clone().to_dense()
|
||||
y.register_hook(lambda x: x + 1e-2)
|
||||
return y
|
||||
x = torch.ones(1, requires_grad=True).to_sparse()
|
||||
x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse()
|
||||
with self.assertRaisesRegex(RuntimeError, 'backward not multiplied by grad_output'):
|
||||
gradcheck(fn3, (x,), atol=1e-1, check_sparse_nnz=True, check_batched_grad=False, fast_mode=fast_mode)
|
||||
self.assertFalse(gradcheck(fn3, (x,), atol=1e-1, check_sparse_nnz=True, check_batched_grad=False,
|
||||
|
|
@ -4096,7 +4094,7 @@ class TestAutograd(TestCase):
|
|||
@staticmethod
|
||||
def backward(ctx, x):
|
||||
return x.to_sparse()
|
||||
x = torch.ones(1, requires_grad=True)
|
||||
x = torch.ones(1, dtype=torch.double, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, 'grad is incorrect layout'):
|
||||
gradcheck(Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode)
|
||||
self.assertFalse(gradcheck(Test.apply, (x,), check_batched_grad=False, raise_exception=False, fast_mode=fast_mode))
|
||||
|
|
@ -4113,7 +4111,7 @@ class TestAutograd(TestCase):
|
|||
y = x.clone()
|
||||
y.register_hook(hook)
|
||||
return y
|
||||
x = torch.ones(1, requires_grad=True)
|
||||
x = torch.ones(1, dtype=torch.double, requires_grad=True)
|
||||
with self.assertWarnsRegex(UserWarning, "Backwards compatibility: New undefined gradient support checking feature"):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Expected backward function to handle undefined output grads'):
|
||||
gradcheck(fn, (x,), fast_mode=fast_mode)
|
||||
|
|
@ -4160,8 +4158,8 @@ class TestAutograd(TestCase):
|
|||
def check(fast_mode):
|
||||
def fn(x, y):
|
||||
return x * y.coalesce().to_dense()
|
||||
a = torch.rand(2, 2, requires_grad=True)
|
||||
b = torch.rand(2, 2).to_sparse().requires_grad_(True)
|
||||
a = torch.rand(2, 2, dtype=torch.double, requires_grad=True)
|
||||
b = torch.rand(2, 2, dtype=torch.double,).to_sparse().requires_grad_(True)
|
||||
self.assertTrue(gradcheck(fn, (a, b), check_sparse_nnz=True, check_batched_grad=False, fast_mode=fast_mode))
|
||||
check(fast_mode=True)
|
||||
check(fast_mode=False)
|
||||
|
|
@ -4189,7 +4187,7 @@ class TestAutograd(TestCase):
|
|||
return torch.cat([x, x])
|
||||
else:
|
||||
return x
|
||||
a = torch.ones(1, requires_grad=True)
|
||||
a = torch.ones(1, dtype=torch.double, requires_grad=True)
|
||||
with self.assertRaisesRegex(AssertionError, 'return outputs with the same shape when inputs are perturbed'):
|
||||
self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode))
|
||||
|
||||
|
|
@ -4231,12 +4229,12 @@ class TestAutograd(TestCase):
|
|||
|
||||
with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
|
||||
jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
|
||||
self.assertEqual(jacobian[0], 2 * torch.eye(4))
|
||||
self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
|
||||
jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=1)
|
||||
self.assertEqual(jacobian[0], 2 * torch.eye(4))
|
||||
self.assertEqual(jacobian[1], 1 * torch.eye(4))
|
||||
jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
|
||||
self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
|
||||
self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
|
||||
jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0)
|
||||
|
|
@ -4253,8 +4251,8 @@ class TestAutograd(TestCase):
|
|||
outputs = fn(a, b)
|
||||
with self.assertWarnsRegex(UserWarning, "get_analytical_jacobian was part of PyTorch's private API"):
|
||||
jacobians, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian((a, b), outputs[0])
|
||||
self.assertEqual(jacobians[0], 2 * torch.eye(4))
|
||||
self.assertEqual(jacobians[1], 1 * torch.eye(4))
|
||||
self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double))
|
||||
self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double))
|
||||
self.assertTrue(reentrant)
|
||||
|
||||
class NonDetFunc(Function):
|
||||
|
|
@ -5157,7 +5155,7 @@ for shape in [(1,), ()]:
|
|||
self.assertEqual(b.grad, torch.tensor([-inf, 0., 0.]))
|
||||
|
||||
def test_nansum_with_nans(self):
|
||||
a = torch.randn(2, 2, 2, 2)
|
||||
a = torch.randn(2, 2, 2, 2, dtype=torch.double)
|
||||
with torch.no_grad():
|
||||
a[a < 0.2] = float('nan')
|
||||
a.requires_grad = True
|
||||
|
|
@ -5198,7 +5196,7 @@ for shape in [(1,), ()]:
|
|||
test(inp, torch.double, torch.float)
|
||||
|
||||
def test_nan_to_num(self):
|
||||
a = torch.randn(3, 3, 3, 3)
|
||||
a = torch.randn(3, 3, 3, 3, dtype=torch.double)
|
||||
with torch.no_grad():
|
||||
a[torch.rand_like(a) < 0.2] = float('nan')
|
||||
a[torch.rand_like(a) < 0.2] = float('inf')
|
||||
|
|
@ -5891,8 +5889,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def test_vjp_create_graph(self):
|
||||
def reducer(x):
|
||||
return x.sum(dim=1)
|
||||
inputs = torch.rand(2, 2)
|
||||
v = torch.ones(2)
|
||||
inputs = torch.rand(2, 2, dtype=torch.double)
|
||||
v = torch.ones(2, dtype=torch.double)
|
||||
|
||||
inputs.requires_grad_()
|
||||
v.requires_grad_()
|
||||
|
|
@ -5907,8 +5905,10 @@ class TestAutogradFunctional(TestCase):
|
|||
def adder(x, y):
|
||||
return 2 * x + 3 * y, x * y
|
||||
|
||||
inputs = (torch.rand(2, requires_grad=True), torch.rand(2, requires_grad=True))
|
||||
v = (torch.tensor([1., 0.], requires_grad=True), torch.tensor([1., 0.], requires_grad=True))
|
||||
inputs = (torch.rand(2, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(2, dtype=torch.double, requires_grad=True))
|
||||
v = (torch.tensor([1., 0.], dtype=torch.double, requires_grad=True),
|
||||
torch.tensor([1., 0.], dtype=torch.double, requires_grad=True))
|
||||
|
||||
gradcheck(lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[1], inputs + v)
|
||||
gradgradcheck(lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[1], inputs + v)
|
||||
|
|
@ -6061,8 +6061,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def test_jvp_create_graph(self):
|
||||
def reducer(x):
|
||||
return x.sum(dim=1)
|
||||
inputs = torch.rand(2, 2)
|
||||
v = torch.ones(2, 2)
|
||||
inputs = torch.rand(2, 2, dtype=torch.double)
|
||||
v = torch.ones(2, 2, dtype=torch.double)
|
||||
|
||||
inputs.requires_grad_()
|
||||
v.requires_grad_()
|
||||
|
|
@ -6077,8 +6077,10 @@ class TestAutogradFunctional(TestCase):
|
|||
def adder(x, y):
|
||||
return 2 * x + 3 * y, x * y
|
||||
|
||||
inputs = (torch.rand(2, requires_grad=True), torch.rand(2, requires_grad=True))
|
||||
v = (torch.tensor([1., 0.], requires_grad=True), torch.tensor([1., 0.], requires_grad=True))
|
||||
inputs = (torch.rand(2, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(2, dtype=torch.double, requires_grad=True))
|
||||
v = (torch.tensor([1., 0.], dtype=torch.double, requires_grad=True),
|
||||
torch.tensor([1., 0.], dtype=torch.double, requires_grad=True))
|
||||
|
||||
gradcheck(lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[1], inputs + v)
|
||||
gradgradcheck(lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[1], inputs + v)
|
||||
|
|
@ -6293,7 +6295,7 @@ class TestAutogradFunctional(TestCase):
|
|||
def exp_reducer(x):
|
||||
return x.exp().sum(dim=1)
|
||||
|
||||
inputs = torch.rand(4, 4, requires_grad=True)
|
||||
inputs = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
|
||||
res = autogradF.jacobian(exp_reducer, inputs, create_graph=True, vectorize=vectorize)
|
||||
self._assert_interleaved_struct(res, exp_reducer(inputs), inputs)
|
||||
self.assertIsNotNone(res.grad_fn)
|
||||
|
|
@ -6304,7 +6306,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def add_exp_reducer(x, y):
|
||||
return (x + y).exp().sum(dim=1)
|
||||
|
||||
inputs = (torch.rand(4, 4, requires_grad=True), torch.rand(4, 4, requires_grad=True))
|
||||
inputs = (torch.rand(4, 4, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(4, 4, dtype=torch.double, requires_grad=True))
|
||||
res = autogradF.jacobian(add_exp_reducer, inputs, create_graph=True, vectorize=vectorize)
|
||||
self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs)
|
||||
self.assertIsNotNone(res[0].grad_fn)
|
||||
|
|
@ -6603,7 +6606,7 @@ class TestAutogradFunctional(TestCase):
|
|||
def pow_reducer(x):
|
||||
return x.pow(3).sum()
|
||||
|
||||
inputs = torch.rand(2, 2, requires_grad=True)
|
||||
inputs = torch.rand(2, 2, dtype=torch.double, requires_grad=True)
|
||||
res = autogradF.hessian(pow_reducer, inputs, create_graph=True, vectorize=vectorize)
|
||||
self._assert_interleaved_struct(res, inputs, inputs)
|
||||
self.assertIsNotNone(res.grad_fn)
|
||||
|
|
@ -6614,7 +6617,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def add_pow_reducer(x, y):
|
||||
return (x + y).pow(3).sum()
|
||||
|
||||
inputs = (torch.rand(2, 2, requires_grad=True), torch.rand(2, 2, requires_grad=True))
|
||||
inputs = (torch.rand(2, 2, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(2, 2, dtype=torch.double, requires_grad=True))
|
||||
res = autogradF.hessian(add_pow_reducer, inputs, create_graph=True, vectorize=vectorize)
|
||||
self._assert_interleaved_struct(res, inputs, inputs)
|
||||
self.assertIsNotNone(res[0][0].grad_fn)
|
||||
|
|
@ -6783,8 +6787,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def foo(a):
|
||||
return 3 * a.narrow(0, 0, 3).exp().sum()
|
||||
|
||||
inputs = torch.rand(4, 4, requires_grad=True)
|
||||
v = torch.ones(4, 4, requires_grad=True)
|
||||
inputs = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
|
||||
v = torch.ones(4, 4, dtype=torch.double, requires_grad=True)
|
||||
res = autogradF.vhp(foo, inputs, v, create_graph=True)
|
||||
self._assert_same_struct(res[1], inputs)
|
||||
self.assertIsNotNone(res[0].grad_fn)
|
||||
|
|
@ -6796,8 +6800,10 @@ class TestAutogradFunctional(TestCase):
|
|||
def bar(a, b):
|
||||
return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
|
||||
|
||||
inputs = (torch.rand(3, requires_grad=True), torch.rand(4, requires_grad=True))
|
||||
v = (torch.ones(3, requires_grad=True), torch.ones(4, requires_grad=True))
|
||||
inputs = (torch.rand(3, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(4, dtype=torch.double, requires_grad=True))
|
||||
v = (torch.ones(3, dtype=torch.double, requires_grad=True),
|
||||
torch.ones(4, dtype=torch.double, requires_grad=True))
|
||||
out, vhp_val = autogradF.vhp(bar, inputs, v, create_graph=True)
|
||||
self._assert_same_struct(vhp_val, inputs)
|
||||
self.assertIsNotNone(out.grad_fn)
|
||||
|
|
@ -6958,8 +6964,8 @@ class TestAutogradFunctional(TestCase):
|
|||
def foo(a):
|
||||
return 3 * a.narrow(0, 0, 3).exp().sum()
|
||||
|
||||
inputs = torch.rand(4, 4, requires_grad=True)
|
||||
v = torch.ones(4, 4, requires_grad=True)
|
||||
inputs = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
|
||||
v = torch.ones(4, 4, dtype=torch.double, requires_grad=True)
|
||||
res = autogradF.hvp(foo, inputs, v, create_graph=True)
|
||||
self._assert_same_struct(res[1], inputs)
|
||||
self.assertIsNotNone(res[0].grad_fn)
|
||||
|
|
@ -6971,8 +6977,10 @@ class TestAutogradFunctional(TestCase):
|
|||
def bar(a, b):
|
||||
return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
|
||||
|
||||
inputs = (torch.rand(3, requires_grad=True), torch.rand(4, requires_grad=True))
|
||||
v = (torch.ones(3, requires_grad=True), torch.ones(4, requires_grad=True))
|
||||
inputs = (torch.rand(3, dtype=torch.double, requires_grad=True),
|
||||
torch.rand(4, dtype=torch.double, requires_grad=True))
|
||||
v = (torch.ones(3, dtype=torch.double, requires_grad=True),
|
||||
torch.ones(4, dtype=torch.double, requires_grad=True))
|
||||
out, hvp_val = autogradF.hvp(bar, inputs, v, create_graph=True)
|
||||
self._assert_same_struct(hvp_val, inputs)
|
||||
self.assertIsNotNone(out.grad_fn)
|
||||
|
|
@ -7216,8 +7224,8 @@ class TestAutogradForwardMode(TestCase):
|
|||
self.assertEqual(t, bar * 2)
|
||||
|
||||
def test_view_inplace_non_differentiable_views(self):
|
||||
original_foo = torch.rand(2)
|
||||
original_bar = torch.ones(2)
|
||||
original_foo = torch.rand(2, dtype=torch.double)
|
||||
original_bar = torch.ones(2, dtype=torch.double)
|
||||
|
||||
# Do clones to be able to compare the values updated inplace
|
||||
# with the original content of these Tensors
|
||||
|
|
@ -7380,7 +7388,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
assert torch.isfinite(x.grad).all()
|
||||
|
||||
def test_parameter_resize(self, device):
|
||||
asd = torch.nn.Parameter(torch.ones(16, device=device))
|
||||
asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device))
|
||||
|
||||
for i in range(2):
|
||||
with torch.no_grad():
|
||||
|
|
@ -7398,7 +7406,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
|
||||
i = i.to(torch.long)
|
||||
|
||||
inp = torch.randn(v_size, requires_grad=True)
|
||||
inp = torch.randn(v_size, dtype=torch.double, requires_grad=True)
|
||||
other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True, device=device,
|
||||
dtype=torch.double)[0]
|
||||
|
||||
|
|
@ -7445,7 +7453,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
self.assertEqual(integral_conv(l), pyscalar)
|
||||
|
||||
# floating point -> floating point
|
||||
f = Variable(t(torch.randn(1, 1)))
|
||||
f = Variable(t(torch.randn(1, 1, dtype=torch.double)))
|
||||
pyscalar = -12345.1
|
||||
f[0] = pyscalar
|
||||
self.assertEqual(float(f), pyscalar)
|
||||
|
|
@ -7603,8 +7611,8 @@ class TestAutogradDeviceType(TestCase):
|
|||
output.sum().backward()
|
||||
|
||||
def test_where_functional(self, device):
|
||||
x = torch.randn(5, 5, device=device, requires_grad=True)
|
||||
y = torch.randn(5, 5, device=device, requires_grad=True)
|
||||
x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
|
||||
y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
|
||||
cond = mask_not_all_zeros((5, 5)).to(device=device)
|
||||
|
||||
def where(cond, x, y):
|
||||
|
|
@ -7613,13 +7621,13 @@ class TestAutogradDeviceType(TestCase):
|
|||
gradcheck(where, [cond, x, y], raise_exception=True)
|
||||
gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)])
|
||||
|
||||
x = torch.randn(5, 1, 5, device=device, requires_grad=True)
|
||||
y = torch.randn(5, 5, 1, device=device, requires_grad=True)
|
||||
x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True)
|
||||
y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True)
|
||||
gradcheck(where, [cond, x, y], raise_exception=True)
|
||||
gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)])
|
||||
|
||||
def test_where_scalar(self, device):
|
||||
x = torch.randn(5, 5, device=device, requires_grad=True)
|
||||
x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
|
||||
scalar = 4.
|
||||
cond = mask_not_all_zeros((5, 5)).to(device=device)
|
||||
|
||||
|
|
@ -7662,7 +7670,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
for input_length, vary_lengths, zero_mode in tests:
|
||||
targets = torch.randint(1, num_labels, (batch_size, target_length),
|
||||
device=device, dtype=torch.long)
|
||||
x = torch.randn(gradcheck_input_size, device=device, requires_grad=True)
|
||||
x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
|
||||
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
|
||||
device=device)
|
||||
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
|
||||
|
|
@ -7757,7 +7765,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
@onlyCUDA
|
||||
def test_pin_memory(self, device):
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
self.assertEqual(x, x.pin_memory())
|
||||
self.assertIsNot(x, x.pin_memory())
|
||||
self.assertTrue(x.pin_memory().requires_grad)
|
||||
|
|
@ -8067,7 +8075,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
def test_inplace_view_then_no_grad(self, device):
|
||||
# Perform an in-place operation on a view of a non-leaf variable.
|
||||
a = torch.ones(3, 1, device=device, requires_grad=True)
|
||||
a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
|
||||
b = a * 2
|
||||
c = b.view_as(b)
|
||||
c[0][0] = 3
|
||||
|
|
@ -8080,8 +8088,8 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
def test_inplace_view_gradcheck(self, device):
|
||||
# gradcheck modifications to views
|
||||
a = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
b = torch.randn(2, 2, device=device, requires_grad=True)
|
||||
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
|
||||
b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
|
||||
|
||||
def func(root, b):
|
||||
x = root.clone()
|
||||
|
|
@ -8090,25 +8098,25 @@ class TestAutogradDeviceType(TestCase):
|
|||
return x
|
||||
|
||||
gradcheck(func, [a, b], raise_exception=True)
|
||||
go = torch.randn(a.size(), device=device, requires_grad=True)
|
||||
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
|
||||
gradgradcheck(func, (a, b), (go,))
|
||||
|
||||
def test_inplace_view_multiple_outputs(self, device):
|
||||
root = torch.arange(9.).reshape(3, 3).requires_grad_()
|
||||
root = torch.arange(9., dtype=torch.double).reshape(3, 3).requires_grad_()
|
||||
x = root.clone()
|
||||
v1 = x.unbind()
|
||||
with self.assertRaises(RuntimeError):
|
||||
v1[0].mul_(2)
|
||||
|
||||
def test_inplace_view_of_multiple_output_view(self, device):
|
||||
a = torch.rand(10, device=device, requires_grad=True).clone()
|
||||
a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
|
||||
b = a.unbind(0)
|
||||
c = b[0].view_as(b[0])
|
||||
with self.assertRaises(RuntimeError):
|
||||
c.mul_(2)
|
||||
|
||||
def test_inplace_multiple_output_view_of_view(self, device):
|
||||
a = torch.rand(10, device=device, requires_grad=True).clone()
|
||||
a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
|
||||
b = a.view_as(a)
|
||||
c = b.unbind(0)
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -8116,8 +8124,8 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
def test_inplace_view_makes_base_require_grad(self, device):
|
||||
# in-place modification to view makes base require grad
|
||||
a = torch.randn(4, 4, device=device, requires_grad=False)
|
||||
b = torch.randn(4, 2, device=device, requires_grad=True)
|
||||
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
|
||||
b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
|
||||
|
||||
def func(root, b):
|
||||
x = root.clone()
|
||||
|
|
@ -8127,7 +8135,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
return x
|
||||
|
||||
gradcheck(func, [a, b], raise_exception=True)
|
||||
go = torch.randn(a.size(), device=device, requires_grad=True)
|
||||
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
|
||||
gradgradcheck(func, (a, b), (go,))
|
||||
|
||||
def test_inplace_view_backprop_view(self, device):
|
||||
|
|
@ -8143,10 +8151,10 @@ class TestAutogradDeviceType(TestCase):
|
|||
# Test that an in-place operation on a base that forced it to require
|
||||
# grad also forces any previous views to require grad and backprop
|
||||
# correctly
|
||||
r = torch.ones(1, device=device, requires_grad=True)
|
||||
r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True)
|
||||
|
||||
def fn(r):
|
||||
x = torch.ones(5, device=device)
|
||||
x = torch.ones(5, dtype=torch.double, device=device)
|
||||
v = x.select(0, 1)
|
||||
self.assertFalse(v.requires_grad)
|
||||
self.assertIsNone(v.grad_fn)
|
||||
|
|
@ -8159,8 +8167,8 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
def test_inplace_view_python(self, device):
|
||||
# in-place modifications of Python-autograd created view
|
||||
a = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
b = torch.randn(2, 2, device=device, requires_grad=True)
|
||||
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
|
||||
b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
|
||||
|
||||
class PyAdd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -8180,7 +8188,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
return x
|
||||
|
||||
gradcheck(func, [a, b], raise_exception=True)
|
||||
go = torch.randn(a.size(), device=device, requires_grad=True)
|
||||
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
|
||||
gradgradcheck(func, (a, b), (go,))
|
||||
|
||||
def test_inplace_view_non_contig(self, device):
|
||||
|
|
@ -8215,8 +8223,8 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
def test_mv_grad_stride_0(self, device):
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/38315
|
||||
mat = torch.randn(2, 2, device=device)
|
||||
vec = torch.randn(1, device=device).requires_grad_(True)
|
||||
mat = torch.randn(2, 2, dtype=torch.double, device=device)
|
||||
vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True)
|
||||
|
||||
def fn(vec):
|
||||
# Expand inside the function to make sure the input to
|
||||
|
|
@ -8229,10 +8237,10 @@ class TestAutogradDeviceType(TestCase):
|
|||
|
||||
@onlyCUDA
|
||||
def test_gradcheck_input_output_different_device(self, device):
|
||||
x = torch.ones((1,), device="cuda", requires_grad=True)
|
||||
x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True)
|
||||
gradcheck(lambda x: x.to("cpu"), (x,))
|
||||
|
||||
x = torch.ones((1,), device="cpu", requires_grad=True)
|
||||
x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True)
|
||||
gradcheck(lambda x: x.to("cuda"), (x,))
|
||||
|
||||
def test_logcumsumexp_large_value(self, device):
|
||||
|
|
|
|||
|
|
@ -5888,7 +5888,7 @@ shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)]
|
|||
def index_variable(shape, max_indices, device=torch.device('cpu')):
|
||||
if not isinstance(shape, tuple):
|
||||
shape = (shape,)
|
||||
index = torch.rand(*shape, device=device).mul_(max_indices).floor_().long()
|
||||
index = torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().long()
|
||||
return index
|
||||
|
||||
|
||||
|
|
@ -6216,6 +6216,10 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
|
|||
# double check casting
|
||||
elif isinstance(arg, non_differentiable):
|
||||
if isinstance(arg.tensor, torch.Tensor):
|
||||
if arg.tensor.dtype == torch.float:
|
||||
return maybe_non_contig(arg.tensor.to(dtype=torch.double, device=device))
|
||||
if arg.tensor.dtype == torch.cfloat:
|
||||
return maybe_non_contig(arg.tensor.to(dtype=torch.cdouble, device=device))
|
||||
return maybe_non_contig(arg.tensor.to(device=device))
|
||||
return maybe_non_contig(arg.tensor.to(device=device))
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
|
|
|
|||
Loading…
Reference in a new issue