mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fixed an issue where nn.Linear would cause an internal int underflow … (#119221)
…when trying to reshape a scalar input. Fixes #119161 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119221 Approved by: https://github.com/albanD
This commit is contained in:
parent
7fd6b1c558
commit
d9a1b25807
2 changed files with 16 additions and 1 deletions
|
|
@ -70,6 +70,14 @@ static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weigh
|
|||
|
||||
|
||||
Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
|
||||
// _matmul_impl checks this again later, but _flatten_nd_linear does not work on scalars inputs,
|
||||
// so let's try to catch this here already
|
||||
const auto input_dim = input.dim();
|
||||
const auto weight_dim = weight.dim();
|
||||
TORCH_CHECK(input_dim != 0 && weight_dim != 0,
|
||||
"both arguments to linear need to be at least 1D, but they are ",
|
||||
input_dim, "D and ", weight_dim, "D");
|
||||
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
auto bias = bias_opt.has_value()
|
||||
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
|
||||
|
|
@ -82,7 +90,6 @@ Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Ten
|
|||
return xnnpack::linear(input, weight, *bias);
|
||||
}
|
||||
#endif
|
||||
const auto input_dim = input.dim();
|
||||
if (input_dim == 2 && bias->defined()) {
|
||||
// Fused op is marginally faster.
|
||||
return at::addmm(*bias, input, weight.t());
|
||||
|
|
|
|||
|
|
@ -6526,6 +6526,14 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
expected = m(inp.view(6, 5)).view(2, 3, 8)
|
||||
self.assertEqual(expected, m(inp))
|
||||
|
||||
def test_linear_raise_on_scalar_input(self):
|
||||
# This used to cause an int underflow issue when reshaping the input
|
||||
# see https://github.com/pytorch/pytorch/issues/119161
|
||||
m = nn.Linear(1, 1)
|
||||
inp = torch.ones(1).squeeze()
|
||||
with self.assertRaisesRegex(RuntimeError, ".*both arguments.*1D.*"):
|
||||
m(inp)
|
||||
|
||||
@parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
|
||||
@parametrize_test('bias', [
|
||||
subtest(False, name='nobias'), subtest(True, name='bias')])
|
||||
|
|
|
|||
Loading…
Reference in a new issue