From 8ff310392ed4ae87cf393bcc19fae4e25a26a3db Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Sun, 4 Aug 2024 04:26:28 +0000 Subject: [PATCH] add __torch_function__ handler to get_device cpp (#132567) From the issue: ``` import torch class CustomParameter(torch.nn.Parameter): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return func.__name__ x = CustomParameter(torch.rand(2)) print(x.square()) # 'square' print(torch.square(x)) # 'square' print(x.get_device()) # 'get_device' print(torch.get_device(x)) # -1 ``` after fix: ``` $ python repro.py square square get_device get_device ``` Fixes: https://github.com/pytorch/pytorch/issues/131944 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132567 Approved by: https://github.com/ezyang --- torch/csrc/autograd/python_torch_functions_manual.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 2817506ebdc..c36cf275a6b 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -279,6 +279,10 @@ static PyObject* THPVariable_get_device( ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function( + r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + } if (r.idx == 0) { return wrap(r.tensor(0).get_device());