diff --git a/test/minioptest_failures_dict.json b/test/minioptest_failures_dict.json index 60e36ad796a..e32256a37e9 100644 --- a/test/minioptest_failures_dict.json +++ b/test/minioptest_failures_dict.json @@ -7,6 +7,10 @@ "MiniOpTest.test_aot_dispatch_static__test_nonzero": { "comment": "", "status": "xfail" + }, + "MiniOpTestOther.test_aot_dispatch_static__test_nonzero_again": { + "comment": "", + "status": "xfail" } }, "aten::sin_": {}, diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index e233010bdc7..3273d2a7d87 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1776,6 +1776,15 @@ class MiniOpTest(CustomOpTestCaseBase): y = op(x) +class MiniOpTestOther(CustomOpTestCaseBase): + test_ns = "mini_op_test" + + def test_nonzero_again(self): + x = torch.tensor([0, 1, 2, 0, 0]) + y = torch.ops.aten.nonzero.default(x) + self.assertEqual(y, torch.tensor([[1], [2]])) + + mini_op_test_checks = [ "test_schema", "test_autograd_registration", @@ -1795,6 +1804,17 @@ optests.generate_opcheck_tests( mini_op_test_checks, ) +optests.generate_opcheck_tests( + MiniOpTestOther, + ["aten", "mini_op_test"], + get_file_path_2( + os.path.dirname(__file__), + "minioptest_failures_dict.json", + ), + [], + mini_op_test_checks, +) + class TestGenerateOpcheckTests(CustomOpTestCaseBase): def test_MiniOpTest(self): diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 3e1da88ea4e..9a8d61d6210 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -302,9 +302,9 @@ def validate_failures_dict_structure( if not actual_test_name.startswith(test): continue base_test_name = actual_test_name[len(test) + 2 :] - if testcase.__name__ == test_class and hasattr( - testcase, base_test_name - ): + if testcase.__name__ != test_class: + continue + if hasattr(testcase, base_test_name): continue raise RuntimeError( f"In failures dict, got test name '{test_name}'. We parsed this as "