diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index c6592476d43..b1aa0734fe5 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -177,6 +177,20 @@ Via a string and device ordinal: >>> torch.device('cpu', 0) device(type='cpu', index=0) +The device object can also be used as a context manager to change the default +device tensors are allocated on: + +:: + + >>> with torch.device('cuda:1'): + ... r = torch.randn(2, 3) + >>> r.device + device(type='cuda', index=1) + +This context manager has no effect if a factory function is passed an explicit, +non-None device argument. To globally change the default device, see also +:func:`torch.set_default_device`. + .. note:: The :class:`torch.device` argument in functions can generally be substituted with a string. This allows for fast prototyping of code. diff --git a/docs/source/torch.rst b/docs/source/torch.rst index fa5ca6948ff..111ee21f6d8 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -17,6 +17,7 @@ Tensors is_nonzero set_default_dtype get_default_dtype + set_default_device set_default_tensor_type numel set_printoptions diff --git a/test/test_autograd.py b/test/test_autograd.py index 3e394593d48..56116ce609f 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -122,22 +122,36 @@ class TestAutograd(TestCase): with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"): @torch.no_grad() class Foo(): - pass + def __init__(self): + assert not torch.is_grad_enabled() + + def foo(self): + # Not applied to methods + assert torch.is_grad_enabled() + + # Show that we can actually construct the class + foo = Foo() + foo.foo() # Decorating functions or methods is fine though with warnings.catch_warnings(record=True) as w: @torch.no_grad() def foo(): - pass + assert not torch.is_grad_enabled() + + foo() class Foo2(): @torch.no_grad() def __init__(self): - pass + assert not torch.is_grad_enabled() @torch.no_grad() def foo(self): - pass + assert not torch.is_grad_enabled() + + foo2 = Foo2() + foo2.foo() self.assertEqual(len(w), 0) diff --git a/test/test_overrides.py b/test/test_overrides.py index d91f505a22a..25d993d7af2 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -760,7 +760,7 @@ class Wrapper: val = getattr(self._data, name) # If it's a method - if callable(val): + if not isinstance(val, torch.device) and callable(val): c = getattr(type(self._data), name) # Don't append self to args if classmethod/staticmethod if c is val: diff --git a/test/test_utils.py b/test/test_utils.py index d78ab2b18e6..ad23cd31f82 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,8 +14,17 @@ import torch import torch.nn as nn import torch.utils.data from torch.utils.data import DataLoader +from torch.testing._internal.common_device_type import ( + ops, + onlyCPU, + instantiate_device_type_tests, +) +from torch.testing._internal.common_methods_invocations import op_db import torch.cuda +from torch.utils._pytree import tree_any, tree_all_only from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch import set_default_device +from torch.utils._device import set_device import torch.utils.cpp_extension from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings @@ -796,6 +805,74 @@ class TestExtensionUtils(TestCase): torch._register_device_module('xpu', DummyXPUModule) +class TestDeviceUtils(TestCase): + def test_basic(self): + with torch.device('meta') as dev: + x = torch.empty(3, 3) + self.assertEqual(x.device.type, 'meta') + self.assertEqual(dev, torch.device('meta')) + + def test_decorator(self): + @set_device('meta') + def f(): + return torch.empty(3, 3) + self.assertEqual(f().device.type, 'meta') + + def test_decorator_generator(self): + @set_device('meta') + def f(): + yield torch.empty(3, 3) + yield torch.empty(3, 3) + r1, r2 = list(f()) + self.assertEqual(r1.device.type, 'meta') + self.assertEqual(r2.device.type, 'meta') + + + def test_nn_module(self): + with torch.device('meta'): + m = nn.Linear(40, 50) + self.assertEqual(m.weight.device.type, 'meta') + + def test_set_default_device(self): + try: + set_default_device('meta') + r = torch.empty(2, 2) + finally: + set_default_device(None) + + self.assertEqual(r.device.type, 'meta') + + @onlyCPU + @ops(op_db) + def test_device_mode_ops(self, device, dtype, op): + func = op.get_op() + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample in samples: + # Only test samples which don't have Tensor inputs. However, + # we don't test the factory property on OpInfo as it is very, + # very incomplete + if tree_any( + lambda x: isinstance(x, torch.Tensor), + (sample.input, sample.args, sample.kwargs) + ): + continue + # Many OpInfos will explicitly pass in a device. DeviceContext + # will respect device if it is explicitly specified. To test + # DeviceContext, we have to remove the device kwarg in this case. + # NB: Can't pass None to sample_inputs, the function can't + # handle it. + kwargs = sample.kwargs.copy() + kwargs.pop('device', None) + with torch.device('meta'): + r = func(sample.input, *sample.args, **kwargs) + self.assertTrue( + tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r) + ) + + +instantiate_device_type_tests(TestDeviceUtils, globals()) + + class TestCppExtensionUtils(TestCase): def test_cpp_compiler_is_ok(self): self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++')) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3b0fec7d024..3ee54778888 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -45,6 +45,13 @@ class device: @overload def __init__(self, type: str, index: _int) -> None: ... + # Uncomment if we ever make torch.device a decorator + # def __call__(self, func: T) -> T: ... + + def __enter__(self) -> "device": ... + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce # Defined in torch/csrc/Stream.cpp diff --git a/torch/__init__.py b/torch/__init__.py index d3205b75fbe..f7ab4c7555e 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -35,6 +35,7 @@ import builtins __all__ = [ 'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type', + 'set_default_device', 'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed', 'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul', 'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode', @@ -444,6 +445,49 @@ def is_storage(obj): return type(obj) in _storage_classes +_GLOBAL_DEVICE_CONTEXT = None + +def set_default_device(device): + """Sets the default ``torch.Tensor`` to be allocated on ``device``. This + does not affect factory function calls which are called with an explicit + ``device`` argument. Factory calls will be performed as if they + were passed ``device`` as an argument. + + To only temporarily change the default device instead of setting it + globally, use ``with torch.device(device):`` instead. + + The default device is initially ``cpu``. If you set the default tensor + device to another device (e.g., ``cuda``) without a device index, tensors + will be allocated on whatever the current device for the device type, + even after :func:`torch.cuda.set_device` is called. + + Args: + device (device or string): the device to set as default + + Example:: + + >>> # xdoctest: +SKIP("requires cuda, changes global state") + >>> torch.tensor([1.2, 3]).device + device(type='cpu') + >>> torch.set_default_device('cuda') # current device is 0 + >>> torch.tensor([1.2, 3]).device + device(type='cuda', index=0) + >>> torch.set_default_device('cuda:1') + >>> torch.tensor([1.2, 3]).device + device(type='cuda', index=1) + + """ + global _GLOBAL_DEVICE_CONTEXT + if _GLOBAL_DEVICE_CONTEXT is not None: + _GLOBAL_DEVICE_CONTEXT.__exit__(None, None, None) + if device is None: + _GLOBAL_DEVICE_CONTEXT = None + return + from torch.utils._device import DeviceContext + _GLOBAL_DEVICE_CONTEXT = DeviceContext(device) + _GLOBAL_DEVICE_CONTEXT.__enter__() + + def set_default_tensor_type(t): r"""Sets the default ``torch.Tensor`` type to floating point tensor type ``t``. This type will also be used as default floating point type for diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index bc60cc59eb4..9e0ce5b7b83 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,96 +1,11 @@ -import sys import torch -import functools -import inspect -import warnings -from typing import Any, Callable, TypeVar, cast +from typing import Any + +from torch.utils._contextlib import _DecoratorContextManager __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled', 'inference_mode', 'set_multithreading_enabled'] - -# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'. -# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators -FuncType = Callable[..., Any] -F = TypeVar('F', bound=FuncType) - - -class _DecoratorContextManager: - """Allow a context manager to be used as a decorator""" - - def __call__(self, func: F) -> F: - if inspect.isclass(func): - warnings.warn("Decorating classes is deprecated and will be disabled in " - "future versions. You should only decorate functions or methods. " - "To preserve the current behavior of class decoration, you can " - "directly decorate the `__init__` method and nothing else.") - - if inspect.isgeneratorfunction(func): - return self._wrap_generator(func) - - @functools.wraps(func) - def decorate_context(*args, **kwargs): - with self.clone(): - return func(*args, **kwargs) - return cast(F, decorate_context) - - def _wrap_generator(self, func): - """Wrap each generator invocation with the context manager""" - @functools.wraps(func) - def generator_context(*args, **kwargs): - gen = func(*args, **kwargs) - - # Generators are suspended and unsuspended at `yield`, hence we - # make sure the grad mode is properly set every time the execution - # flow returns into the wrapped generator and restored when it - # returns through our `yield` to our caller (see PR #49017). - try: - # Issuing `None` to a generator fires it up - with self.clone(): - response = gen.send(None) - - while True: - try: - # Forward the response to our caller and get its next request - request = yield response - - except GeneratorExit: - # Inform the still active generator about its imminent closure - with self.clone(): - gen.close() - raise - - except BaseException: - # Propagate the exception thrown at us by the caller - with self.clone(): - response = gen.throw(*sys.exc_info()) - - else: - # Pass the last request to the generator and get its response - with self.clone(): - response = gen.send(request) - - # We let the exceptions raised above by the generator's `.throw` or - # `.send` methods bubble up to our caller, except for StopIteration - except StopIteration as e: - # The generator informed us that it is done: take whatever its - # returned value (if any) was and indicate that we're done too - # by returning it (see docs for python's return-statement). - return e.value - - return generator_context - - def __enter__(self) -> None: - raise NotImplementedError - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - raise NotImplementedError - - def clone(self): - # override this method if your children class takes __init__ parameters - return self.__class__() - - class no_grad(_DecoratorContextManager): r"""Context-manager that disabled gradient calculation. diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 9ea495f2c3f..331f5932b02 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -169,6 +169,36 @@ PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) { END_HANDLE_TH_ERRORS } +PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + py::object mode = py::module::import("torch.utils._device") + .attr("DeviceContext")(py::handle(self)); + at::impl::PythonTorchFunctionTLS::push_onto_stack( + std::make_shared( + mode.release().ptr(), getPyInterpreter())); + // So that with torch.device('cuda') as dev: works + Py_INCREF(self); + return self; + END_HANDLE_TH_ERRORS +} + +PyObject* THPDevice_exit(PyObject* self, PyObject* unused) { + HANDLE_TH_ERRORS + at::impl::PythonTorchFunctionTLS::pop_stack(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + py::object deco = + py::module::import("torch.utils._device").attr("device_decorator"); + return deco(py::handle(self), *py::handle(args), **py::handle(kwargs)) + .release() + .ptr(); + END_HANDLE_TH_ERRORS +} + typedef PyObject* (*getter)(PyObject*, void*); // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in @@ -182,6 +212,8 @@ static struct PyGetSetDef THPDevice_properties[] = { // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static PyMethodDef THPDevice_methods[] = { {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr}, + {"__enter__", THPDevice_enter, METH_NOARGS, nullptr}, + {"__exit__", THPDevice_exit, METH_VARARGS, nullptr}, {nullptr} /* Sentinel */ }; @@ -199,6 +231,11 @@ PyTypeObject THPDeviceType = { nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ (hashfunc)THPDevice_hash, /* tp_hash */ + // TODO: We're not sure if this is a good idea or not, because making + // torch.device callable means that it will start returning true + // for callable() queries, and that is unexpected. We can always add + // this later, so for now, don't actually implement this + // THPDevice_call, /* tp_call */ nullptr, /* tp_call */ (reprfunc)THPDevice_str, /* tp_str */ nullptr, /* tp_getattro */ diff --git a/torch/overrides.py b/torch/overrides.py index 9907996e01f..8481db3d577 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -75,6 +75,7 @@ def get_ignored_functions() -> Set[Callable]: torch.is_tensor, torch.is_storage, torch.set_default_tensor_type, + torch.set_default_device, torch.set_rng_state, torch.get_rng_state, torch.manual_seed, diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 084540edb0a..1b1d4d5f4ca 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -378,7 +378,7 @@ def kaiser( device=device, requires_grad=requires_grad) - return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta)) + return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta, device=device)) @_add_docstr( diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py new file mode 100644 index 00000000000..d34d2b9e15e --- /dev/null +++ b/torch/utils/_contextlib.py @@ -0,0 +1,143 @@ +# Extra utilities for working with context managers that should have been +# in the standard library but are not + +import functools +import inspect +import warnings +import sys +from typing import Any, Callable, TypeVar, cast + +# Used for annotating the decorator usage of _DecoratorContextManager (e.g., +# 'no_grad' and 'enable_grad'). +# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators +FuncType = Callable[..., Any] +F = TypeVar('F', bound=FuncType) + + +def _wrap_generator(ctx_factory, func): + """ + Wrap each generator invocation with the context manager factory. + + The input should be a function that returns a context manager, + not a context manager itself, to handle one-shot context managers. + """ + @functools.wraps(func) + def generator_context(*args, **kwargs): + gen = func(*args, **kwargs) + + # Generators are suspended and unsuspended at `yield`, hence we + # make sure the grad mode is properly set every time the execution + # flow returns into the wrapped generator and restored when it + # returns through our `yield` to our caller (see PR #49017). + try: + # Issuing `None` to a generator fires it up + with ctx_factory(): + response = gen.send(None) + + while True: + try: + # Forward the response to our caller and get its next request + request = yield response + + except GeneratorExit: + # Inform the still active generator about its imminent closure + with ctx_factory(): + gen.close() + raise + + except BaseException: + # Propagate the exception thrown at us by the caller + with ctx_factory(): + response = gen.throw(*sys.exc_info()) + + else: + # Pass the last request to the generator and get its response + with ctx_factory(): + response = gen.send(request) + + # We let the exceptions raised above by the generator's `.throw` or + # `.send` methods bubble up to our caller, except for StopIteration + except StopIteration as e: + # The generator informed us that it is done: take whatever its + # returned value (if any) was and indicate that we're done too + # by returning it (see docs for python's return-statement). + return e.value + + return generator_context + + +def context_decorator(ctx, func): + """ + Like contextlib.ContextDecorator, but: + + 1. Is done by wrapping, rather than inheritance, so it works with context + managers that are implemented from C and thus cannot easily inherit from + Python classes + 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) + 3. Errors out if you try to wrap a class, because it is ambiguous whether + or not you intended to wrap only the constructor + + The input argument can either be a context manager (in which case it must + be a multi-shot context manager that can be directly invoked multiple times) + or a callable that produces a context manager. + """ + + assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( + f"Passed in {ctx} is both callable and also a valid context manager " + "(has __enter__), making it ambiguous which interface to use. If you " + "intended to pass a context manager factory, rewrite your call as " + "context_decorator(lambda: ctx()); if you intended to pass a context " + "manager directly, rewrite your call as context_decorator(lambda: ctx)" + ) + + if not callable(ctx): + def ctx_factory(): + return ctx + else: + ctx_factory = ctx + + if inspect.isclass(func): + raise RuntimeError( + "Cannot decorate classes; it is ambiguous whether or not only the " + "constructor or all methods should have the context manager applied; " + "additionally, decorating a class at definition-site will prevent " + "use of the identifier as a conventional type. " + "To specify which methods to decorate, decorate each of them " + "individually." + ) + + if inspect.isgeneratorfunction(func): + return _wrap_generator(ctx_factory, func) + + @functools.wraps(func) + def decorate_context(*args, **kwargs): + with ctx_factory(): + return func(*args, **kwargs) + + return decorate_context + + +class _DecoratorContextManager: + """Allow a context manager to be used as a decorator""" + + def __call__(self, orig_func: F) -> F: + if inspect.isclass(orig_func): + warnings.warn("Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.") + func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) + else: + func = orig_func + + return cast(F, context_decorator(self.clone, func)) + + def __enter__(self) -> None: + raise NotImplementedError + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + raise NotImplementedError + + def clone(self): + # override this method if your children class takes __init__ parameters + return self.__class__() diff --git a/torch/utils/_device.py b/torch/utils/_device.py new file mode 100644 index 00000000000..54fb15df9ab --- /dev/null +++ b/torch/utils/_device.py @@ -0,0 +1,75 @@ +import torch +from torch.overrides import TorchFunctionMode +from torch.utils._contextlib import context_decorator +import functools + +@functools.lru_cache(1) +def _device_constructors(): + return { + # standard ones + torch.empty, + torch.empty_strided, + torch.empty_quantized, + torch.ones, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch.eye, + torch.fft.fftfreq, + torch.fft.rfftfreq, + torch.full, + torch.fill, + torch.hamming_window, + torch.hann_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.nested.nested_tensor, + # This function doesn't actually take a device argument + # torch.normal, + torch.ones, + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.range, + torch.sparse_coo_tensor, + torch.sparse_compressed_tensor, + torch.sparse_csr_tensor, + torch.sparse_csc_tensor, + torch.sparse_bsr_tensor, + torch.sparse_bsc_tensor, + torch.tril_indices, + torch.triu_indices, + torch.vander, + torch.zeros, + torch.asarray, + # weird ones + torch.tensor, + torch.as_tensor, + torch.scalar_tensor, + } + +# NB: This is directly called from C++ in torch/csrc/Device.cpp +class DeviceContext(TorchFunctionMode): + def __init__(self, device): + self.device = torch.device(device) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func in _device_constructors() and kwargs.get('device') is None: + kwargs['device'] = self.device + return func(*args, **kwargs) + +# NB: This is directly called from C++ in torch/csrc/Device.cpp +def device_decorator(device, func): + return context_decorator(lambda: device, func) + +def set_device(device): + """ + Decorator which sets the default device inside of the wrapped + function. If you would like to use this as a context manager, + use device as a context manager directly, e.g., + ``with torch.device(device)``. + """ + return lambda func: device_decorator(torch.device(device), func)