mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143918 Approved by: https://github.com/Skylion007
33 lines
734 B
Python
33 lines
734 B
Python
"""This module converts objects into numpy array."""
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
def make_np(x: torch.Tensor) -> np.ndarray:
|
|
"""
|
|
Convert an object into numpy array.
|
|
|
|
Args:
|
|
x: An instance of torch tensor
|
|
|
|
Returns:
|
|
numpy.array: Numpy array
|
|
"""
|
|
if isinstance(x, np.ndarray):
|
|
return x
|
|
if np.isscalar(x):
|
|
return np.array([x])
|
|
if isinstance(x, torch.Tensor):
|
|
return _prepare_pytorch(x)
|
|
raise NotImplementedError(
|
|
f"Got {type(x)}, but numpy array or torch tensor are expected."
|
|
)
|
|
|
|
|
|
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
|
|
if x.dtype == torch.bfloat16:
|
|
x = x.to(torch.float16)
|
|
x = x.detach().cpu().numpy()
|
|
return x
|