mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add __torch_function__ handler to get_device cpp (#132567)
From the issue:
```
import torch
class CustomParameter(torch.nn.Parameter):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return func.__name__
x = CustomParameter(torch.rand(2))
print(x.square()) # 'square'
print(torch.square(x)) # 'square'
print(x.get_device()) # 'get_device'
print(torch.get_device(x)) # -1
```
after fix:
```
$ python repro.py
square
square
get_device
get_device
```
Fixes: https://github.com/pytorch/pytorch/issues/131944
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132567
Approved by: https://github.com/ezyang
This commit is contained in:
parent
7f8a384a8f
commit
8ff310392e
1 changed files with 4 additions and 0 deletions
|
|
@ -279,6 +279,10 @@ static PyObject* THPVariable_get_device(
|
|||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.has_torch_function()) {
|
||||
return handle_torch_function(
|
||||
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
|
||||
}
|
||||
|
||||
if (r.idx == 0) {
|
||||
return wrap(r.tensor(0).get_device());
|
||||
|
|
|
|||
Loading…
Reference in a new issue