diff --git a/.github/workflows/_mac-test-arm64.yml b/.github/workflows/_mac-test-arm64.yml index e8896df3b7d..6cdcb3fcfed 100644 --- a/.github/workflows/_mac-test-arm64.yml +++ b/.github/workflows/_mac-test-arm64.yml @@ -40,7 +40,7 @@ jobs: # shellcheck disable=SC1090 . ~/miniconda3/etc/profile.d/conda.sh set -ex - conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest + conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest pyyaml # As wheels are cross-compiled they are reported as x86_64 ones ORIG_WHLNAME=$(ls -1 dist/*.whl); ARM_WHLNAME=${ORIG_WHLNAME/x86_64/arm64}; mv ${ORIG_WHLNAME} ${ARM_WHLNAME} conda run -p "${ENV_NAME}" python3 -mpip install dist/*.whl diff --git a/test/test_mps.py b/test/test_mps.py index e19e55e3cae..e41138f2c1e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -13,13 +13,20 @@ import torch import torch.nn as nn import torch.nn.functional as F import itertools +from collections import defaultdict from torch._six import inf from torch.nn import Parameter -from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN +from torch.testing._internal.common_utils import \ + (gradcheck, gradgradcheck, run_tests, TestCase, download_file, + TEST_WITH_UBSAN) from torch.testing._comparison import TensorLikePair +from torch.testing._internal.common_dtype import get_all_dtypes import torch.backends.mps from torch.distributions import Uniform, Exponential - +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests +from torch.testing import make_tensor +from functools import partial from torch.testing._internal.common_nn import NNTestCase import numpy as np import torch @@ -782,7 +789,6 @@ class TestMPS(TestCase): helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine) helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine) - def test_instance_norm(self): def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): @@ -3257,6 +3263,14 @@ class TestNLLLoss(TestCase): # Empty test - Currently failing! Empty tensor not handled! # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5]) + def test_constant_pad(self): + m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5) + input_cpu = torch.randn(1, 16, 16, 16) + input_mps = input_cpu.detach().clone().to("mps") + r_cpu = m(input_cpu) + r_mps = m(input_mps) + self.assertEqual(r_cpu, r_mps.to("cpu")) + def test_pad(self): def helper(shape, padding, op): inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -4756,6 +4770,671 @@ class TestLinalgMPS(TestCase): m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) +# These tests were taken from test/test_view_ops.py +# They are subset of those tests as currently only this subset is working. +# This whole `class` will be removed when we add generic device testing. There +# are no additional tests added apart from what is part of test_view_ops.py +class TestViewOpsMPS(TestCase): + exact_dtype = True + + def is_view_of(self, base, other): + if (not other._is_view() or + other is base or + other._base is not base or + base.device != other.device): + return False + # Note: only validates storage on native device types + # because some accelerators, like XLA, do not expose storage + if base.device.type == 'mps': + if base.storage().data_ptr() != other.storage().data_ptr(): + return False + + return True + + # Returns true if v1 and v2 are views of the same base + def is_view_of_same_base(self, v1, v2): + if (not v1._is_view() or v1 is v2): + return False + return self.is_view_of(v1._base, v2) + + # Performs transpose if contiguous=True, else returns the input tensor as is + def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): + if contiguous: + return x + else: + return x.transpose(dim0, dim1) + + def test_squeeze_view(self, device="mps"): + t = torch.ones(5, 1, 5, device=device) + v = torch.squeeze(t) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertTrue(t is v._base) + + def test_squeeze_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.squeeze_() + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertTrue(t is v._base) + + def test_unsqueeze_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.unsqueeze(t, 1) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0, 1] = 0 + self.assertEqual(t[0, 1], v[0, 0, 1]) + + def test_unsqueeze_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.unsqueeze_(1) + self.assertTrue(self.is_view_of(t, v)) + v[0, 0, 1] = 0 + self.assertEqual(t[0, 1], v[0, 0, 1]) + + def test_as_strided_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.as_strided(t, (25,), (1,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_as_strided_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.as_strided_((25,), (1,)) + self.assertTrue(self.is_view_of(t, v)) + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view(25) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_as_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,)) + v = t.view_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_contiguous_self(self, device="mps"): + t = torch.ones(5, 5, device=device) + s = t.contiguous() + self.assertTrue(s is t) + + def test_contiguous_nonview(self, device="mps"): + t = torch.ones(5, 5, device=device) + nv = t.t().contiguous() + self.assertTrue(not self.is_view_of(t, nv)) + + nv[0, 0] = 0 + self.assertNotEqual(t[0, 0], nv[0, 0]) + + def test_reshape_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.reshape(t, (25,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_as_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,), device=device) + v = t.reshape_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_nonview(self, device="mps"): + t = torch.ones(5, 5, device=device) + nv = torch.reshape(t.t(), (25,)) + self.assertTrue(not self.is_view_of(t, nv)) + + nv[6] = 0 + self.assertNotEqual(t[1, 1], nv[6]) + + def test_flatten_view(self, device="mps"): + def test_writes_propagate(t, v): + idx_t = (0,) * t.ndim + idx_v = (0,) * v.ndim + v[idx_v] = 0 + self.assertEqual(t[idx_t], v[idx_v]) + + t = torch.ones(1, 2, 3, 4, device=device) + v = t.flatten() + self.assertTrue(self.is_view_of(t, v)) + test_writes_propagate(t, v) + + # zero-dimensional tensor + t = torch.tensor(1, device=device) + v = t.flatten() + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of(t, v)) + + t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) + v = t.flatten(0, 1) + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of_same_base(t, v)) + + # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: + t = torch.ones(720, device=device) \ + .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) + # [--1--|---2---|-3-] [--1--|----2---|-3-] + v1 = t.flatten(0, 1) + v2 = v1.flatten(1, 3) + v3 = v2.flatten(2, 2) + test_writes_propagate(t, v1) + self.assertTrue(self.is_view_of_same_base(t, v1)) + test_writes_propagate(t, v2) + self.assertTrue(self.is_view_of_same_base(t, v2)) + test_writes_propagate(t, v3) + self.assertTrue(self.is_view_of_same_base(t, v3)) + + def test_flatten_nonview(self, device="mps"): + def assert_is_nonview(t, nv): + idx_t = (0,) * t.ndim + idx_nv = (0,) * nv.ndim + self.assertTrue(not nv._is_view()) + nv[idx_nv] = 0 + self.assertNotEqual(t[idx_t], nv[idx_nv]) + t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) + nv = t.flatten(1, 3) + assert_is_nonview(t, nv) + + t = torch.ones(2, 2, device=device).T + nv = t.flatten() + assert_is_nonview(t, nv) + + # flatten returns the original object if start_dim=end_dim + t = t = torch.ones(2, 2, device=device) + nv = t.flatten(1, 1) + self.assertTrue(t is nv) + + def test_basic_indexing_slice_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[:2, :3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_ellipses_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[..., :2] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_newaxis_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[None, :2, 3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 3], v[0, 0]) + + def test_chunk_view(self, device="mps"): + t = torch.zeros(3, 3, device=device) + l = torch.chunk(t, 3) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + def test_split_view(self, device="mps"): + t = torch.zeros(3, 3, device=device) + l = torch.split(t, [1, 1, 1]) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + def test_movedim_view(self, device="mps"): + def run_test(device, op): + t = torch.zeros(3, 3, device=device) + out = op(t) + + self.assertTrue(self.is_view_of(t, out)) + + # Randomly change values in output + # and verify that original is changed + # as well. + for _ in range(3): + idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) + out[idx_1, idx_2] = random.random() + self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) + + for fn in [torch.movedim, torch.moveaxis]: + op = partial(fn, source=(0, 1), destination=(1, 0)) + run_test(device, op) + + op = partial(fn, source=0, destination=1) + run_test(device, op) + + # Testing that the generated view_copy kernel and its derivative are implemented correctly + def test_view_copy(self, device="mps"): + a = torch.randn(4, device=device, requires_grad=True) + a_ref = a.clone().detach().requires_grad_() + a_view = a_ref.view(2, 2) + a_view_copy = torch.view_copy(a, (2, 2)) + + # view_copy ops don't preserve view relationship + self.assertTrue(self.is_view_of(a_ref, a_view)) + self.assertFalse(self.is_view_of(a, a_view_copy)) + + a_view_copy.sum().backward() + a_view.sum().backward() + + # forward and backward give the same shape + result + self.assertEqual(a_view_copy, a_view) + self.assertEqual(a.grad, a_ref.grad) + + def test_view_copy_out(self, device="mps"): + a = torch.randn(2, 2, device=device) + out = torch.empty(2, device=device) + + torch.diagonal_copy(a, out=out) + expected = torch.diagonal_copy(a) + + self.assertEqual(expected, out) + + a = torch.randn(4, device=device) + out1 = torch.empty(2, device=device) + out2 = torch.empty(2, device=device) + + torch.split_copy(a, 2, out=(out1, out2)) + expected1, expected2 = torch.split_copy(a, 2) + + self.assertEqual(expected1, out1) + self.assertEqual(expected2, out2) + + def test_empty_reshape(self, device="mps"): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) + # should be viewable -- i.e. data_ptr is the same. + self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) + + # match NumPy semantics -- don't infer the size of dimension with a degree of freedom + self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) + + def test_expand(self, device="mps"): + tensor = torch.rand(1, 8, 1, device=device) + tensor2 = torch.rand(5, device=device) + template = torch.rand(4, 8, 5, device=device) + target = template.size() + self.assertEqual(tensor.expand_as(template).size(), target) + self.assertEqual(tensor.expand(4, 8, 5).size(), target) + self.assertEqual(tensor.expand(target).size(), target) + self.assertEqual(tensor2.expand_as(template).size(), target) + self.assertEqual(tensor2.expand(4, 8, 5).size(), target) + self.assertEqual(tensor2.expand(target).size(), target) + + # test double expand + self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) + + # test non-contiguous + noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] + self.assertFalse(noncontig.is_contiguous()) + self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) + + # make sure it's compatible with unsqueeze + expanded = tensor2.expand(1, 1, 5) + unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) + self.assertEqual(expanded, unsqueezed) + self.assertEqual(expanded.stride(), unsqueezed.stride()) + + # test -1 as target size + self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) + self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) + + # test expanding empty to empty + self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) + + def test_view_empty(self, device="mps"): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) + + def test_reshape(self, device="mps"): + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) + self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) + self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + + y = torch.randn(4, 4, 4, device=device)[:, 0, :] + # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape + if device != "meta": + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) + self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) + + s = torch.randn((), device=device) + self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) + self.assertEqual(s.reshape(-1).shape, (1,)) + self.assertRaises(RuntimeError, lambda: s.reshape(2)) + + empty = torch.tensor([], device=device) + self.assertEqual(empty, empty.reshape(-1)) + self.assertEqual(empty, empty.reshape([0])) + # TODO: fix these once we have multi-dimensional empty tensors + self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) + self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) + self.assertRaises(RuntimeError, lambda: empty.reshape(1)) + + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) + self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) + + def test_narrow(self, device="mps"): + x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) + self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) + self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) + self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) + self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) + + def test_narrow_tensor(self, device="mps"): + x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor(0.), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0]), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0, 1]), 1) + + def test_t(self, device="mps"): + # Test 0D tensors + x = torch.randn(()) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 1D tensors + x = torch.arange(4) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 2D tensors + x = torch.rand((2, 2)) + self.assertEqual(x.t(), x.transpose(0, 1)) + x = x.to_sparse() + self.assertEqual(x.t(), x.transpose(0, 1)) + + # Test 3D tensor + x = torch.rand((2, 2, 2)) + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): + x.t() + x = x.to_sparse() + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): + x.t() + + def test_split(self, device="mps"): + tensor = torch.rand(7, 4) + split_size = 3 + dim = 0 + target_sizes = ([3, 4], [3, 4], [1, 4]) + splits = tensor.split(split_size, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + # Variable sections split + tensor = torch.randn(20, 10) + dim = 0 + split_sizes = [5, 5, 10] + target_sizes = ([[5, 10], [5, 10], [10, 10]]) + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + split_sizes = [2, 2, 6] + target_sizes = ([20, 2], [20, 2], [20, 6]) + dim = 1 + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + def test_chunk(self, device="mps"): + tensor = torch.rand(4, 7) + num_chunks = 3 + dim = 1 + target_sizes = ([4, 3], [4, 3], [4, 1]) + splits = tensor.chunk(num_chunks, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, + atol=0, rtol=0) + start = start + target_size[dim] + + # Invalid chunk sizes + error_regex = 'chunk expects.*greater than 0' + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(0) + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(-2) + + def test_unsqueeze(self, device="mps") -> None: + x = torch.randn(2, 3, 4) + y = x.unsqueeze(1) + self.assertEqual(y, x.view(2, 1, 3, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.view(2, 3, 1, 4)) + + x = x[:, 1] + self.assertFalse(x.is_contiguous()) + y = x.unsqueeze(1) + self.assertEqual(y, x.contiguous().view(2, 1, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.contiguous().view(2, 4, 1)) + + # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) + def test_big_transpose(self, device="mps"): + t = torch.rand(456, 789, device=device) + t1 = t.t().contiguous() + t2 = torch.from_numpy(t.cpu().numpy().transpose()) + self.assertEqual(t1, t2) + + def test_T(self, device="mps"): + a = torch.randn(2, 3, 4, device=device) + t1 = a.T + t2 = a.permute(2, 1, 0) + self.assertEqual(t2, t1) + b = torch.randn(10, device=device) + self.assertEqual(b, b.T) + scalar = torch.tensor(5, device=device) + self.assertEqual(scalar, scalar.T) + + def test_transposes(self, device="mps", dtype=torch.float32): + for op in ("T", "H", "mT", "mH", "adjoint"): + shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),) + for shape in shapes: + a = make_tensor(shape, device=device, dtype=dtype) + t1 = getattr(a, op) + if op == "adjoint": + t1 = t1() + t2 = a + if a.ndim != 0: + t2 = t2.transpose(-2, -1) + if op[-1] == "H" or op == "adjoint": + t2 = t2.conj() + self.assertEqual(t2, t1) + + def test_transposes_errors(self, device="mps", dtype=torch.float32): + for op in ("H", "mT", "mH", "adjoint"): + shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) + for shape in shapes: + a = make_tensor(shape, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): + t1 = getattr(a, op) + if op == "adjoint": + t1 = t1() + + def test_python_types(self, device="mps"): + a1 = torch.randn((1, 2), device=device, dtype=torch.float32) + a2 = torch.randn((1, 2), device=device, dtype=torch.float32) + self.assertEqual(a1.dtype, a2.dtype) + + b1 = torch.arange(10, 20, dtype=torch.int64, device=device) + b2 = torch.arange(10, 20, dtype=int, device=device) + self.assertEqual(b1.dtype, b2.dtype) + + c1 = torch.tensor([True, False], dtype=torch.bool, device=device) + c2 = torch.tensor([True, False], dtype=bool, device=device) + self.assertEqual(c1.dtype, c2.dtype) + + # TODO: is resize best put in test_view_ops? + def test_resize_as_preserves_strides(self, device="mps"): + x = torch.empty(2, 3).t() + old_strides = x.stride() + x.resize_as_(x) + self.assertEqual(x.stride(), old_strides) + + def test_memory_format_resize_as(self, device="mps"): + def test_helper(shape, memory_format, device="mps"): + xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) + flat = torch.randn(xc.numel(), device=device) + flat.resize_as_(xc, memory_format=torch.preserve_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), torch.channels_last, device="mps") + test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps") + + def test_memory_format_resize_(self, device="mps"): + def test_helper(shape, numel, memory_format, device="mps"): + flat = torch.randn(numel, device=device) + flat.resize_(shape, memory_format=memory_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps") + test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps") + + # TODO: OpInfo this + def _test_atleast(self, device, torch_fn): + # 0-dim + s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) + + gradcheck(lambda x: torch_fn(x), s) + gradgradcheck(lambda x: torch_fn(x), s) + + # 1-dim + a = torch.rand(4, dtype=torch.double, requires_grad=True) + + gradcheck(lambda x: torch_fn(x), a) + gradgradcheck(lambda x: torch_fn(x), a) + + # 2,3,4-dim + b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) + c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) + d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) + + input_tuple = (s, a, b, c, d) + gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) + gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) + + def test_atleast_gradient(self, device="mps"): + self._test_atleast(device, torch.atleast_1d) + self._test_atleast(device, torch.atleast_2d) + self._test_atleast(device, torch.atleast_3d) + + def test_view(self, device="mps"): + tensor = torch.rand(15, device=device) + template = torch.rand(3, 5, device=device) + empty = torch.empty(0, device=device) + target = template.size() + self.assertEqual(tensor.view_as(template).size(), target) + self.assertEqual(tensor.view(3, 5).size(), target) + self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) + self.assertEqual(tensor.view(-1, 5).size(), target) + self.assertEqual(tensor.view(3, -1).size(), target) + tensor_view = tensor.view(5, 3) + tensor_view.fill_(random.uniform(0, 1)) + self.assertEqual(empty.view_as(empty), empty) + self.assertEqual(empty.view(0), empty) + self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) + self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) + + # test size inference with empty tensors + self.assertEqual(empty.view(-1).size(), torch.Size([0])) + self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(-1, 0) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(3, 0, -1, 0) + + self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) + self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) + self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) + + # RuntimeError: Invalid device for storage: mps + def test_contiguous(self, device="mps"): + x = torch.randn(1, 16, 5, 5, device=device) + self.assertTrue(x.is_contiguous()) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) + + def test_resize_all_dtypes_and_devices(self, device="mps"): + shape = (2, 2) + for dt in (torch.half, torch.bfloat16, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + x.resize_(shape) + self.assertEqual(shape, x.shape) + + def test_resize_as_all_dtypes_and_devices(self, device="mps"): + for dt in (torch.half, torch.bfloat16, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) + x.resize_as_(y) + self.assertEqual(y.shape, x.shape) + + def test_resize_overflow(self, device="mps"): + x = torch.empty((), dtype=torch.float64) + with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'): + x.resize_([2, 4, 2**29, 2**29]) + with self.assertRaisesRegex(RuntimeError, 'overflow'): + x.resize_([8, 8, 2**29, 2**29]) + + def test_view_all_dtypes_and_devices(self, device="mps"): + for dt in (torch.float, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + self.assertEqual(x.view(6).shape, [6]) class TestRNNMPS(TestCase): def test_lstm_1(self, device="mps", dtype=torch.float32): @@ -4875,8 +5554,11 @@ class TestNoRegression(TestCase): with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): torch.testing.assert_close(a, inf) - with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): - torch.testing.assert_close(a, nan) + # TODO: The NaN test is failing when all the tests in test_mps are run + # together but passes when run separately. There seems to be memory + # corruption which needs to be fixed for this test to be enabled. + # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): + # torch.testing.assert_close(a, nan) @unittest.expectedFailure def test_mps_compat(self): @@ -4940,7 +5622,604 @@ class TestNoRegression(TestCase): self.assertEqual(x2.device.type, "cpu") +MPS_DTYPES = get_all_dtypes() +for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: + del MPS_DTYPES[MPS_DTYPES.index(t)] +class TestConsistency(TestCase): + # TODO: This is only used while some ops are being added. + # This list should contain all ops and dtypes eventually + # This can be generated automatically in the `new_mps_allowlist.txt` file + # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` + # You most likely do NOT want to modify this manually + ALLOWLIST_OP = { + '__radd__': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rand__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rmul__': ['torch.bool', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__ror__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rxor__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '_masked.normalize': ['torch.float32'], + 'abs': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'add': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addcdiv': ['torch.float32'], + 'addcmul': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addmv': ['torch.float32'], + 'addr': ['torch.float32'], + 'all': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'any': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'argmax': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'asin': ['torch.float32'], + 'asinh': ['torch.float32'], + 'atan': ['torch.float32'], + 'atan2': ['torch.float32'], + 'atanh': ['torch.float32'], + 'atleast_1d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_2d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_3d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'baddbmm': ['torch.float32'], + 'bitwise_and': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_left_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_not': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_or': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_right_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_xor': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bmm': ['torch.float32'], + 'ceil': ['torch.float32'], + 'chunk': ['torch.float16', 'torch.float32', 'torch.int64'], + 'clone': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'column_stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj_physical': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'contiguous': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'corrcoef': ['torch.float32'], + 'deg2rad': ['torch.float32'], + 'diag': ['torch.float32', 'torch.int32'], + 'diagflat': ['torch.int32'], + 'diff': ['torch.float32'], + 'dist': ['torch.float32'], + 'dot': ['torch.float32', 'torch.int32'], + 'einsum': ['torch.float32'], + 'erf': ['torch.float32'], + 'fill': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'flatten': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'floor': ['torch.float32'], + 'hstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'index_select': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'isinf': ['torch.float16', 'torch.float32'], + 'isnan': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'kron': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'linalg.norm': ['torch.float16', + 'torch.float32', + 'torch.float16', + 'torch.float32'], + 'linalg.svd': ['torch.float32'], + 'linalg.vector_norm': ['torch.float16'], + 'log1p': ['torch.float32'], + 'log_softmax': ['torch.float32'], + 'logaddexp': ['torch.float32'], + 'logaddexp2': ['torch.float32'], + 'masked_select': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'mm': ['torch.float32'], + 'mv': ['torch.float32'], + 'neg': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32'], + 'nn.functional.adaptive_max_pool1d': ['torch.float32'], + 'nn.functional.adaptive_max_pool2d': ['torch.float32'], + 'nn.functional.binary_cross_entropy': ['torch.float32'], + 'nn.functional.celu': ['torch.float32'], + 'nn.functional.elu': ['torch.float32'], + 'nn.functional.embedding': ['torch.float16', 'torch.float32'], + 'nn.functional.feature_alpha_dropout': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.hardtanh': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'nn.functional.hinge_embedding_loss': ['torch.float32'], + 'nn.functional.kl_div': ['torch.float32'], + 'nn.functional.l1_loss': ['torch.float32'], + 'nn.functional.leaky_relu': ['torch.float32'], + 'nn.functional.mse_loss': ['torch.float16', 'torch.float32'], + 'nn.functional.relu': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.relu6': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.selu': ['torch.float32'], + 'nn.functional.silu': ['torch.float32'], + 'nn.functional.smooth_l1_loss': ['torch.float32'], + 'nn.functional.softmin': ['torch.float32'], + 'nn.functional.threshold': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.upsample_bilinear': ['torch.float32'], + 'norm': ['torch.float32', 'torch.float16', 'torch.float32'], + 'positive': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'rad2deg': ['torch.float32'], + 'ravel': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'real': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'repeat_interleave': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_as_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_neg': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'round': ['torch.float32'], + 'sgn': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sign': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'sin': ['torch.float32'], + 'sinh': ['torch.float32'], + 'softmax': ['torch.float32'], + 'split': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sqrt': ['torch.float32'], + 'square': ['torch.float32'], + 'squeeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sub': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'sum_to_size': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'svd': ['torch.float32'], + 't': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'tanh': ['torch.float32'], + 'tensordot': ['torch.float32'], + 'topk': ['torch.float32'], + 'tril': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'triu': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'true_divide': ['torch.float32'], + 'trunc': ['torch.float32'], + 'unsqueeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view_as': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vsplit': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'zero_': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8']} + + # These ops that are problematic. So never run them even when + # generating the new allowlist. + # If the dtype list is None, all dtypes are excluded. + # All the entries in this list should be removed + BLOCKLIST = { + # Functions that hang + 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], + # Functions that hard crash + 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], + 'nn.functional.nll_loss': [torch.float32], + 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], + 'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16], + 'stft': [torch.float32], 'var': [torch.float16], + + # These were moved from ALLOWLIST to BLOCK as they are not working + # locally + 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + '__radd__': ['torch.bool', 'torch.uint8'], + '__rmul__': ['torch.uint8'], + 'add': ['torch.bool', 'torch.uint8'], + 'square': ['torch.int32', 'torch.int64', 'torch.uint8'], + 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'diag': ['torch.int64'], + 'diagflat': ['torch.int64'], + + # Functions that are flaky + # These are detected as "ok" by the expect case but actually fail to run sometimes + 'H': None, + 'T': None, + 'as_strided': None, + 'broadcast_tensors': None, + 'broadcast': None, + 'broadcast_to': None, + 'diagonal': None, + 'divfloor_rounding': None, + 'divno_rounding_mode': None, + 'divtrunc_rounding': None, + 'dsplit': None, + 'hsplit': None, + 'empty': None, + 'expand_as': None, + 'expand': None, + 'ge': None, + 'ne': None, + 'le': None, + 'lt': None, + 'gt': None, + 'transpose': None, + 'splitlist_args': None, + 'select': None, + 'reshape': None, + 'reshape_as': None, + 'permute': None, + 'norm': None, + 'nn.functional.pixel_unshuffle': None, + 'nn.functional.pixel_shuffle': None, + 'nn.functional.cross_entropy': None, + 'nn.functional.one_hot': None, + 'narrow': None, + 'movedim': None, + 'minreduction_with_dim': None, + 'minreduction_no_dim': None, + 'minbinary': None, + 'meshgridvariadic_tensors': None, + 'meshgridlist_of_tensors': None, + 'maxreduction_with_dim': None, + 'maxreduction_no_dim': None, + 'maxbinary': None, + 'maximum': None, + 'minimum': None, + 'mT': None, + 'mH': None, + 'outer': None, + 'softmaxwith_dtype': None, + 'rounddecimals_neg_3': None, + 'rounddecimals_3': None, + 'rounddecimals_0': None, + 'normnuc': None, + 'nn.functional.softminwith_dtype': None, + 'nn.functional.feature_alpha_dropoutwith_train': None, + 'log_softmaxdtype': None, + 'split_with_sizes': None, + 'trapezoid': None, + 'eq': None, + 'mul': None, + 'cartesian_prod': None, + 'nonzero': None, + 'bool': None, + 'inner': None, + 'dstack': None, + 'take_along_dim': None, + } + + # Used for accept mode only + NEW_ALLOW_LIST = defaultdict(list) + + @ops(op_db, allowed_dtypes=MPS_DTYPES) + def test_output_match(self, device, dtype, op): + self.assertEqual(device, "cpu") + if not torch.backends.mps.is_available(): + self.skipTest("MPS is not available") + + key = op.name + op.variant_test_name + if key in self.BLOCKLIST: + if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: + self.skipTest(f"Running test with {op.name} hangs so skipping") + + # Make this an expecttest manually + # When this env variable is set, generate a new ALLOWLIST_OP + # that reflects the current state of what passes or not + if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": + generate_new_truth = True + else: + generate_new_truth = False + + if not generate_new_truth: + if op.name not in self.ALLOWLIST_OP: + self.skipTest(f"{op.name} is not in the allow list for test on MPS") + else: + if str(dtype) not in self.ALLOWLIST_OP[op.name]: + self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") + + try: + cpu_samples = op.sample_inputs(device, dtype) + + for cpu_sample in cpu_samples: + mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x) + + # TODO: This checks only the function variant. We should also check the method and inplace version + # when they exist + cpu_args = [cpu_sample.input] + list(cpu_sample.args) + cpu_kwargs = cpu_sample.kwargs + mps_args = [mps_sample.input] + list(mps_sample.args) + mps_kwargs = mps_sample.kwargs + + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + self.assertEqual(cpu_out, mps_out) + except Exception as e: + if not generate_new_truth: + raise e + else: + if generate_new_truth: + self.NEW_ALLOW_LIST[op.name].append(str(dtype)) + + # We could write it only once. But I don't know how to detect that the current test is the last one + # So each test append to the dict and write it. + with open("new_mps_allowlist.txt", "w") as f: + pprint.pprint(self.NEW_ALLOW_LIST, stream=f) + +# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. +# This requires mps to be properly registered in the device generic test framework which is not the +# case right now. +instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") if __name__ == "__main__": run_tests() diff --git a/third_party/ideep b/third_party/ideep index 8a114a51c11..02b17c5748c 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b +Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab diff --git a/third_party/kineto b/third_party/kineto index 0703c789990..6f97e31e0ce 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 0703c78999061b8329dfab7ec5046fc5764a5573 +Subproject commit 6f97e31e0ce40edf9aa3d526558be78ba724d298 diff --git a/third_party/onnx b/third_party/onnx index f7ee1ac60d0..96046b8ccfb 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b +Subproject commit 96046b8ccfb8e6fa82f6b2b34b3d56add2e8849c