2021-10-29 19:15:30 +00:00
|
|
|
# Owner(s): ["high priority"]
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
import torch
|
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
2021-11-09 22:29:55 +00:00
|
|
|
from torch.testing._internal.logging_tensor import LoggingTensor, log_input, capture_logs, no_dispatch
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
from torch.utils._pytree import tree_map
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
from torch.utils._python_dispatch import enable_python_mode
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
import logging
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
class TestPythonDispatch(TestCase):
|
|
|
|
|
def test_basic(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
y = x * x
|
|
|
|
|
saved_x = y.grad_fn._saved_self
|
|
|
|
|
grad_y = LoggingTensor(torch.tensor([1.0]))
|
|
|
|
|
log_input("grad_y", grad_y)
|
2021-08-12 18:39:31 +00:00
|
|
|
g, = torch.autograd.grad((y,), (x,), (grad_y,))
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
self.assertEqual(g.elem, torch.tensor([6.0]))
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.assertEqual(saved_x, x)
|
|
|
|
|
self.assertEqual(saved_x._version, x._version)
|
|
|
|
|
x.add_(2)
|
|
|
|
|
self.assertEqual(saved_x, x)
|
|
|
|
|
# TODO: figure out why broken
|
|
|
|
|
# self.assertEqual(saved_x._version, x._version)
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = torch._ops.aten.mul($0, $0)
|
|
|
|
|
$2 = input('grad_y')
|
|
|
|
|
$3 = torch._ops.aten.mul($2, $0)
|
|
|
|
|
$4 = torch._ops.aten.mul($2, $0)
|
2021-08-12 18:39:31 +00:00
|
|
|
$5 = torch._ops.aten.add($4, $3)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
def test_out(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = LoggingTensor(torch.zeros(1))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
log_input("y", y)
|
|
|
|
|
torch.abs(x, out=y)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(y.elem, torch.ones(1))
|
|
|
|
|
# TODO: arguably this shouldn't pass and we should complain
|
|
|
|
|
# that out isn't a kwarg
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = input('y')
|
2021-08-12 18:39:31 +00:00
|
|
|
$2 = torch._ops.aten.abs($0, out=$1)''')
|
|
|
|
|
|
2021-08-09 16:59:01 +00:00
|
|
|
|
|
|
|
|
def test_kwarg_only(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = LoggingTensor(torch.ones(1, 1))
|
|
|
|
|
z = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
log_input("y", y)
|
|
|
|
|
log_input("z", z)
|
|
|
|
|
torch.addmv(x, y, z)
|
|
|
|
|
torch.addmv(x, y, z, beta=1)
|
|
|
|
|
torch.addmv(x, y, z, beta=2)
|
|
|
|
|
torch.addmv(x, y, z, alpha=2)
|
|
|
|
|
torch.addmv(x, y, z, beta=2, alpha=2)
|
|
|
|
|
|
|
|
|
|
# The expectation is that beta/alpha don't show up when they're
|
|
|
|
|
# defaulted. This is even if the user explicitly specified it.
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2021-08-09 16:59:01 +00:00
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = input('y')
|
|
|
|
|
$2 = input('z')
|
|
|
|
|
$3 = torch._ops.aten.addmv($0, $1, $2)
|
|
|
|
|
$4 = torch._ops.aten.addmv($0, $1, $2)
|
|
|
|
|
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2)
|
|
|
|
|
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2)
|
2021-08-12 18:39:31 +00:00
|
|
|
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
|
2021-08-09 16:59:01 +00:00
|
|
|
|
|
|
|
|
def test_kwarg_only_and_positional_default(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
log_input("y", y)
|
|
|
|
|
torch.ops.aten.kl_div(x, y)
|
|
|
|
|
torch.ops.aten.kl_div(x, y, 2)
|
|
|
|
|
torch.ops.aten.kl_div(x, y, log_target=True)
|
|
|
|
|
torch.ops.aten.kl_div(x, y, 2, log_target=True)
|
|
|
|
|
|
|
|
|
|
# What we are testing here is that we omit reduction
|
|
|
|
|
# if it is defaulted, even if a kwarg is set
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2021-08-09 16:59:01 +00:00
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = input('y')
|
|
|
|
|
$2 = torch._ops.aten.kl_div($0, $1)
|
|
|
|
|
$3 = torch._ops.aten.kl_div($0, $1, 2)
|
|
|
|
|
$4 = torch._ops.aten.kl_div($0, $1, log_target=True)
|
2021-08-12 18:39:31 +00:00
|
|
|
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
def test_list_ret(self) -> None:
|
|
|
|
|
# test all sequence types are permissible returns
|
|
|
|
|
for list_type in (list, tuple):
|
|
|
|
|
class A(torch._C._TensorBase):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
if func == torch.ops.aten.split:
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
return list_type(torch.split(*args))
|
|
|
|
|
else:
|
|
|
|
|
raise AssertionError(f"unrecognized func: {func}")
|
|
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
torch.split(A(torch.tensor([0, 1])), 2),
|
2021-08-12 18:39:31 +00:00
|
|
|
torch.split(torch.tensor([0, 1]), 2)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def test_invalid_ret(self) -> None:
|
|
|
|
|
# test invalid return gets reasonable error message
|
|
|
|
|
class A(torch._C._TensorBase):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return "arf"
|
|
|
|
|
|
2021-07-23 17:35:50 +00:00
|
|
|
# Wobbles depending on NDEBUG mode of pybind11
|
|
|
|
|
self.assertRaisesRegexp(
|
2021-08-12 18:39:31 +00:00
|
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).neg(),
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
2022-01-28 16:08:55 +00:00
|
|
|
self.assertRaisesRegexp(
|
|
|
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).detach(),
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
|
|
|
|
|
2022-01-28 16:08:55 +00:00
|
|
|
def test_detach_appears_twice_when_called_once(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
2022-01-28 16:08:55 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
x.detach()
|
|
|
|
|
# FIXME: We actually want this to emit a single detach. However,
|
|
|
|
|
# it currently emits two, for reasons unclear to us. Leaving
|
|
|
|
|
# this test here to make sure we don't regress even further (it
|
|
|
|
|
# would be bad if calling .detach() once emits 3+ detaches).
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = torch._ops.aten.detach($0)
|
|
|
|
|
$2 = torch._ops.aten.detach($1)''')
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_metadata_change_not_allowed(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = x.data
|
|
|
|
|
self.assertIsInstance(y, LoggingTensor)
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: y.resize_(4))
|
|
|
|
|
|
2021-09-22 18:09:11 +00:00
|
|
|
def test_storage(self) -> None:
|
|
|
|
|
# For now, just make sure it doesn't crash. Ideally, we should
|
|
|
|
|
# return some virtual storage that is safe to work with
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: x.storage())
|
|
|
|
|
|
|
|
|
|
def test_make_wrapper_subclass_noalloc(self) -> None:
|
|
|
|
|
# This is ludicrously big (8TB) and this should pass because wrapper
|
|
|
|
|
# subclasses don't allocate
|
|
|
|
|
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_version(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
prev_vc = x._version
|
|
|
|
|
x.detach().add_(2)
|
|
|
|
|
cur_vc = x._version
|
|
|
|
|
self.assertNotEqual(prev_vc, cur_vc)
|
|
|
|
|
x.data.add_(2)
|
|
|
|
|
self.assertEqual(cur_vc, x._version)
|
|
|
|
|
|
2021-08-18 14:45:45 +00:00
|
|
|
def test_subclass_priority(self) -> None:
|
|
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class ErrorB(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# The big tests for code coverage are test_precedence_semantics in
|
|
|
|
|
# test_overrides.py; this is just to make sure it is wired up at all
|
|
|
|
|
# correctly for __torch_dispatch__
|
|
|
|
|
class A(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA
|
|
|
|
|
|
|
|
|
|
class B(A):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorB
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_format(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
s1 = str(x)
|
|
|
|
|
s2 = repr(x)
|
|
|
|
|
s3 = f"{x}"
|
|
|
|
|
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
|
|
|
|
|
self.assertEqual(s1, s2)
|
|
|
|
|
self.assertEqual(s1, s3)
|
|
|
|
|
|
|
|
|
|
def test_custom_autograd(self) -> None:
|
|
|
|
|
escape = [None]
|
|
|
|
|
|
|
|
|
|
class Square(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):
|
|
|
|
|
y = x ** 2
|
|
|
|
|
ctx.save_for_backward(x)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
assert isinstance(grad_output, LoggingTensor)
|
2021-08-12 18:39:31 +00:00
|
|
|
x, = ctx.saved_tensors
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
assert isinstance(x, LoggingTensor)
|
|
|
|
|
escape[0] = x
|
|
|
|
|
return grad_output * 2 * x
|
|
|
|
|
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.ones(1), requires_grad=True)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
x.grad = LoggingTensor(torch.zeros(1))
|
|
|
|
|
log_input("x.grad", x.grad)
|
|
|
|
|
y = Square.apply(x)
|
|
|
|
|
grad_output = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("grad_output", grad_output)
|
|
|
|
|
y.backward(grad_output)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.assertEqual(escape[0], x)
|
|
|
|
|
self.assertEqual(escape[0]._version, x._version)
|
|
|
|
|
# TODO: figure out why x.requires_grad = False doesn't
|
|
|
|
|
# trigger an error for LoggingTensor
|
|
|
|
|
x.add_(2)
|
|
|
|
|
self.assertEqual(escape[0], x)
|
|
|
|
|
# TODO: figure out why this is broken
|
|
|
|
|
# self.assertEqual(escape[0]._version, x._version)
|
|
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
$0 = input('x')
|
|
|
|
|
$1 = input('x.grad')
|
|
|
|
|
$2 = torch._ops.aten.pow($0, 2)
|
|
|
|
|
$3 = input('grad_output')
|
|
|
|
|
$4 = torch._ops.aten.mul($3, tensor(2))
|
|
|
|
|
$5 = torch._ops.aten.mul($4, $0)
|
2021-08-12 18:39:31 +00:00
|
|
|
$6 = torch._ops.aten.add_($1, $5)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2021-09-10 20:07:37 +00:00
|
|
|
def test_subclass_creation(self):
|
|
|
|
|
# Make sure these statements runs without error
|
|
|
|
|
# In particular checking that when internal detach returns
|
|
|
|
|
# subclasses, these are cleanly overwritten.
|
|
|
|
|
class Foo(torch.Tensor):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
|
|
|
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
|
|
|
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
|
2021-09-22 18:09:11 +00:00
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
2021-09-10 20:07:37 +00:00
|
|
|
Foo(LoggingTensor(torch.rand(2)))
|
|
|
|
|
|
2021-09-22 18:09:11 +00:00
|
|
|
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
|
|
|
|
|
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
|
2021-09-10 20:07:37 +00:00
|
|
|
|
2021-09-21 17:45:44 +00:00
|
|
|
def test_new_ones(self) -> None:
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return MyTensor(3)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
|
|
|
|
|
|
2021-10-13 20:32:02 +00:00
|
|
|
def test_like(self) -> None:
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return MyTensor(3)
|
|
|
|
|
|
|
|
|
|
for f in ["empty", "ones", "rand", "randn", "zeros"]:
|
|
|
|
|
f_name = f + "_like"
|
|
|
|
|
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
|
|
|
|
|
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
|
|
|
|
|
|
2021-11-18 15:06:26 +00:00
|
|
|
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
|
|
|
|
|
class WrapperTensor(torch.Tensor):
|
|
|
|
|
elem: torch.Tensor
|
|
|
|
|
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad,
|
|
|
|
|
strides=elem.stride(), storage_offset=elem.storage_offset())
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise RuntimeError("NYI")
|
|
|
|
|
|
|
|
|
|
# non-contiguous strides, non-zero storage offset
|
|
|
|
|
x = torch.randn(4, 6).t().diagonal(offset=2)
|
|
|
|
|
y = WrapperTensor(x)
|
|
|
|
|
self.assertEqual(y.size(), x.size())
|
|
|
|
|
self.assertEqual(y.stride(), x.stride())
|
|
|
|
|
self.assertEqual(y.storage_offset(), x.storage_offset())
|
|
|
|
|
|
2021-11-29 16:29:05 +00:00
|
|
|
def test_index_put_where_only_index_is_subclass(self) -> None:
|
|
|
|
|
called_funcs = []
|
|
|
|
|
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
elem: torch.Tensor
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad
|
|
|
|
|
)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
called_funcs.append(func)
|
|
|
|
|
return MyTensor(torch.tensor(3))
|
|
|
|
|
|
|
|
|
|
x = torch.randn(3, 3)
|
|
|
|
|
idxs = (MyTensor(torch.tensor(0)),)
|
|
|
|
|
v = torch.randn(1)
|
|
|
|
|
res = x.index_put_(idxs, v)
|
|
|
|
|
self.assertEqual(called_funcs, [torch.ops.aten.index_put_])
|
|
|
|
|
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
def test_enable_python_mode_error(self) -> None:
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
|
|
|
|
with enable_python_mode(torch.Tensor):
|
|
|
|
|
pass
|
|
|
|
|
z = LoggingTensor(torch.empty([]))
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "must be the type"):
|
|
|
|
|
with enable_python_mode(z):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def test_enable_python_mode_basic(self) -> None:
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
z = torch.empty([])
|
|
|
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
|
|
|
|
|
|
|
|
def test_enable_python_mode_unrelated_tensors(self) -> None:
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
z = x + y
|
|
|
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
|
|
|
|
|
|
|
|
def test_enable_python_mode_subclass_priority(self) -> None:
|
|
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class ErrorB(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class A(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA
|
|
|
|
|
|
|
|
|
|
class B(A):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorB
|
|
|
|
|
|
|
|
|
|
a = A(torch.empty(1))
|
|
|
|
|
b = B(torch.empty(1))
|
|
|
|
|
with self.assertRaises(ErrorA):
|
|
|
|
|
a + a
|
|
|
|
|
|
|
|
|
|
# B has precedence over A due to the subclass relationship
|
|
|
|
|
with self.assertRaises(ErrorB):
|
|
|
|
|
with enable_python_mode(A):
|
|
|
|
|
b + b
|
|
|
|
|
with self.assertRaises(ErrorB):
|
|
|
|
|
with enable_python_mode(B):
|
|
|
|
|
a + a
|
|
|
|
|
with self.assertRaises(ErrorB):
|
|
|
|
|
with enable_python_mode(B):
|
|
|
|
|
a + b
|
|
|
|
|
|
|
|
|
|
def test_enable_python_mode_respects_no_dispatch(self) -> None:
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
z = torch.ones([2, 3])
|
|
|
|
|
self.assertTrue(isinstance(z, LoggingTensor))
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
expected = torch.ones([2, 3])
|
|
|
|
|
self.assertEqual(z.elem, expected)
|
|
|
|
|
|
|
|
|
|
def test_nested_enable_python_mode(self) -> None:
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
pass
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2021-10-13 20:49:31 +00:00
|
|
|
def test_tolist_numpy_with_python_mode(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.tensor([2.0, 3.0]))
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
|
|
|
x.tolist()
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
|
|
|
x.numpy()
|
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
|
self.assertEqual(x, None)
|
|
|
|
|
|
2021-09-22 17:57:59 +00:00
|
|
|
def test_enable_python_mode_subclass_autograd_device_check(self) -> None:
|
2021-12-02 15:45:35 +00:00
|
|
|
class NonWrapperSubclass(torch.Tensor):
|
2021-09-24 15:37:41 +00:00
|
|
|
elem: torch.Tensor
|
|
|
|
|
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
# Wrong device here!
|
|
|
|
|
r = torch.Tensor._make_subclass(cls, elem.to("meta"), elem.requires_grad)
|
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
def unwrap(e):
|
2021-12-02 15:45:35 +00:00
|
|
|
return e.elem if isinstance(e, NonWrapperSubclass) else e
|
2021-09-24 15:37:41 +00:00
|
|
|
|
|
|
|
|
def wrap(e):
|
2021-12-02 15:45:35 +00:00
|
|
|
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
|
2021-09-24 15:37:41 +00:00
|
|
|
|
|
|
|
|
# no_dispatch is only needed if you use enable_python_mode.
|
|
|
|
|
# It prevents infinite recursion.
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
2021-12-02 15:45:35 +00:00
|
|
|
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
2021-09-24 15:37:41 +00:00
|
|
|
return rs
|
|
|
|
|
|
2021-12-02 15:45:35 +00:00
|
|
|
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
|
2021-09-24 15:37:41 +00:00
|
|
|
y = torch.randn(2, requires_grad=True)
|
2021-09-22 17:57:59 +00:00
|
|
|
z = x * y
|
2021-12-02 15:45:35 +00:00
|
|
|
self.assertIsInstance(z, NonWrapperSubclass)
|
2021-09-24 15:37:41 +00:00
|
|
|
z.sum().backward(torch.tensor(1))
|
2021-09-22 17:57:59 +00:00
|
|
|
self.assertEqual(x.grad, y)
|
2021-09-24 15:37:41 +00:00
|
|
|
self.assertEqual(y.grad, x)
|
2021-09-22 17:57:59 +00:00
|
|
|
|
2021-12-02 15:45:35 +00:00
|
|
|
def test_none_wrapping(self):
|
|
|
|
|
# A Tensor subclass that returns None when doing add
|
|
|
|
|
# See LoggingTensor above for more details on the subclass
|
|
|
|
|
class SubclassWithNone(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad
|
|
|
|
|
)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
def unwrap(e):
|
|
|
|
|
return e.elem if isinstance(e, SubclassWithNone) else e
|
|
|
|
|
|
|
|
|
|
def wrap(e):
|
|
|
|
|
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
|
|
|
|
|
|
# no_dispatch is only needed if you use enable_python_mode.
|
|
|
|
|
# It prevents infinite recursion.
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
|
|
|
|
if func.__name__ == "add":
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return rs
|
|
|
|
|
|
|
|
|
|
x = SubclassWithNone(torch.rand(2))
|
|
|
|
|
# Make sure both run without error
|
|
|
|
|
self.assertIsInstance(x * 2, SubclassWithNone)
|
|
|
|
|
self.assertIsNone(x + 2)
|
|
|
|
|
|
|
|
|
|
x.requires_grad_()
|
|
|
|
|
out = x.acos().sum()
|
|
|
|
|
|
|
|
|
|
# The backward of acos does add then rsqrt so here we make sure that the
|
|
|
|
|
# undefined Tensor generated by the user code is nicely handled.
|
|
|
|
|
# If acos formula changes in the future, this can be replaced by any other
|
|
|
|
|
# function that does add then something in the backward in a composite way
|
2021-12-08 16:43:33 +00:00
|
|
|
with self.assertRaisesRegex(RuntimeError, "but got None"):
|
2021-12-02 15:45:35 +00:00
|
|
|
out.backward()
|
|
|
|
|
|
2022-01-20 02:07:36 +00:00
|
|
|
def test_storage_can_be_converted_to_python_object(self):
|
|
|
|
|
with enable_python_mode(LoggingTensor):
|
|
|
|
|
s = torch.Storage()
|
|
|
|
|
z = LoggingTensor(torch.empty([]))
|
|
|
|
|
z.set_(s)
|
|
|
|
|
|
2022-02-14 20:05:41 +00:00
|
|
|
def test_autograd_in_attr(self):
|
|
|
|
|
# We want the wrapped Tensor to require gradients!
|
|
|
|
|
true_t = torch.rand(2, requires_grad=True)
|
|
|
|
|
t = LoggingTensor(true_t)
|
|
|
|
|
|
|
|
|
|
out = t + 2
|
|
|
|
|
|
|
|
|
|
self.assertFalse(out.requires_grad)
|
|
|
|
|
self.assertIsNone(out.grad_fn)
|
|
|
|
|
|
2022-02-14 23:09:10 +00:00
|
|
|
# TODO: this should be True
|
|
|
|
|
self.assertFalse(out.elem.requires_grad)
|
|
|
|
|
# TODO: this should be not None
|
|
|
|
|
self.assertIsNone(out.elem.grad_fn)
|
2022-02-14 20:05:41 +00:00
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
2022-02-14 23:09:10 +00:00
|
|
|
out.backward()
|
2022-02-14 20:05:41 +00:00
|
|
|
|
2022-02-14 23:09:10 +00:00
|
|
|
# TODO: this should not raise
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
|
|
|
|
out.elem.backward()
|
2022-02-14 20:05:41 +00:00
|
|
|
|
|
|
|
|
self.assertIsNone(t.grad)
|
2022-02-14 23:09:10 +00:00
|
|
|
# TODO: this should not be None
|
|
|
|
|
self.assertIsNone(t.elem.grad)
|
2022-02-14 20:05:41 +00:00
|
|
|
|
2022-01-20 02:07:36 +00:00
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
if __name__ == '__main__':
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
run_tests()
|