mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[MPS][BE] Error-check linear (#124952)
Validate that all arguments are on MPS devices and dtypes are expected Fixes cryptic messages like ``` % python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32), torch.rand((32, 32), device='mps')))" RuntimeError: Placeholder storage has not been allocated on MPS device! ``` And hard crashes like ``` % python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32, device='mps'), torch.randint(-10, 10, (32, 32), dtype=torch.int8, device='mps')))" ``` Fixes https://github.com/pytorch/pytorch/issues/123995 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124952 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
973d724e21
commit
db3a2d751c
2 changed files with 27 additions and 1 deletions
|
|
@ -16,9 +16,16 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt
|
|||
auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
|
||||
|
||||
TORCH_CHECK(supportedFloatingType(input), "MPS device does not support linear for non-float inputs");
|
||||
TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
|
||||
TORCH_CHECK(supportedFloatingType(weight_arg), "MPS device does not support linear for non-float weights");
|
||||
TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps");
|
||||
|
||||
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
|
||||
bool is_bias_defined = bias.defined();
|
||||
const bool is_bias_defined = bias.defined();
|
||||
if (is_bias_defined) {
|
||||
TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps");
|
||||
TORCH_CHECK(supportedFloatingType(bias), "MPS device does not support linear for non-float bias");
|
||||
}
|
||||
|
||||
auto input_size = input.sizes();
|
||||
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
|
||||
|
|
|
|||
|
|
@ -1961,6 +1961,25 @@ class TestMPS(TestCaseMPS):
|
|||
helper(())
|
||||
helper((2, 4))
|
||||
|
||||
def test_linear_errors(self):
|
||||
# Mixed CPU<->MPS tensors
|
||||
size = (3, 3)
|
||||
|
||||
# Unsupported dtypes
|
||||
with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
|
||||
torch.nn.functional.linear(torch.rand(size, device='mps'),
|
||||
torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
|
||||
|
||||
# Weigths on wrong device
|
||||
with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
|
||||
torch.nn.functional.linear(torch.rand(size, device='mps'),
|
||||
torch.rand(size, device='cpu'))
|
||||
|
||||
# Input on wrong device
|
||||
with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
|
||||
torch.nn.functional.linear(torch.rand(size, device='cpu'),
|
||||
torch.rand(size, device='mps'))
|
||||
|
||||
def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
|
||||
cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
|
||||
mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
|
||||
|
|
|
|||
Loading…
Reference in a new issue