mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implement Python Array API asarray function. (#60627)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60627 In this PR, the core of `frombuffer` and `fromDLPack` onto _tensor_new.cpp_. `asarray` uses such refactored functions for interpreting the object as a tensor. We follow the Python Array API standard found: https://data-apis.org/array-api/latest/API_specification/creation_functions.html?highlight=asarray Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D31640510 Pulled By: mruberry fbshipit-source-id: d0869e0d73cb50023d5866b001dac5d34ca30dfd
This commit is contained in:
parent
9e3a2babfa
commit
8854817f44
12 changed files with 749 additions and 276 deletions
|
|
@ -54,6 +54,7 @@ Creation Ops
|
|||
|
||||
tensor
|
||||
sparse_coo_tensor
|
||||
asarray
|
||||
as_tensor
|
||||
as_strided
|
||||
from_numpy
|
||||
|
|
|
|||
|
|
@ -1,178 +0,0 @@
|
|||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes
|
||||
)
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
return int(torch.empty((), dtype=dtype).element_size())
|
||||
|
||||
SIZE = 5
|
||||
SHAPE = (SIZE,)
|
||||
|
||||
# Tests for the `frombuffer` function (only work on CPU):
|
||||
# Constructs tensors from Python objects that implement the buffer protocol,
|
||||
# without copying data.
|
||||
class TestBufferProtocol(common.TestCase):
|
||||
def _run_test(self, shape, dtype, count=-1, first=0, offset=None, **kwargs):
|
||||
numpy_dtype = common.torch_to_numpy_dtype_dict[dtype]
|
||||
|
||||
if offset is None:
|
||||
offset = first * get_dtype_size(dtype)
|
||||
|
||||
numpy_original = make_tensor(shape, torch.device("cpu"), dtype).numpy()
|
||||
original = memoryview(numpy_original)
|
||||
# First call PyTorch's version in case of errors.
|
||||
# If this call exits successfully, the NumPy version must also do so.
|
||||
torch_frombuffer = torch.frombuffer(original, dtype=dtype, count=count, offset=offset, **kwargs)
|
||||
numpy_frombuffer = numpy.frombuffer(original, dtype=numpy_dtype, count=count, offset=offset)
|
||||
|
||||
self.assertEqual(numpy_frombuffer, torch_frombuffer)
|
||||
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
|
||||
return (numpy_original, torch_frombuffer)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_same_type(self, device, dtype):
|
||||
self._run_test((), dtype)
|
||||
self._run_test((4,), dtype)
|
||||
self._run_test((10, 10), dtype)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_requires_grad(self, device, dtype):
|
||||
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
|
||||
kwargs["requires_grad"] = requires_grad
|
||||
_, tensor = self._run_test(*args, **kwargs)
|
||||
self.assertTrue(tensor.requires_grad == requires_grad)
|
||||
|
||||
requires_grad = dtype.is_floating_point or dtype.is_complex
|
||||
_run_test_and_check_grad(requires_grad, (), dtype)
|
||||
_run_test_and_check_grad(requires_grad, (4,), dtype)
|
||||
_run_test_and_check_grad(requires_grad, (10, 10), dtype)
|
||||
_run_test_and_check_grad(False, (), dtype)
|
||||
_run_test_and_check_grad(False, (4,), dtype)
|
||||
_run_test_and_check_grad(False, (10, 10), dtype)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_offset(self, device, dtype):
|
||||
# Offset should be valid whenever there is, at least,
|
||||
# one remaining element
|
||||
for i in range(SIZE):
|
||||
self._run_test(SHAPE, dtype, first=i)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_count(self, device, dtype):
|
||||
# Count should be valid for any valid in the interval
|
||||
# [-1, len(input)], except for 0
|
||||
for i in range(-1, SIZE + 1):
|
||||
if i != 0:
|
||||
self._run_test(SHAPE, dtype, count=i)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_count_and_offset(self, device, dtype):
|
||||
# Explicit default count [-1, 1, 2, ..., len]
|
||||
for i in range(-1, SIZE + 1):
|
||||
if i != 0:
|
||||
self._run_test(SHAPE, dtype, count=i)
|
||||
# Explicit default offset [0, 1, ..., len - 1]
|
||||
for i in range(SIZE):
|
||||
self._run_test(SHAPE, dtype, first=i)
|
||||
# All possible combinations of count and dtype aligned
|
||||
# offset for 'input'
|
||||
# count:[1, 2, ..., len - 1] x first:[0, 1, ..., len - count]
|
||||
for i in range(1, SIZE):
|
||||
for j in range(SIZE - i + 1):
|
||||
self._run_test(SHAPE, dtype, count=i, first=j)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_invalid_positional_args(self, device, dtype):
|
||||
bytes = get_dtype_size(dtype)
|
||||
in_bytes = SIZE * bytes
|
||||
# Empty array
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"both buffer length \(0\) and count"):
|
||||
empty = numpy.array([])
|
||||
torch.frombuffer(empty, dtype=dtype)
|
||||
# Count equals 0
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"both buffer length .* and count \(0\)"):
|
||||
self._run_test(SHAPE, dtype, count=0)
|
||||
# Offset negative and bigger than total length
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"offset \(-{bytes} bytes\) must be"):
|
||||
self._run_test(SHAPE, dtype, first=-1)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"offset \({in_bytes} bytes\) must be .* "
|
||||
rf"buffer length \({in_bytes} bytes\)"):
|
||||
self._run_test(SHAPE, dtype, first=SIZE)
|
||||
# Non-multiple offset with all elements
|
||||
if bytes > 1:
|
||||
offset = bytes - 1
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"buffer length \({in_bytes - offset} bytes\) after "
|
||||
rf"offset \({offset} bytes\) must be"):
|
||||
self._run_test(SHAPE, dtype, offset=bytes - 1)
|
||||
# Count too big for each good first element
|
||||
for first in range(SIZE):
|
||||
count = SIZE - first + 1
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"requested buffer length \({count} \* {bytes} bytes\) "
|
||||
rf"after offset \({first * bytes} bytes\) must .*"
|
||||
rf"buffer length \({in_bytes} bytes\)"):
|
||||
self._run_test(SHAPE, dtype, count=count, first=first)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_shared_buffer(self, device, dtype):
|
||||
x = make_tensor((1,), device, dtype)
|
||||
# Modify the whole tensor
|
||||
arr, tensor = self._run_test(SHAPE, dtype)
|
||||
tensor[:] = x
|
||||
self.assertEqual(arr, tensor)
|
||||
self.assertTrue((tensor == x).all().item())
|
||||
|
||||
# Modify the whole tensor from all valid offsets, given
|
||||
# a count value
|
||||
for count in range(-1, SIZE + 1):
|
||||
if count == 0:
|
||||
continue
|
||||
|
||||
actual_count = count if count > 0 else SIZE
|
||||
for first in range(SIZE - actual_count):
|
||||
last = first + actual_count
|
||||
arr, tensor = self._run_test(SHAPE, dtype, first=first, count=count)
|
||||
tensor[:] = x
|
||||
self.assertEqual(arr[first:last], tensor)
|
||||
self.assertTrue((tensor == x).all().item())
|
||||
|
||||
# Modify the first value in the array
|
||||
arr[first] = x.item() - 1
|
||||
self.assertEqual(arr[first:last], tensor)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_not_a_buffer(self, device, dtype):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"object does not implement Python buffer protocol."):
|
||||
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
|
||||
|
||||
@dtypes(*common.torch_to_numpy_dtype_dict.keys())
|
||||
def test_non_writable_buffer(self, device, dtype):
|
||||
numpy_arr = make_tensor((1,), device, dtype).numpy()
|
||||
byte_arr = numpy_arr.tobytes()
|
||||
with self.assertWarnsOnceRegex(UserWarning,
|
||||
r"The given buffer is not writable."):
|
||||
torch.frombuffer(byte_arr, dtype=dtype)
|
||||
|
||||
def test_byte_to_int(self):
|
||||
byte_array = numpy.array([-1, 0, 0, 0, -1, 0, 0, 0], dtype=numpy.byte)
|
||||
tensor = torch.frombuffer(byte_array, dtype=torch.int32)
|
||||
self.assertEqual(tensor.numel(), 2)
|
||||
# Assuming little endian machine
|
||||
self.assertSequenceEqual(tensor, [255, 255])
|
||||
|
||||
instantiate_device_type_tests(TestBufferProtocol, globals(), only_for="cpu")
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
@ -21,6 +21,8 @@ from torch.testing._internal.common_dtype import (
|
|||
get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
|
||||
)
|
||||
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
# TODO: refactor tri_tests_args, _compare_trilu_indices, run_additional_tri_tests
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
tri_tests_args, _compare_trilu_indices, run_additional_tri_tests)
|
||||
|
|
@ -3612,9 +3614,416 @@ class TestLikeTensorCreation(TestCase):
|
|||
self.assertEqual(torch.full_like(like, 1., dtype=torch.complex64).dtype,
|
||||
torch.complex64)
|
||||
|
||||
# Tests for the `frombuffer` function (only work on CPU):
|
||||
# Constructs tensors from Python objects that implement the buffer protocol,
|
||||
# without copying data.
|
||||
SIZE = 5
|
||||
SHAPE = (SIZE,)
|
||||
|
||||
def may_require_grad(dtype):
|
||||
return dtype.is_floating_point or dtype.is_complex
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
return int(torch.empty((), dtype=dtype).element_size())
|
||||
|
||||
class TestBufferProtocol(TestCase):
|
||||
def _run_test(self, shape, dtype, count=-1, first=0, offset=None, **kwargs):
|
||||
numpy_dtype = torch_to_numpy_dtype_dict[dtype]
|
||||
|
||||
if offset is None:
|
||||
offset = first * get_dtype_size(dtype)
|
||||
|
||||
numpy_original = make_tensor(shape, torch.device("cpu"), dtype).numpy()
|
||||
original = memoryview(numpy_original)
|
||||
# First call PyTorch's version in case of errors.
|
||||
# If this call exits successfully, the NumPy version must also do so.
|
||||
torch_frombuffer = torch.frombuffer(original, dtype=dtype, count=count, offset=offset, **kwargs)
|
||||
numpy_frombuffer = np.frombuffer(original, dtype=numpy_dtype, count=count, offset=offset)
|
||||
|
||||
self.assertEqual(numpy_frombuffer, torch_frombuffer)
|
||||
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
|
||||
return (numpy_original, torch_frombuffer)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_same_type(self, device, dtype):
|
||||
self._run_test((), dtype)
|
||||
self._run_test((4,), dtype)
|
||||
self._run_test((10, 10), dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_requires_grad(self, device, dtype):
|
||||
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
|
||||
kwargs["requires_grad"] = requires_grad
|
||||
_, tensor = self._run_test(*args, **kwargs)
|
||||
self.assertTrue(tensor.requires_grad == requires_grad)
|
||||
|
||||
requires_grad = may_require_grad(dtype)
|
||||
_run_test_and_check_grad(requires_grad, (), dtype)
|
||||
_run_test_and_check_grad(requires_grad, (4,), dtype)
|
||||
_run_test_and_check_grad(requires_grad, (10, 10), dtype)
|
||||
_run_test_and_check_grad(False, (), dtype)
|
||||
_run_test_and_check_grad(False, (4,), dtype)
|
||||
_run_test_and_check_grad(False, (10, 10), dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_offset(self, device, dtype):
|
||||
# Offset should be valid whenever there is, at least,
|
||||
# one remaining element
|
||||
for i in range(SIZE):
|
||||
self._run_test(SHAPE, dtype, first=i)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_count(self, device, dtype):
|
||||
# Count should be valid for any valid in the interval
|
||||
# [-1, len(input)], except for 0
|
||||
for i in range(-1, SIZE + 1):
|
||||
if i != 0:
|
||||
self._run_test(SHAPE, dtype, count=i)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_with_count_and_offset(self, device, dtype):
|
||||
# Explicit default count [-1, 1, 2, ..., len]
|
||||
for i in range(-1, SIZE + 1):
|
||||
if i != 0:
|
||||
self._run_test(SHAPE, dtype, count=i)
|
||||
# Explicit default offset [0, 1, ..., len - 1]
|
||||
for i in range(SIZE):
|
||||
self._run_test(SHAPE, dtype, first=i)
|
||||
# All possible combinations of count and dtype aligned
|
||||
# offset for 'input'
|
||||
# count:[1, 2, ..., len - 1] x first:[0, 1, ..., len - count]
|
||||
for i in range(1, SIZE):
|
||||
for j in range(SIZE - i + 1):
|
||||
self._run_test(SHAPE, dtype, count=i, first=j)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_invalid_positional_args(self, device, dtype):
|
||||
bytes = get_dtype_size(dtype)
|
||||
in_bytes = SIZE * bytes
|
||||
# Empty array
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"both buffer length \(0\) and count"):
|
||||
empty = np.array([])
|
||||
torch.frombuffer(empty, dtype=dtype)
|
||||
# Count equals 0
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"both buffer length .* and count \(0\)"):
|
||||
self._run_test(SHAPE, dtype, count=0)
|
||||
# Offset negative and bigger than total length
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"offset \(-{bytes} bytes\) must be"):
|
||||
self._run_test(SHAPE, dtype, first=-1)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"offset \({in_bytes} bytes\) must be .* "
|
||||
rf"buffer length \({in_bytes} bytes\)"):
|
||||
self._run_test(SHAPE, dtype, first=SIZE)
|
||||
# Non-multiple offset with all elements
|
||||
if bytes > 1:
|
||||
offset = bytes - 1
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"buffer length \({in_bytes - offset} bytes\) after "
|
||||
rf"offset \({offset} bytes\) must be"):
|
||||
self._run_test(SHAPE, dtype, offset=bytes - 1)
|
||||
# Count too big for each good first element
|
||||
for first in range(SIZE):
|
||||
count = SIZE - first + 1
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
rf"requested buffer length \({count} \* {bytes} bytes\) "
|
||||
rf"after offset \({first * bytes} bytes\) must .*"
|
||||
rf"buffer length \({in_bytes} bytes\)"):
|
||||
self._run_test(SHAPE, dtype, count=count, first=first)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_shared_buffer(self, device, dtype):
|
||||
x = make_tensor((1,), device, dtype)
|
||||
# Modify the whole tensor
|
||||
arr, tensor = self._run_test(SHAPE, dtype)
|
||||
tensor[:] = x
|
||||
self.assertEqual(arr, tensor)
|
||||
self.assertTrue((tensor == x).all().item())
|
||||
|
||||
# Modify the whole tensor from all valid offsets, given
|
||||
# a count value
|
||||
for count in range(-1, SIZE + 1):
|
||||
if count == 0:
|
||||
continue
|
||||
|
||||
actual_count = count if count > 0 else SIZE
|
||||
for first in range(SIZE - actual_count):
|
||||
last = first + actual_count
|
||||
arr, tensor = self._run_test(SHAPE, dtype, first=first, count=count)
|
||||
tensor[:] = x
|
||||
self.assertEqual(arr[first:last], tensor)
|
||||
self.assertTrue((tensor == x).all().item())
|
||||
|
||||
# Modify the first value in the array
|
||||
arr[first] = x.item() - 1
|
||||
self.assertEqual(arr[first:last], tensor)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_not_a_buffer(self, device, dtype):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"object does not implement Python buffer protocol."):
|
||||
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_non_writable_buffer(self, device, dtype):
|
||||
numpy_arr = make_tensor((1,), device, dtype).numpy()
|
||||
byte_arr = numpy_arr.tobytes()
|
||||
with self.assertWarnsOnceRegex(UserWarning,
|
||||
r"The given buffer is not writable."):
|
||||
torch.frombuffer(byte_arr, dtype=dtype)
|
||||
|
||||
def test_byte_to_int(self):
|
||||
byte_array = np.array([-1, 0, 0, 0, -1, 0, 0, 0], dtype=np.byte)
|
||||
tensor = torch.frombuffer(byte_array, dtype=torch.int32)
|
||||
self.assertEqual(tensor.numel(), 2)
|
||||
# Assuming little endian machine
|
||||
self.assertSequenceEqual(tensor, [255, 255])
|
||||
|
||||
# Tests for the `asarray` function:
|
||||
# Constructs tensors from a Python object that has one of the following
|
||||
# characteristics:
|
||||
# 1. is a Tensor
|
||||
# 2. is a DLPack capsule
|
||||
# 3. implements the Python Buffer protocol
|
||||
# 4. is an arbitrary list
|
||||
# The implementation itself is based on the Python Array API:
|
||||
# https://data-apis.org/array-api/latest/API_specification/creation_functions.html
|
||||
def get_another_device(device):
|
||||
return "cuda" if torch.device(device).type == "cpu" else "cpu"
|
||||
|
||||
def identity(tensor):
|
||||
return tensor
|
||||
def to_numpy(tensor):
|
||||
return tensor.numpy()
|
||||
def to_memview(tensor):
|
||||
return memoryview(to_numpy(tensor))
|
||||
|
||||
class TestAsArray(TestCase):
|
||||
def _check(self, original, cvt=lambda t: t, is_alias=True, same_dtype=True, same_device=True, **kwargs):
|
||||
"""Check the output of 'asarray', given its input and assertion informations.
|
||||
|
||||
Besides calling 'asarray' itself, this function does 4 different checks:
|
||||
1. Whether the result is aliased or not, depending on 'is_alias'
|
||||
2. Whether the result has the expected dtype and elements
|
||||
3. Whether the result lives in the expected device
|
||||
4. Whether the result has its 'requires_grad' set or not
|
||||
"""
|
||||
result = torch.asarray(cvt(original), **kwargs)
|
||||
self.assertTrue(isinstance(result, torch.Tensor))
|
||||
|
||||
# 1. The storage pointers should be equal only if 'is_alias' is set
|
||||
if is_alias:
|
||||
self.assertEqual(result.data_ptr(), original.data_ptr())
|
||||
else:
|
||||
self.assertNotEqual(result.data_ptr(), original.data_ptr())
|
||||
|
||||
# 2. Comparison of the elements only takes place if the original
|
||||
# sequence and the resulting tensor have the same data type
|
||||
if same_dtype:
|
||||
self.assertEqual(original, result)
|
||||
else:
|
||||
dtype = kwargs.get("dtype", torch.get_default_dtype())
|
||||
self.assertEqual(original.shape, result.shape)
|
||||
self.assertEqual(dtype, result.dtype)
|
||||
|
||||
# 3. Given the specified target device, we first check whether
|
||||
# its type is the same, and then if its index is the same (if it
|
||||
# is not None)
|
||||
if same_device:
|
||||
device = original.device
|
||||
else:
|
||||
device = torch.device(kwargs.get("device", "cpu"))
|
||||
|
||||
# Compare the target device type, and its index
|
||||
self.assertEqual(device.type, result.device.type)
|
||||
if device.index is not None:
|
||||
self.assertEqual(device.index, result.device.index)
|
||||
|
||||
# 4. By default, 'requires_grad' is unset
|
||||
self.assertEqual(result.requires_grad, kwargs.get("requires_grad", False))
|
||||
|
||||
def _test_alias_with_cvt(self, cvt, device, dtype, shape=(5, 5), only_with_dtype=False):
|
||||
original = make_tensor(shape, device, dtype)
|
||||
|
||||
def check(**kwargs):
|
||||
self._check(original, cvt=cvt, **kwargs)
|
||||
|
||||
if not only_with_dtype:
|
||||
check(copy=False)
|
||||
check(device=device)
|
||||
check(device=device, copy=False)
|
||||
|
||||
check(dtype=dtype)
|
||||
check(dtype=dtype, copy=False)
|
||||
check(requires_grad=False, dtype=dtype)
|
||||
check(requires_grad=may_require_grad(dtype), dtype=dtype)
|
||||
check(device=device, dtype=dtype)
|
||||
check(device=device, dtype=dtype, copy=False)
|
||||
|
||||
# Skipping 'meta' devices, since there's no point in comparing their
|
||||
# data pointer (which is basically the point here), since they all
|
||||
# return 0.
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_alias_from_tensor(self, device, dtype):
|
||||
self._test_alias_with_cvt(identity, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_alias_from_numpy(self, device, dtype):
|
||||
self._test_alias_with_cvt(to_numpy, device, dtype)
|
||||
|
||||
# Skipping 'meta', since 'to_dlpack' does not work for them.
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes(include_bool=False))
|
||||
def test_alias_from_dlpack(self, device, dtype):
|
||||
self._test_alias_with_cvt(to_dlpack, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_alias_from_buffer(self, device, dtype):
|
||||
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||
|
||||
def _test_copy_with_cvt(self, cvt, device, dtype, shape=(5, 5), only_with_dtype=False):
|
||||
original = make_tensor(shape, device, dtype)
|
||||
|
||||
def check(**kwargs):
|
||||
self._check(original, cvt=cvt, is_alias=False, **kwargs)
|
||||
|
||||
if not only_with_dtype:
|
||||
check(copy=True)
|
||||
check(device=device, copy=True)
|
||||
|
||||
check(requires_grad=False, dtype=dtype, copy=True)
|
||||
check(requires_grad=may_require_grad(dtype), dtype=dtype, copy=True)
|
||||
check(dtype=dtype, copy=True)
|
||||
check(device=device, dtype=dtype, copy=True)
|
||||
|
||||
# Copy is forced because of different device
|
||||
if torch.cuda.is_available():
|
||||
other = get_another_device(device)
|
||||
check(same_device=False, device=other, dtype=dtype)
|
||||
check(same_device=False, device=other, dtype=dtype, copy=True)
|
||||
|
||||
# Copy is forced because of different dtype
|
||||
if not only_with_dtype:
|
||||
for other in get_all_dtypes():
|
||||
if dtype != other:
|
||||
check(same_dtype=False, dtype=other)
|
||||
check(same_dtype=False, dtype=other, copy=True)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_copy_tensor(self, device, dtype):
|
||||
self._test_copy_with_cvt(identity, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_copy_from_numpy(self, device, dtype):
|
||||
self._test_copy_with_cvt(to_numpy, device, dtype)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes(include_bool=False))
|
||||
def test_copy_from_dlpack(self, device, dtype):
|
||||
self._test_copy_with_cvt(to_dlpack, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
def test_copy_from_buffer(self, device, dtype):
|
||||
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||
|
||||
def _test_copy_mult_devices(self, devices, dtype, cvt):
|
||||
cuda1 = devices[0]
|
||||
cuda2 = devices[1]
|
||||
original = make_tensor((5, 5), cuda1, dtype)
|
||||
|
||||
def check(**kwargs):
|
||||
self._check(original, cvt, is_alias=False, same_device=False, device=cuda2, **kwargs)
|
||||
|
||||
check()
|
||||
check(copy=True)
|
||||
check(dtype=dtype, copy=True)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
@dtypes(*get_all_dtypes(include_bool=False))
|
||||
def test_copy_from_tensor_mult_devices(self, devices, dtype):
|
||||
self._test_copy_mult_devices(devices, dtype, identity)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
@dtypes(*get_all_dtypes(include_bool=False))
|
||||
def test_copy_from_dlpack_mult_devices(self, devices, dtype):
|
||||
self._test_copy_mult_devices(devices, dtype, to_dlpack)
|
||||
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_copy_list(self, device, dtype):
|
||||
original = make_tensor((5, 5), torch.device("cpu"), dtype)
|
||||
|
||||
def check(**kwargs):
|
||||
self._check(original, torch.Tensor.tolist, is_alias=False, **kwargs)
|
||||
|
||||
same_device = torch.device("cpu") == device
|
||||
check(same_device=same_device, device=device, dtype=dtype)
|
||||
check(same_device=same_device, device=device, dtype=dtype, requires_grad=False)
|
||||
check(same_device=same_device, device=device, dtype=dtype, requires_grad=may_require_grad(dtype))
|
||||
check(same_device=same_device, device=device, dtype=dtype, copy=True)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_unsupported_alias(self, device, dtype):
|
||||
original = make_tensor((5, 5), device, dtype)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
other_device = get_another_device(device)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
f"from device '{device}' to '{other_device}'"):
|
||||
torch.asarray(original, device=other_device, copy=False)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"with dtype '.*' into dtype '.*'"):
|
||||
torch.asarray(original, dtype=torch.float64, copy=False)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"can't alias arbitrary sequence"):
|
||||
torch.asarray(original.tolist(), copy=False)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
@dtypes(torch.float32)
|
||||
def test_unsupported_alias_mult_devices(self, devices, dtype):
|
||||
dev1, dev2 = devices[:2]
|
||||
original = make_tensor((5, 5), dev1, dtype)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
f"from device '{dev1}' to '{dev2}'"):
|
||||
torch.asarray(original, device=dev2, copy=False)
|
||||
|
||||
@dtypes(torch.float32, torch.complex64)
|
||||
def test_retain_autograd_history(self, device, dtype):
|
||||
original = make_tensor((5, 5), device, dtype, requires_grad=True)
|
||||
# 'cloned' has 'grad_fn=<CloneBackwards>'
|
||||
cloned = original.clone()
|
||||
|
||||
def check(**kwargs):
|
||||
a = torch.asarray(cloned, **kwargs)
|
||||
requires_grad = kwargs.get("requires_grad", False)
|
||||
self.assertEqual(a.requires_grad, requires_grad)
|
||||
# Autograd history shouldn't be retained when requires_grad is False
|
||||
self.assertEqual(a.grad_fn is None, not requires_grad)
|
||||
|
||||
check()
|
||||
check(requires_grad=True)
|
||||
check(copy=True)
|
||||
check(requires_grad=True, copy=True)
|
||||
check(requires_grad=False)
|
||||
check(requires_grad=False, copy=True)
|
||||
|
||||
instantiate_device_type_tests(TestTensorCreation, globals())
|
||||
instantiate_device_type_tests(TestRandomTensorCreation, globals())
|
||||
instantiate_device_type_tests(TestLikeTensorCreation, globals())
|
||||
instantiate_device_type_tests(TestBufferProtocol, globals(), only_for="cpu")
|
||||
instantiate_device_type_tests(TestAsArray, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -282,6 +282,9 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
|||
unsorted_function_hints.update({
|
||||
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
|
||||
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
|
||||
'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, '
|
||||
'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, '
|
||||
'requires_grad: _bool=False) -> Tensor: ...'],
|
||||
'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
|
||||
'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, '
|
||||
'offset: int=0, device: Union[_device, str, None]=None, '
|
||||
|
|
|
|||
|
|
@ -961,6 +961,101 @@ arctanh(input, *, out=None) -> Tensor
|
|||
Alias for :func:`torch.atanh`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.asarray,
|
||||
r"""
|
||||
asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor
|
||||
|
||||
Converts :attr:`obj` into a tensor, sharing data and preserving autograd history
|
||||
if possible.
|
||||
|
||||
:attr:`obj` can be one of:
|
||||
|
||||
1. a tensor
|
||||
2. a NumPy array
|
||||
3. a DLPack capsule
|
||||
4. a Python object that implements the buffer protocol
|
||||
5. a Python sequence
|
||||
|
||||
For each of the mentioned options, in order, this functions will assume :attr:`obj`
|
||||
is of that type and try, first, sharing memory. Only then, it will make a copy (if
|
||||
necessary).
|
||||
|
||||
The dtype of the result tensor is inferred from the input object, except when
|
||||
object is (4): an object that implements the buffer protocol (see :func:`torch.frombuffer`).
|
||||
In that case, the buffer is interpreted as an array of bytes, which are grouped
|
||||
according to the size of the given :attr:`dtype` or the global default
|
||||
(see :func:`torch.set_default_tensor_type`) if `None` is given.
|
||||
|
||||
For example: NumPy arrays also implement the buffer protocol. However, since NumPy
|
||||
arrays have higher priority than objects implementing the buffer protocol, this function
|
||||
will handle them as NumPy arrays. In other words, it will infer its dtype as if using
|
||||
``torch.from_numpy`` (instead of ``torch.frombuffer``).
|
||||
|
||||
.. seealso::
|
||||
:func:`torch.as_tensor` tries to avoid copies for tensors and NumPy arrays.
|
||||
:func:`torch.tensor` always copies the data from the input object.
|
||||
:func:`torch.from_numpy` creates a tensor that shares its memory with a NumPy array.
|
||||
:func:`torch.frombuffer` creates a tensor that shares its memory with an object
|
||||
that implements the buffer protocol.
|
||||
:func:`torch.utils.dlpack.from_dlpack` creates a tensor that shares its memory
|
||||
with the object represented in the dlpack.
|
||||
|
||||
Args:
|
||||
obj (object): a Python object that satisfies, at least, one of the five options
|
||||
mentioned above.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, it will be inferred from :attr:`obj`.
|
||||
copy (bool, optional): flags whether the object memory should be copied or not.
|
||||
If ``None``, then the result tensor shares memory with the Python object
|
||||
whenever possible. If ``True``, then the object memory is copied. If ``False``,
|
||||
then the object memory is shared. If the object memory cannot be shared
|
||||
and this flag is ``False``, then an error is thrown.
|
||||
device (:class:`torch.device`, optional): the device of the constructed tensor.
|
||||
If `None`, then the device of :attr:`obj` is used. Else, it either copies
|
||||
the data, if :attr:`obj` lives in a different device, or it shares the
|
||||
memory, if :attr:`obj` lives in the same device.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. However, if this flag is ``False`` and the input object
|
||||
is a non-leaf :class:`Tensor`, this function will call :func:`torch.Tensor.detach`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.tensor([1, 2, 3])
|
||||
>>> # Shares memory with tensor 'a'
|
||||
>>> b = torch.asarray(a)
|
||||
>>> a.data_ptr() == b.data_ptr()
|
||||
True
|
||||
>>> # Forces memory copy
|
||||
>>> c = torch.asarray(a, copy=True)
|
||||
>>> a.data_ptr() == c.data_ptr()
|
||||
False
|
||||
|
||||
>>> a = torch.tensor([1, 2, 3], requires_grad=True).float()
|
||||
>>> b = a + 2
|
||||
>>> b
|
||||
tensor([1., 2., 3.], grad_fn=<AddBackward0>)
|
||||
>>> # Shares memory with tensor 'b', with no grad
|
||||
>>> c = torch.asarray(b)
|
||||
>>> c
|
||||
tensor([1., 2., 3.])
|
||||
>>> # Shares memory with tensor 'b', retaining autograd history
|
||||
>>> d = torch.asarray(b, requires_grad=True)
|
||||
>>> d
|
||||
tensor([1., 2., 3.], grad_fn=<AddBackward0>)
|
||||
|
||||
>>> array = numpy.array([1, 2, 3])
|
||||
>>> # Shares memory with array 'array'
|
||||
>>> t1 = torch.asarray(array)
|
||||
>>> array.__array_interface__['data'][0] == t1.data_ptr()
|
||||
True
|
||||
>>> # Copies memory due to dtype mismatch
|
||||
>>> t2 = torch.asarray(array, dtype=torch.float32)
|
||||
>>> array.__array_interface__['data'][0] == t1.data_ptr()
|
||||
False
|
||||
""")
|
||||
|
||||
add_docstr(torch.baddbmm,
|
||||
r"""
|
||||
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@
|
|||
#include <torch/csrc/utils/tensor_layouts.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
#include <torch/csrc/utils/tensor_qschemes.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
#include <torch/csrc/utils/tensor_numpy.h>
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/utils/crash_handler.h>
|
||||
|
|
@ -374,28 +375,8 @@ PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data)
|
|||
{
|
||||
using namespace torch::autograd;
|
||||
HANDLE_TH_ERRORS
|
||||
DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
|
||||
THPUtils_assert(dlMTensor, "from_dlpack received an invalid capsule. "
|
||||
"Note that DLTensor capsules can be consumed only once, "
|
||||
"so you might have already constructed a tensor from it once.")
|
||||
// atensor steals the ownership of the underlying storage. It also passes a
|
||||
// destructor function that will be called when the underlying storage goes
|
||||
// out of scope. When the destructor is called, the dlMTensor is destructed too.
|
||||
auto atensor = at::fromDLPack(dlMTensor);
|
||||
|
||||
// Make sure this capsule will never be used again.
|
||||
PyCapsule_SetName(data, "used_dltensor");
|
||||
|
||||
// It is possible that the call to at::fromDLPack is the very first
|
||||
// call to create a Tensor in PyTorch. If so, then _lazy_init has
|
||||
// not been called, and the attempt to call createPyObject will fail
|
||||
// because cuda ATen types have not been registered in Python yet.
|
||||
// so if we have a cuda tensor, then we need to make sure
|
||||
// we have called _lazy_init here
|
||||
if(atensor.is_cuda()) {
|
||||
py::module::import("torch.cuda").attr("init")();
|
||||
}
|
||||
return THPVariable_Wrap(std::move(atensor));
|
||||
auto tensor = torch::utils::tensor_fromDLPack(data);
|
||||
return THPVariable_Wrap(tensor);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -450,14 +450,15 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje
|
|||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)",
|
||||
}, /*traceable=*/false);
|
||||
|
||||
PyObject* ret = nullptr;
|
||||
ParsedArgs<5> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
|
|
@ -468,76 +469,35 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje
|
|||
auto offset = r.toInt64(3);
|
||||
auto requires_grad = r.toBool(4);
|
||||
|
||||
auto elsize = at::elementSize(dtype);
|
||||
size_t actual_count = 0;
|
||||
Py_buffer view;
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
PyObject_CheckBuffer(buffer) != 0,
|
||||
"object does not implement Python buffer protocol.");
|
||||
|
||||
if (PyObject_GetBuffer(buffer, &view, PyBUF_WRITABLE) < 0) {
|
||||
TORCH_CHECK(
|
||||
PyObject_GetBuffer(buffer, &view, PyBUF_SIMPLE) >= 0,
|
||||
"could not retrieve buffer from object");
|
||||
TORCH_WARN_ONCE(
|
||||
"The given buffer is not writable, and PyTorch does "
|
||||
"not support non-writable tensors. This means you can write to the "
|
||||
"underlying (supposedly non-writable) buffer using the tensor. "
|
||||
"You may want to copy the buffer to protect its data or make it writable "
|
||||
"before converting it to a tensor. This type of warning will be "
|
||||
"suppressed for the rest of this program.");
|
||||
PyErr_Clear();
|
||||
}
|
||||
|
||||
Py_INCREF(view.obj);
|
||||
THPObjectPtr obj(view.obj);
|
||||
|
||||
auto len = view.len;
|
||||
auto buf = view.buf;
|
||||
PyBuffer_Release(&view);
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
len > 0 && count != 0,
|
||||
"both buffer length (", len, ") and count (", count, ") must not be 0");
|
||||
TORCH_CHECK_VALUE(
|
||||
offset >= 0 && offset < len,
|
||||
"offset (", offset, " bytes) must be non-negative and no greater than "
|
||||
"buffer length (", len, " bytes) minus 1");
|
||||
TORCH_CHECK_VALUE(
|
||||
count > 0 || (len - offset) % elsize == 0,
|
||||
"buffer length (", len - offset, " bytes) after offset (", offset, " bytes) "
|
||||
"must be a multiple of element size (", elsize, ")");
|
||||
|
||||
if (count < 0) {
|
||||
actual_count = (len - offset) / elsize;
|
||||
} else {
|
||||
actual_count = static_cast<size_t>(count);
|
||||
}
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<size_t>(offset) + actual_count * elsize <= len,
|
||||
"requested buffer length (", actual_count, " * ", elsize, " bytes) "
|
||||
"after offset (", offset, " bytes) must not be greater than actual "
|
||||
"buffer length (", len, " bytes)");
|
||||
|
||||
auto offset_buf = static_cast<char*>(buf) + offset;
|
||||
auto options = TensorOptions()
|
||||
.dtype(dtype)
|
||||
.device(c10::kCPU);
|
||||
|
||||
auto tensor = at::for_blob(offset_buf, static_cast<int64_t>(actual_count))
|
||||
.options(options)
|
||||
.deleter([obj = obj.release()](void*) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
Py_DECREF(obj);
|
||||
})
|
||||
.make_tensor();
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
ret = wrap(tensor);
|
||||
return wrap(torch::utils::tensor_frombuffer(
|
||||
buffer, dtype, count, offset, requires_grad));
|
||||
}
|
||||
|
||||
return ret;
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_asarray(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)",
|
||||
}, /*traceable=*/false);
|
||||
|
||||
ParsedArgs<5> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
if (r.idx == 0) {
|
||||
auto obj = r.pyobject(0);
|
||||
auto dtype = r.scalartypeOptional(1);
|
||||
auto device = r.deviceOptional(2);
|
||||
auto copy = r.toBoolOptional(3);
|
||||
auto requires_grad = r.toBool(4);
|
||||
return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad));
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
|
@ -648,6 +608,8 @@ static PyObject * THPVariable_logspace(PyObject* self_, PyObject* args, PyObject
|
|||
static PyMethodDef torch_functions_manual[] = {
|
||||
{"arange", castPyCFunctionWithKeywords(THPVariable_arange),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"asarray", castPyCFunctionWithKeywords(THPVariable_asarray),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr},
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/TracerMode.h>
|
||||
|
|
@ -978,4 +980,195 @@ Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyO
|
|||
throw std::runtime_error("new_tensor(): invalid arguments");
|
||||
}
|
||||
|
||||
Tensor tensor_frombuffer(PyObject* buffer, ScalarType dtype, int64_t count, int64_t offset, bool requires_grad) {
|
||||
auto elsize = at::elementSize(dtype);
|
||||
size_t actual_count = 0;
|
||||
|
||||
Py_buffer view;
|
||||
if (PyObject_GetBuffer(buffer, &view, PyBUF_WRITABLE) < 0) {
|
||||
TORCH_CHECK(
|
||||
PyObject_GetBuffer(buffer, &view, PyBUF_SIMPLE) >= 0,
|
||||
"could not retrieve buffer from object");
|
||||
TORCH_WARN_ONCE(
|
||||
"The given buffer is not writable, and PyTorch does "
|
||||
"not support non-writable tensors. This means you can write to the "
|
||||
"underlying (supposedly non-writable) buffer using the tensor. "
|
||||
"You may want to copy the buffer to protect its data or make it writable "
|
||||
"before converting it to a tensor. This type of warning will be "
|
||||
"suppressed for the rest of this program.");
|
||||
PyErr_Clear();
|
||||
}
|
||||
|
||||
Py_INCREF(view.obj);
|
||||
THPObjectPtr obj(view.obj);
|
||||
|
||||
auto len = view.len;
|
||||
auto buf = view.buf;
|
||||
PyBuffer_Release(&view);
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
len > 0 && count != 0,
|
||||
"both buffer length (", len, ") and count (", count, ") must not be 0");
|
||||
TORCH_CHECK_VALUE(
|
||||
offset >= 0 && offset < len,
|
||||
"offset (", offset, " bytes) must be non-negative and no greater than "
|
||||
"buffer length (", len, " bytes) minus 1");
|
||||
TORCH_CHECK_VALUE(
|
||||
count > 0 || (len - offset) % elsize == 0,
|
||||
"buffer length (", len - offset, " bytes) after offset (", offset, " bytes) "
|
||||
"must be a multiple of element size (", elsize, ")");
|
||||
|
||||
if (count < 0) {
|
||||
actual_count = (len - offset) / elsize;
|
||||
} else {
|
||||
actual_count = static_cast<size_t>(count);
|
||||
}
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<size_t>(offset) + actual_count * elsize <= len,
|
||||
"requested buffer length (", actual_count, " * ", elsize, " bytes) "
|
||||
"after offset (", offset, " bytes) must not be greater than actual "
|
||||
"buffer length (", len, " bytes)");
|
||||
|
||||
auto offset_buf = static_cast<char*>(buf) + offset;
|
||||
auto options = TensorOptions()
|
||||
.dtype(dtype)
|
||||
.device(c10::kCPU);
|
||||
|
||||
auto tensor = at::for_blob(offset_buf, static_cast<int64_t>(actual_count))
|
||||
.options(options)
|
||||
.deleter([obj = obj.release()](void*) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
Py_DECREF(obj);
|
||||
})
|
||||
.make_tensor();
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
Tensor tensor_fromDLPack(PyObject *data) {
|
||||
DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
|
||||
TORCH_CHECK(dlMTensor,
|
||||
"from_dlpack received an invalid capsule. "
|
||||
"Note that DLTensor capsules can be consumed only once, "
|
||||
"so you might have already constructed a tensor from it once.");
|
||||
|
||||
// atensor steals the ownership of the underlying storage. It also passes a
|
||||
// destructor function that will be called when the underlying storage goes
|
||||
// out of scope. When the destructor is called, the dlMTensor is destructed too.
|
||||
auto atensor = at::fromDLPack(dlMTensor);
|
||||
|
||||
// Make sure this capsule will never be used again.
|
||||
PyCapsule_SetName(data, "used_dltensor");
|
||||
|
||||
// It is possible that the call to at::fromDLPack is the very first
|
||||
// call to create a Tensor in PyTorch. If so, then _lazy_init has
|
||||
// not been called, and the attempt to call createPyObject will fail
|
||||
// because cuda ATen types have not been registered in Python yet.
|
||||
// so if we have a cuda tensor, then we need to make sure
|
||||
// we have called _lazy_init here
|
||||
if(atensor.is_cuda()) {
|
||||
py::module::import("torch.cuda").attr("init")();
|
||||
}
|
||||
return atensor;
|
||||
}
|
||||
|
||||
Tensor asarray(
|
||||
PyObject* obj,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> copy,
|
||||
bool requires_grad) {
|
||||
Tensor tensor;
|
||||
|
||||
bool force_copy = copy.value_or(false);
|
||||
bool force_alias = !copy.value_or(true);
|
||||
bool should_warn_numpy_not_writable = false;
|
||||
|
||||
auto dtype_unwrapped =
|
||||
dtype.value_or(torch::tensors::get_default_scalar_type());
|
||||
|
||||
// Check whether 'obj' is a 'Tensor'
|
||||
if (THPVariable_Check(obj)) {
|
||||
tensor = THPVariable_Unpack(obj);
|
||||
}
|
||||
|
||||
#ifdef USE_NUMPY
|
||||
// Check whether 'obj' is a NumPy Array
|
||||
if (is_numpy_available() && PyArray_Check(obj)) {
|
||||
tensor = tensor_from_numpy(obj, /*warn_if_not_writeable=*/false);
|
||||
should_warn_numpy_not_writable = !PyArray_ISWRITEABLE((PyArrayObject*) obj);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether 'obj' is a 'DLPack' capsule
|
||||
if (!tensor.defined() && PyCapsule_IsValid(obj, "dltensor") != 0) {
|
||||
tensor = tensor_fromDLPack(obj);
|
||||
}
|
||||
|
||||
// Check whether 'obj' implements the buffer protocol
|
||||
if (!tensor.defined() && PyObject_CheckBuffer(obj) != 0) {
|
||||
tensor = tensor_frombuffer(obj, dtype_unwrapped, -1, 0, requires_grad);
|
||||
}
|
||||
|
||||
if (tensor.defined()) {
|
||||
// Given an aliasable tensor, should we copy it?
|
||||
bool wrong_device = device.has_value() && device.value() != tensor.device();
|
||||
bool wrong_dtype =
|
||||
dtype.has_value() && dtype.value() != tensor.scalar_type();
|
||||
bool needs_copying = !copy.has_value() && (wrong_device || wrong_dtype);
|
||||
|
||||
// Given a defined tensor, we copy it if either we have to (copy=True) or
|
||||
// if we need to (copy=None) because of mismatched device or dtype.
|
||||
if (force_copy || needs_copying) {
|
||||
if (wrong_device || wrong_dtype) {
|
||||
tensor = tensor.to(
|
||||
device.value_or(tensor.device()),
|
||||
dtype.value_or(tensor.scalar_type()));
|
||||
} else {
|
||||
tensor = tensor.clone();
|
||||
}
|
||||
} else {
|
||||
// If we are not copying, we have to check whther we have the tensor
|
||||
// in the right device, with the right dtype.
|
||||
TORCH_CHECK_VALUE(
|
||||
!wrong_device,
|
||||
"can't alias tensor from device '", tensor.device(),
|
||||
"' to '", device.value(), "'.");
|
||||
TORCH_CHECK_VALUE(
|
||||
!wrong_dtype,
|
||||
"can't alias tensor with dtype '", tensor.scalar_type(),
|
||||
"' into dtype '", dtype.value(), "'.");
|
||||
// If tensor is a NumPy Array view, we warn the user about non-writeable
|
||||
// arrays if this is the case.
|
||||
if (should_warn_numpy_not_writable) {
|
||||
warn_numpy_not_writeable();
|
||||
}
|
||||
}
|
||||
|
||||
// Setting 'requires_grad' when the tensor is not a leaf does not work.
|
||||
// Whenever that happens, we have to use 'detach'.
|
||||
if (!tensor.is_leaf() && !requires_grad) {
|
||||
tensor = tensor.detach();
|
||||
} else {
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
}
|
||||
} else {
|
||||
// Undefined tensor means it does not implement neither DLPack nor
|
||||
// the buffer protocol. Last case is a sequence, in which case we must
|
||||
// copy (copy can't be false).
|
||||
TORCH_CHECK_VALUE(
|
||||
!force_alias, "can't alias arbitrary sequence into a tensor.");
|
||||
|
||||
// Make tensor from sequence, inferring its type, and then convert
|
||||
// it to the desired type.
|
||||
tensor = internal_new_from_data(
|
||||
TensorOptions(), dtype_unwrapped, device, obj, false, false, true);
|
||||
tensor = tensor.to(dtype_unwrapped);
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
}} // namespace torch::utils
|
||||
|
|
|
|||
|
|
@ -23,5 +23,7 @@ at::Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type
|
|||
at::Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
|
||||
at::Tensor tensor_frombuffer(PyObject* buffer, at::ScalarType dtype, int64_t count, int64_t offset, bool requires_grad);
|
||||
at::Tensor tensor_fromDLPack(PyObject *data);
|
||||
at::Tensor asarray(PyObject* obj, c10::optional<c10::ScalarType> dtype, c10::optional<c10::Device> device, c10::optional<bool> copy, bool requires_grad);
|
||||
}} // namespace torch::utils
|
||||
|
|
|
|||
|
|
@ -168,6 +168,16 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
|
|||
return array.release();
|
||||
}
|
||||
|
||||
void warn_numpy_not_writeable() {
|
||||
TORCH_WARN_ONCE(
|
||||
"The given NumPy array is not writeable, and PyTorch does "
|
||||
"not support non-writeable tensors. This means you can write to the "
|
||||
"underlying (supposedly non-writeable) NumPy array using the tensor. "
|
||||
"You may want to copy the array to protect its data or make it writeable "
|
||||
"before converting it to a tensor. This type of warning will be "
|
||||
"suppressed for the rest of this program.");
|
||||
}
|
||||
|
||||
at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) {
|
||||
if (!is_numpy_available()) {
|
||||
throw std::runtime_error("Numpy is not available");
|
||||
|
|
@ -180,14 +190,7 @@ at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/)
|
|||
// warn_if_not_writable is true when a copy of numpy variable is created.
|
||||
// the warning is suppressed when a copy is being created.
|
||||
if (!PyArray_ISWRITEABLE(array) && warn_if_not_writeable) {
|
||||
TORCH_WARN_ONCE(
|
||||
"The given NumPy array is not writeable, and PyTorch does "
|
||||
"not support non-writeable tensors. This means you can write to the "
|
||||
"underlying (supposedly non-writeable) NumPy array using the tensor. "
|
||||
"You may want to copy the array to protect its data or make it writeable "
|
||||
"before converting it to a tensor. This type of warning will be "
|
||||
"suppressed for the rest of this program.");
|
||||
|
||||
warn_numpy_not_writeable();
|
||||
}
|
||||
|
||||
int ndim = PyArray_NDIM(array);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ bool is_numpy_available();
|
|||
bool is_numpy_int(PyObject* obj);
|
||||
bool is_numpy_scalar(PyObject* obj);
|
||||
|
||||
void warn_numpy_not_writeable();
|
||||
at::Tensor tensor_from_cuda_array_interface(PyObject* obj);
|
||||
|
||||
}} // namespace torch::utils
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.set_vital,
|
||||
torch.read_vitals,
|
||||
torch.frombuffer,
|
||||
torch.asarray,
|
||||
Tensor.__delitem__,
|
||||
Tensor.__dir__,
|
||||
Tensor.__getattribute__,
|
||||
|
|
|
|||
Loading…
Reference in a new issue