Fix test_python_dispatch under debug mode (#98609)

The problem for these operators is that they were returning the input
directly as the output. This isn't support and will raise debug asserts.

Test Plan:
- Test locally. The debug build in CI doesn't actually do anything.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98609
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
Richard Zou 2023-04-07 11:26:35 -07:00 committed by PyTorch MergeBot
parent 01b2c45659
commit 618ea6fac3

View file

@ -80,7 +80,7 @@ class TestPythonRegistration(TestCase):
def my_sum(*args, **kwargs):
run[0] = True
return args[0]
return args[0].clone()
my_lib1 = Library("aten", "IMPL")
my_lib1.impl('aten::sum', my_sum, "CPU")
@ -216,7 +216,7 @@ class TestPythonRegistration(TestCase):
def test_extend_library_with_dispatch_key_arg(self):
def my_sum(*args, **kwargs):
return args[0]
return args[0].clone()
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
@ -236,7 +236,7 @@ class TestPythonRegistration(TestCase):
# Example 1
@torch.library.impl(my_lib1, "sum", "CPU")
def my_sum(*args, **kwargs):
return args[0]
return args[0].clone()
x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum(x), x)
@ -249,7 +249,7 @@ class TestPythonRegistration(TestCase):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
else:
return args[0]
return args[0].clone()
y = torch._efficientzerotensor(3)
self.assertTrue(torch.ops.foo.sum(y)._is_zerotensor())