mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix test_python_dispatch under debug mode (#98609)
The problem for these operators is that they were returning the input directly as the output. This isn't support and will raise debug asserts. Test Plan: - Test locally. The debug build in CI doesn't actually do anything. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98609 Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
parent
01b2c45659
commit
618ea6fac3
1 changed files with 4 additions and 4 deletions
|
|
@ -80,7 +80,7 @@ class TestPythonRegistration(TestCase):
|
|||
|
||||
def my_sum(*args, **kwargs):
|
||||
run[0] = True
|
||||
return args[0]
|
||||
return args[0].clone()
|
||||
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib1.impl('aten::sum', my_sum, "CPU")
|
||||
|
|
@ -216,7 +216,7 @@ class TestPythonRegistration(TestCase):
|
|||
|
||||
def test_extend_library_with_dispatch_key_arg(self):
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0]
|
||||
return args[0].clone()
|
||||
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
|
||||
|
||||
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
||||
|
|
@ -236,7 +236,7 @@ class TestPythonRegistration(TestCase):
|
|||
# Example 1
|
||||
@torch.library.impl(my_lib1, "sum", "CPU")
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0]
|
||||
return args[0].clone()
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.ops.foo.sum(x), x)
|
||||
|
|
@ -249,7 +249,7 @@ class TestPythonRegistration(TestCase):
|
|||
if args[0]._is_zerotensor():
|
||||
return torch._efficientzerotensor(args[0].shape)
|
||||
else:
|
||||
return args[0]
|
||||
return args[0].clone()
|
||||
|
||||
y = torch._efficientzerotensor(3)
|
||||
self.assertTrue(torch.ops.foo.sum(y)._is_zerotensor())
|
||||
|
|
|
|||
Loading…
Reference in a new issue