add __torch_function__ handler to get_device cpp (#132567)

From the issue:
```
import torch

class CustomParameter(torch.nn.Parameter):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
         return func.__name__

x = CustomParameter(torch.rand(2))

print(x.square()) # 'square'
print(torch.square(x)) # 'square'
print(x.get_device()) # 'get_device'
print(torch.get_device(x)) # -1
```
after fix:
```
$ python repro.py
square
square
get_device
get_device
```

Fixes: https://github.com/pytorch/pytorch/issues/131944

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132567
Approved by: https://github.com/ezyang
This commit is contained in:
Gabriel Ferns 2024-08-04 04:26:28 +00:00 committed by PyTorch MergeBot
parent 7f8a384a8f
commit 8ff310392e

View file

@ -279,6 +279,10 @@ static PyObject* THPVariable_get_device(
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (r.idx == 0) {
return wrap(r.tensor(0).get_device());