From db3a2d751c117fe563bdbc4a1b4f8736c184ca68 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 25 Apr 2024 23:25:20 +0000 Subject: [PATCH] [MPS][BE] Error-check linear (#124952) Validate that all arguments are on MPS devices and dtypes are expected Fixes cryptic messages like ``` % python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32), torch.rand((32, 32), device='mps')))" RuntimeError: Placeholder storage has not been allocated on MPS device! ``` And hard crashes like ``` % python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32, device='mps'), torch.randint(-10, 10, (32, 32), dtype=torch.int8, device='mps')))" ``` Fixes https://github.com/pytorch/pytorch/issues/123995 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124952 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/mps/operations/Linear.mm | 9 ++++++++- test/test_mps.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 6ed98530359..450e24c77c9 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -16,9 +16,16 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg; TORCH_CHECK(supportedFloatingType(input), "MPS device does not support linear for non-float inputs"); + TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps"); + TORCH_CHECK(supportedFloatingType(weight_arg), "MPS device does not support linear for non-float weights"); + TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps"); const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); - bool is_bias_defined = bias.defined(); + const bool is_bias_defined = bias.defined(); + if (is_bias_defined) { + TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps"); + TORCH_CHECK(supportedFloatingType(bias), "MPS device does not support linear for non-float bias"); + } auto input_size = input.sizes(); std::vector output_size(input_size.begin(), input_size.end() - 1); diff --git a/test/test_mps.py b/test/test_mps.py index bfac420775a..7f87c1ccd41 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1961,6 +1961,25 @@ class TestMPS(TestCaseMPS): helper(()) helper((2, 4)) + def test_linear_errors(self): + # Mixed CPU<->MPS tensors + size = (3, 3) + + # Unsupported dtypes + with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"): + torch.nn.functional.linear(torch.rand(size, device='mps'), + torch.randint(-10, 10, size, dtype=torch.int8, device='mps')) + + # Weigths on wrong device + with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"): + torch.nn.functional.linear(torch.rand(size, device='mps'), + torch.rand(size, device='cpu')) + + # Input on wrong device + with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"): + torch.nn.functional.linear(torch.rand(size, device='cpu'), + torch.rand(size, device='mps')) + def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)