From 425804db2b59fe2653fbf0ece730a522865a9b2a Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 6 Feb 2025 02:14:11 +0000 Subject: [PATCH] [torch] fix exception types in custom class magic setattr/getattr (#146516) Summary: `c10::AttributeError` is not automatically converted to Python AttributeError, it needs some special macros (e.g. `HANDLE_TH_ERRORS`). Some Python functions like `hasattr` rely on the type of the throw exception to be correct. We don't need the fully generality of those macros, so just do a targeted error type conversion here. Test Plan: added unit test Differential Revision: D69197217 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146516 Approved by: https://github.com/zdevito --- test/jit/test_torchbind.py | 4 ++++ torch/csrc/jit/python/script_init.cpp | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index 726d7ce189d..cccbc1a24ed 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -445,6 +445,10 @@ class TestTorchbind(JitTestCase): self.checkScript(fn, (1,)) + def test_hasattr(self): + ss = torch.classes._TorchScriptTesting._StackString(["foo", "bar"]) + self.assertFalse(hasattr(ss, "baz")) + def test_default_args(self): def fn() -> int: obj = torch.classes._TorchScriptTesting._DefaultArgs() diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 8c5dab959b8..d9a06902c67 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -785,7 +785,8 @@ void initJitScriptBindings(PyObject* module) { try { return toPyObject(self.attr(name)); } catch (const ObjectAttributeError& err) { - throw AttributeError("%s", err.what()); + pybind11::set_error(PyExc_AttributeError, err.what()); + throw py::error_already_set(); } }) .def( @@ -806,7 +807,8 @@ void initJitScriptBindings(PyObject* module) { } return toPyObject(self.attr(name)); } catch (const ObjectAttributeError& err) { - throw AttributeError("%s", err.what()); + pybind11::set_error(PyExc_AttributeError, err.what()); + throw py::error_already_set(); } }) .def( @@ -836,7 +838,8 @@ void initJitScriptBindings(PyObject* module) { auto ivalue = toIValue(std::move(value), type); self.setattr(name, ivalue); } catch (const ObjectAttributeError& err) { - throw AttributeError("%s", err.what()); + pybind11::set_error(PyExc_AttributeError, err.what()); + throw py::error_already_set(); } }) .def(