Fixed an issue where nn.Linear would cause an internal int underflow … (#119221)

…when trying to reshape a scalar input.

Fixes #119161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119221
Approved by: https://github.com/albanD
This commit is contained in:
Tobias Ringwald 2024-02-08 21:06:29 +00:00 committed by PyTorch MergeBot
parent 7fd6b1c558
commit d9a1b25807
2 changed files with 16 additions and 1 deletions

View file

@ -70,6 +70,14 @@ static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weigh
Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
// _matmul_impl checks this again later, but _flatten_nd_linear does not work on scalars inputs,
// so let's try to catch this here already
const auto input_dim = input.dim();
const auto weight_dim = weight.dim();
TORCH_CHECK(input_dim != 0 && weight_dim != 0,
"both arguments to linear need to be at least 1D, but they are ",
input_dim, "D and ", weight_dim, "D");
// See [Note: hacky wrapper removal for optional tensor]
auto bias = bias_opt.has_value()
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
@ -82,7 +90,6 @@ Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Ten
return xnnpack::linear(input, weight, *bias);
}
#endif
const auto input_dim = input.dim();
if (input_dim == 2 && bias->defined()) {
// Fused op is marginally faster.
return at::addmm(*bias, input, weight.t());

View file

@ -6526,6 +6526,14 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
expected = m(inp.view(6, 5)).view(2, 3, 8)
self.assertEqual(expected, m(inp))
def test_linear_raise_on_scalar_input(self):
# This used to cause an int underflow issue when reshaping the input
# see https://github.com/pytorch/pytorch/issues/119161
m = nn.Linear(1, 1)
inp = torch.ones(1).squeeze()
with self.assertRaisesRegex(RuntimeError, ".*both arguments.*1D.*"):
m(inp)
@parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
@parametrize_test('bias', [
subtest(False, name='nobias'), subtest(True, name='bias')])