ENH: Add a force argument to numpy() (#78564)

**Reopened** to help with merge issues. See #59790 for full context.

Fixes #20778. Helps #71688.

Finalizes @martinPasen's force argument for `Tensor.numpy()`. It is set to False by default. If it's set to True then we:
1. detatch the Tensor, if requires_grad == True
2. move to cpu, if not on cpu already
3. Uses .resolve_conj() if .is_conj() == True
4. Uses .resolve_neg() if .is_neg() == True

cc @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78564
Approved by: https://github.com/albanD
This commit is contained in:
Rohit Goswami 2022-06-06 14:14:15 +00:00 committed by PyTorch MergeBot
parent 5a95b20d0f
commit 3f58dd18dc
6 changed files with 91 additions and 38 deletions

View file

@ -156,6 +156,31 @@ class TestNumPyInterop(TestCase):
self.assertEqual(y.dtype, np.bool_)
self.assertEqual(x[0], y[0])
def test_to_numpy_force_argument(self, device) -> None:
for force in [False, True]:
for requires_grad in [False, True]:
for sparse in [False, True]:
for conj in [False, True]:
data = [[1 + 2j, -2 + 3j], [-1 - 2j, 3 - 2j]]
x = torch.tensor(data, requires_grad=requires_grad, device=device)
y = x
if sparse:
if requires_grad:
continue
x = x.to_sparse()
if conj:
x = x.conj()
y = x.resolve_conj()
expect_error = requires_grad or sparse or conj or not device == 'cpu'
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
if not force and expect_error:
self.assertRaisesRegex((RuntimeError, TypeError), error_msg, lambda: x.numpy())
self.assertRaisesRegex((RuntimeError, TypeError), error_msg, lambda: x.numpy(force=False))
elif force and sparse:
self.assertRaisesRegex(TypeError, error_msg, lambda: x.numpy(force=True))
else:
self.assertEqual(x.numpy(force=force), y)
def test_from_numpy(self, device) -> None:
dtypes = [
np.double,

View file

@ -791,15 +791,22 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
// implemented on the python object bc PyObjects not declarable in native_functions.yaml
// See: ATen/native/README.md for more context
static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
static PyObject * THPVariable_numpy(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "numpy");
}
jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
static PythonArgParser parser({
"numpy(*, bool force=False)"
});
auto& self_ = THPVariable_Unpack(self);
return torch::utils::tensor_to_numpy(self_);
ParsedArgs<1> parsed_args;
auto r = parser.parse(self, args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
return torch::utils::tensor_to_numpy(self_, r.toBool(0));
END_HANDLE_TH_ERRORS
}
@ -1271,7 +1278,7 @@ PyMethodDef variable_methods[] = {
{"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, NULL},
{"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, NULL},
{"numel", THPVariable_numel, METH_NOARGS, NULL},
{"numpy", THPVariable_numpy, METH_NOARGS, NULL},
{"numpy", castPyCFunctionWithKeywords(THPVariable_numpy), METH_VARARGS | METH_KEYWORDS, NULL},
{"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, NULL},
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},

View file

@ -642,7 +642,7 @@ def gen_pyi(
"cuda": [
"def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..."
],
"numpy": ["def numpy(self) -> Any: ..."],
"numpy": ["def numpy(self, *, force: _bool=False) -> Any: ..."],
"apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
"map_": [
"def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."

View file

@ -2841,11 +2841,26 @@ See :func:`torch.numel`
add_docstr_all('numpy',
r"""
numpy() -> numpy.ndarray
numpy(*, force=False) -> numpy.ndarray
Returns :attr:`self` tensor as a NumPy :class:`ndarray`. This tensor and the
returned :class:`ndarray` share the same underlying storage. Changes to
:attr:`self` tensor will be reflected in the :class:`ndarray` and vice versa.
Returns the tensor as a NumPy :class:`ndarray`.
If :attr:`force` is ``False`` (the default), the conversion
is performed only if the tensor is on the CPU, does not require grad,
does not have its conjugate bit set, and is a dtype and layout that
NumPy supports. The returned ndarray and the tensor will share their
storage, so changes to the tensor will be reflected in the ndarray
and vice versa.
If :attr:`force` is ``True`` this is equivalent to
calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``.
If the tensor isn't on the CPU or the conjugate or negative bit is set,
the tensor won't share its storage with the returned ndarray.
Setting :attr:`force` to ``True`` can be a useful shorthand.
Args:
force (bool): if ``True``, the ndarray may be a copy of the tensor
instead of always sharing memory, defaults to ``False``.
""")
add_docstr_all('orgqr',

View file

@ -105,49 +105,55 @@ static std::vector<int64_t> seq_to_aten_shape(PyObject *py_seq) {
return result;
}
PyObject* tensor_to_numpy(const at::Tensor& tensor) {
PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force/*=false*/) {
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
TORCH_CHECK_TYPE(tensor.device().type() == DeviceType::CPU,
"can't convert ", tensor.device().str().c_str(),
" device type tensor to numpy. Use Tensor.cpu() to ",
"copy the tensor to host memory first.");
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(),
".numpy() is not supported for tensor subclasses.");
TORCH_CHECK_TYPE(tensor.layout() == Layout::Strided,
"can't convert ", c10::str(tensor.layout()).c_str(),
" layout tensor to numpy.",
"convert the tensor to a strided layout first.");
" layout tensor to numpy. ",
"Use Tensor.dense() first.");
TORCH_CHECK(!(at::GradMode::is_enabled() && tensor.requires_grad()),
"Can't call numpy() on Tensor that requires grad. "
"Use tensor.detach().numpy() instead.");
if (!force){
TORCH_CHECK_TYPE(tensor.device().type() == DeviceType::CPU,
"can't convert ", tensor.device().str().c_str(),
" device type tensor to numpy. Use Tensor.cpu() to ",
"copy the tensor to host memory first.");
TORCH_CHECK(!tensor.is_conj(),
"Can't call numpy() on Tensor that has conjugate bit set. ",
"Use tensor.resolve_conj().numpy() instead.");
TORCH_CHECK(!(at::GradMode::is_enabled() && tensor.requires_grad()),
"Can't call numpy() on Tensor that requires grad. "
"Use tensor.detach().numpy() instead.");
TORCH_CHECK(!tensor.is_neg(),
"Can't call numpy() on Tensor that has negative bit set. "
"Use tensor.resolve_neg().numpy() instead.");
TORCH_CHECK(!tensor.is_conj(),
"Can't call numpy() on Tensor that has conjugate bit set. ",
"Use tensor.resolve_conj().numpy() instead.");
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");
TORCH_CHECK(!tensor.is_neg(),
"Can't call numpy() on Tensor that has negative bit set. "
"Use tensor.resolve_neg().numpy() instead.");
}
auto prepared_tensor = tensor.detach().cpu().resolve_conj().resolve_neg();
auto dtype = aten_to_numpy_dtype(prepared_tensor.scalar_type());
auto sizes = to_numpy_shape(prepared_tensor.sizes());
auto strides = to_numpy_shape(prepared_tensor.strides());
auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
auto sizes = to_numpy_shape(tensor.sizes());
auto strides = to_numpy_shape(tensor.strides());
// NumPy strides use bytes. Torch strides use element counts.
auto element_size_in_bytes = tensor.element_size();
auto element_size_in_bytes = prepared_tensor.element_size();
for (auto& stride : strides) {
stride *= element_size_in_bytes;
}
auto array = THPObjectPtr(PyArray_New(
&PyArray_Type,
tensor.dim(),
prepared_tensor.dim(),
sizes.data(),
dtype,
strides.data(),
tensor.data_ptr(),
prepared_tensor.data_ptr(),
0,
NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE,
nullptr));
@ -157,13 +163,13 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
// object of the ndarray to the tensor and disabling resizes on the storage.
// This is not sufficient. For example, the tensor's storage may be changed
// via Tensor.set_, which can free the underlying memory.
PyObject* py_tensor = THPVariable_Wrap(tensor);
PyObject* py_tensor = THPVariable_Wrap(prepared_tensor);
if (!py_tensor) throw python_error();
if (PyArray_SetBaseObject((PyArrayObject*)array.get(), py_tensor) == -1) {
return nullptr;
}
// Use the private storage API
tensor.storage().unsafeGetStorageImpl()->set_resizable(false);
prepared_tensor.storage().unsafeGetStorageImpl()->set_resizable(false);
return array.release();
}

View file

@ -5,7 +5,7 @@
namespace torch { namespace utils {
PyObject* tensor_to_numpy(const at::Tensor& tensor);
PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force=false);
at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true);
int aten_to_numpy_dtype(const at::ScalarType scalar_type);