pytorch/test/cpp_extensions/complex_registration_extension.cpp
Roy Li 24a6c32407 Replace Type dispatch with ATenDispatch (#21320)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21320
ghimport-source-id: cc18f746a1c74df858cb0f6d8b7d4de4315683c7

Test Plan: Imported from OSS

Differential Revision: D15637222

Pulled By: li-roy

fbshipit-source-id: fcfaea0b5480ab966175341cce92e3aa0be7e3cb
2019-06-19 15:46:45 -07:00

91 lines
2.6 KiB
C++

#include <torch/extension.h>
#include <ATen/Type.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/detail/ComplexHooksInterface.h>
#include <ATen/CPUTypeDefault.h>
#include <c10/core/Allocator.h>
#include <ATen/CPUGenerator.h>
#include <ATen/DeviceGuard.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Half.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/Optional.h>
#include <ATen/core/ATenDispatch.h>
#include <cstddef>
#include <functional>
#include <memory>
#include <utility>
#include <ATen/Config.h>
namespace at {
struct ComplexCPUType : public at::CPUTypeDefault {
ComplexCPUType()
: CPUTypeDefault(
ComplexCPUTensorId(),
/*is_variable=*/false,
/*is_undefined=*/false) {}
Backend backend() const override;
const char* toString() const override;
TypeID ID() const override;
static Tensor empty(IntArrayRef size, const TensorOptions & options) {
AT_ASSERT(options.device().is_cpu());
for (auto x: size) {
TORCH_CHECK(x >= 0, "Trying to create tensor using size with negative dimension: ", size);
}
auto* allocator = at::getCPUAllocator();
int64_t nelements = at::prod_intlist(size);
auto dtype = options.dtype();
auto storage_impl = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizable=*/true);
auto tensor = detail::make_tensor<TensorImpl>(storage_impl, at::ComplexCPUTensorId());
// Default TensorImpl has size [0]
if (size.size() != 1 || size[0] != 0) {
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
}
return tensor;
}
};
struct ComplexHooks : public at::ComplexHooksInterface {
ComplexHooks(ComplexHooksArgs) {}
void registerComplexTypes(Context* context) const override {
context->registerType(Backend::ComplexCPU, new ComplexCPUType());
}
};
Backend ComplexCPUType::backend() const {
return Backend::ComplexCPU;
}
const char* ComplexCPUType::toString() const {
return "ComplexCPUType";
}
TypeID ComplexCPUType::ID() const {
return TypeID::ComplexCPU;
}
static auto& complex_empty_registration = globalATenDispatch()
.registerOp(Backend::ComplexCPU, "aten::empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", &ComplexCPUType::empty);
REGISTER_COMPLEX_HOOKS(ComplexHooks);
} // namespace at
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }