mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
asarray: take the default device into consideration. (#106779)
Fix: #106773 This PR makes it so `asarray` takes the default device into consideration when called with a Python sequence as the data. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106779 Approved by: https://github.com/rgommers, https://github.com/lezcano
This commit is contained in:
parent
171341ee65
commit
a5d841ef01
5 changed files with 34 additions and 2 deletions
|
|
@ -9,6 +9,7 @@ import warnings
|
|||
import unittest
|
||||
from itertools import product, combinations, combinations_with_replacement, permutations
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -3967,6 +3968,27 @@ class TestAsArray(TestCase):
|
|||
self.assertEqual(tensor.item(), zerodim_arr.item())
|
||||
self.assertEqual(tensor.dtype, torch.int32)
|
||||
|
||||
def test_default_device(self, device):
|
||||
original = torch.arange(5)
|
||||
|
||||
examples: List[Tuple[Any, Dict]] = [
|
||||
(3, {}),
|
||||
(original, {}),
|
||||
(to_numpy(original), {}),
|
||||
(to_memview(original), {"dtype": original.dtype}),
|
||||
]
|
||||
|
||||
for data, kwargs in examples:
|
||||
with torch.device(device):
|
||||
tensor = torch.asarray(data, **kwargs)
|
||||
self.assertEqual(tensor.device, torch.device(device))
|
||||
|
||||
# Check the contents of the tensor.
|
||||
if isinstance(data, int):
|
||||
self.assertEqual(data, tensor.item())
|
||||
else:
|
||||
self.assertEqual(data, tensor)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTensorCreation, globals())
|
||||
instantiate_device_type_tests(TestRandomTensorCreation, globals())
|
||||
|
|
|
|||
|
|
@ -1257,7 +1257,7 @@ be the PyTorch datatype corresponding to the NumPy's scalar's datatype.
|
|||
|
||||
When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the
|
||||
returned tensor will, by default, infer its datatype from the scalar values, be on the
|
||||
CPU device, and not share its memory.
|
||||
current default device, and not share its memory.
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
|
@ -1282,7 +1282,8 @@ Keyword args:
|
|||
If ``False`` then the returned tensor shares its memory with :attr:`obj` and an
|
||||
error is thrown if it cannot.
|
||||
device (:class:`torch.device`, optional): the device of the returned tensor.
|
||||
Default: ``None``, which causes the device of :attr:`obj` to be used.
|
||||
Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if
|
||||
:attr:`obj` is a Python sequence, the current default device will be used.
|
||||
requires_grad (bool, optional): whether the returned tensor requires grad.
|
||||
Default: ``False``, which causes the returned tensor not to require a gradient.
|
||||
If ``True``, then the returned tensor will require a gradient, and if :attr:`obj`
|
||||
|
|
|
|||
|
|
@ -333,6 +333,11 @@ static PyObject* THPVariable_asarray(
|
|||
ParsedArgs<5> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
if (r.has_torch_function()) {
|
||||
return handle_torch_function(
|
||||
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
|
||||
}
|
||||
|
||||
if (r.idx == 0) {
|
||||
auto obj = r.pyobject(0);
|
||||
auto dtype = r.scalartypeOptional(1);
|
||||
|
|
|
|||
|
|
@ -1633,6 +1633,9 @@ Tensor asarray(
|
|||
bool force_alias = !copy.value_or(true);
|
||||
bool should_warn_numpy_not_writable = false;
|
||||
|
||||
// Used when:
|
||||
// 1. 'obj' implements the buffer protocol and no type is given.
|
||||
// 2. creating a new tensor from a Python sequence.
|
||||
auto dtype_unwrapped =
|
||||
dtype.value_or(torch::tensors::get_default_scalar_type());
|
||||
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ def _device_constructors():
|
|||
torch.tensor,
|
||||
torch.as_tensor,
|
||||
torch.scalar_tensor,
|
||||
torch.asarray,
|
||||
}
|
||||
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
|
|
|
|||
Loading…
Reference in a new issue