mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Original RFC https://github.com/pytorch/pytorch/issues/19092 To ensure that we are not introducing BC breaking change, empty_like returns contiguous tensor by default. ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nCwh = torch.empty_like(nhwC) new_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Now we need a way to preserve memory format in `empty_like` ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nhwC = torch.empty_like(nhwC, memory_format=torch.preserve_format) new_nhwC.is_contiguous(memory_format=torch.channels_last) == True like_nCwh = torch.empty_like(nCwh, memory_format=torch.preserve_format) like_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Usage of `torch.preserve_format` allows us to avoid `if` constructs. We can also generate different memory format outputs ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nhwC = torch.empty_like(nCwh, memory_format=torch.channels_last) new_nhwC.is_contiguous(memory_format=torch.channels_last) == True new_nCwh = torch.empty_like(nhwC, memory_format=torch.contiguous_format) new_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20558 Differential Revision: D15502474 Pulled By: VitalyFedyunin fbshipit-source-id: 2e120d57eefad6fb8e04b8322c79871392f64331
92 lines
2.7 KiB
C++
92 lines
2.7 KiB
C++
#include <torch/extension.h>
|
|
|
|
#include <ATen/Type.h>
|
|
#include <ATen/core/VariableHooksInterface.h>
|
|
#include <ATen/detail/ComplexHooksInterface.h>
|
|
|
|
#include <ATen/CPUTypeDefault.h>
|
|
#include <c10/core/Allocator.h>
|
|
#include <ATen/CPUGenerator.h>
|
|
#include <ATen/DeviceGuard.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <ATen/Utils.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <c10/util/Half.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <c10/core/UndefinedTensorImpl.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <ATen/core/ATenDispatch.h>
|
|
|
|
#include <cstddef>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <utility>
|
|
|
|
#include <ATen/Config.h>
|
|
|
|
namespace at {
|
|
|
|
struct ComplexCPUType : public at::CPUTypeDefault {
|
|
ComplexCPUType()
|
|
: CPUTypeDefault(
|
|
ComplexCPUTensorId(),
|
|
/*is_variable=*/false,
|
|
/*is_undefined=*/false) {}
|
|
|
|
Backend backend() const override;
|
|
const char* toString() const override;
|
|
TypeID ID() const override;
|
|
|
|
static Tensor empty(IntArrayRef size, const TensorOptions & options, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
|
TORCH_CHECK(!optional_memory_format.has_value(), "memory format is not supported")
|
|
AT_ASSERT(options.device().is_cpu());
|
|
|
|
for (auto x: size) {
|
|
TORCH_CHECK(x >= 0, "Trying to create tensor using size with negative dimension: ", size);
|
|
}
|
|
auto* allocator = at::getCPUAllocator();
|
|
int64_t nelements = at::prod_intlist(size);
|
|
auto dtype = options.dtype();
|
|
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
|
dtype,
|
|
nelements,
|
|
allocator->allocate(nelements * dtype.itemsize()),
|
|
allocator,
|
|
/*resizable=*/true);
|
|
|
|
auto tensor = detail::make_tensor<TensorImpl>(storage_impl, at::ComplexCPUTensorId());
|
|
// Default TensorImpl has size [0]
|
|
if (size.size() != 1 || size[0] != 0) {
|
|
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
|
}
|
|
return tensor;
|
|
}
|
|
};
|
|
|
|
struct ComplexHooks : public at::ComplexHooksInterface {
|
|
ComplexHooks(ComplexHooksArgs) {}
|
|
void registerComplexTypes(Context* context) const override {
|
|
context->registerType(Backend::ComplexCPU, new ComplexCPUType());
|
|
}
|
|
};
|
|
|
|
Backend ComplexCPUType::backend() const {
|
|
return Backend::ComplexCPU;
|
|
}
|
|
|
|
const char* ComplexCPUType::toString() const {
|
|
return "ComplexCPUType";
|
|
}
|
|
|
|
TypeID ComplexCPUType::ID() const {
|
|
return TypeID::ComplexCPU;
|
|
}
|
|
|
|
static auto& complex_empty_registration = globalATenDispatch()
|
|
.registerOp(Backend::ComplexCPU, "aten::empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", &ComplexCPUType::empty);
|
|
|
|
REGISTER_COMPLEX_HOOKS(ComplexHooks);
|
|
|
|
} // namespace at
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }
|