From a5d841ef01e615e2a654fb12cf0cd08697d12ccf Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 9 Aug 2023 18:23:40 -0300 Subject: [PATCH] `asarray`: take the default device into consideration. (#106779) Fix: #106773 This PR makes it so `asarray` takes the default device into consideration when called with a Python sequence as the data. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106779 Approved by: https://github.com/rgommers, https://github.com/lezcano --- test/test_tensor_creation_ops.py | 22 +++++++++++++++++++ torch/_torch_docs.py | 5 +++-- .../python_torch_functions_manual.cpp | 5 +++++ torch/csrc/utils/tensor_new.cpp | 3 +++ torch/utils/_device.py | 1 + 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 20bbe8423e7..9b7372bf3fd 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -9,6 +9,7 @@ import warnings import unittest from itertools import product, combinations, combinations_with_replacement, permutations import random +from typing import Any, Dict, List, Tuple from torch.testing import make_tensor from torch.testing._internal.common_utils import ( @@ -3967,6 +3968,27 @@ class TestAsArray(TestCase): self.assertEqual(tensor.item(), zerodim_arr.item()) self.assertEqual(tensor.dtype, torch.int32) + def test_default_device(self, device): + original = torch.arange(5) + + examples: List[Tuple[Any, Dict]] = [ + (3, {}), + (original, {}), + (to_numpy(original), {}), + (to_memview(original), {"dtype": original.dtype}), + ] + + for data, kwargs in examples: + with torch.device(device): + tensor = torch.asarray(data, **kwargs) + self.assertEqual(tensor.device, torch.device(device)) + + # Check the contents of the tensor. + if isinstance(data, int): + self.assertEqual(data, tensor.item()) + else: + self.assertEqual(data, tensor) + instantiate_device_type_tests(TestTensorCreation, globals()) instantiate_device_type_tests(TestRandomTensorCreation, globals()) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 1b335244faa..ee8d3351407 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1257,7 +1257,7 @@ be the PyTorch datatype corresponding to the NumPy's scalar's datatype. When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the returned tensor will, by default, infer its datatype from the scalar values, be on the -CPU device, and not share its memory. +current default device, and not share its memory. .. seealso:: @@ -1282,7 +1282,8 @@ Keyword args: If ``False`` then the returned tensor shares its memory with :attr:`obj` and an error is thrown if it cannot. device (:class:`torch.device`, optional): the device of the returned tensor. - Default: ``None``, which causes the device of :attr:`obj` to be used. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. requires_grad (bool, optional): whether the returned tensor requires grad. Default: ``False``, which causes the returned tensor not to require a gradient. If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index e3e0a96ad58..21881980a19 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -333,6 +333,11 @@ static PyObject* THPVariable_asarray( ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function( + r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + } + if (r.idx == 0) { auto obj = r.pyobject(0); auto dtype = r.scalartypeOptional(1); diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 8eafc0aae59..6a68cc674a8 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1633,6 +1633,9 @@ Tensor asarray( bool force_alias = !copy.value_or(true); bool should_warn_numpy_not_writable = false; + // Used when: + // 1. 'obj' implements the buffer protocol and no type is given. + // 2. creating a new tensor from a Python sequence. auto dtype_unwrapped = dtype.value_or(torch::tensors::get_default_scalar_type()); diff --git a/torch/utils/_device.py b/torch/utils/_device.py index 8d34a711567..bc548612f23 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -51,6 +51,7 @@ def _device_constructors(): torch.tensor, torch.as_tensor, torch.scalar_tensor, + torch.asarray, } # NB: This is directly called from C++ in torch/csrc/Device.cpp