pytorch/test/cpp_extensions/complex_registration_extension.cpp
Pavel Belevich 62b06b9fae Rename TensorTypeId to DispatchKey (#32154)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32154

TensorTypeId -> DispatchKey
	c10/core/TensorTypeId.h -> c10/core/DispatchKey.h
	c10/core/TensorTypeId.cpp -> c10/core/DispatchKey.cpp
	TensorTypeId::* -> DispatchKey::*
	TensorTypeId type_id -> DispatchKey dispatch_key
		type_id -> dispatch_key
	TensorTypeId::NumTensorIds -> DispatchKey::NumDispatchKeys
	RealTensorTypeId -> RealDispatchKey
TensorTypeSet -> DispatchKeySet
	TensorTypeIds -> DispatchKeys
	c10/core/TensorTypeSet.h -> c10/core/DispatchKeySet.h
	c10/core/TensorTypeSet.cpp -> c10/core/DispatchKeySet.cpp
	type_set() -> key_set()
	type_set_ -> key_set_
	typeSet -> keySet
ExcludeTensorTypeIdGuard -> ExcludeDispatchKeyGuard
IncludeTensorTypeIdGuard -> IncludeDispatchKeyGuard
LocalTensorTypeSet -> LocalDispatchKeySet
	c10/core/impl/LocalTensorTypeSet.h -> c10/core/impl/LocalDispatchKeySet.h
	c10/core/impl/LocalTensorTypeSet.cpp -> c10/core/impl/LocalDispatchKeySet.cpp
	tls_local_tensor_type_set -> tls_local_dispatch_key_set
	tls_is_tensor_type_id_excluded -> tls_is_dispatch_key_excluded
	tls_set_tensor_type_id_excluded -> tls_set_dispatch_key_excluded
	tls_is_tensor_type_id_included -> tls_is_dispatch_key_included
	tls_set_tensor_type_id_included -> tls_set_dispatch_key_included
MultiDispatchTensorTypeSet -> MultiDispatchKeySet
	multi_dispatch_tensor_type_set -> multi_dispatch_key_set
tensorTypeIdToBackend -> dispatchKeyToBackend
backendToTensorTypeId -> backendToDispatchKey
initForTensorTypeSet -> initForDispatchKeySet
inferred_type_set -> inferred_key_set
computeTensorTypeId -> computeDispatchKey
PODLocalTensorTypeSet raw_local_tensor_type_set -> PODLocalDispatchKeySet raw_local_dispatch_key_set
get_default_tensor_type_id -> get_default_dispatch_key
inferred_type_id -> inferred_dispatch_key
actual_type_id -> actual_dispatch_key
typeSetToDispatchKey_ -> dispatchKeySetToDispatchKey_
get_type_id() -> get_dispatch_key()
legacyExtractTypeId -> legacyExtractDispatchKey
extractTypeId -> extractDispatchKey

Test Plan: Imported from OSS

Differential Revision: D19398900

Pulled By: pbelevich

fbshipit-source-id: 234ad19f93d33e00201b61e153b740a339035776
2020-01-15 11:16:08 -08:00

58 lines
1.9 KiB
C++

#include <torch/extension.h>
#include <c10/core/Allocator.h>
#include <ATen/CPUGenerator.h>
#include <ATen/DeviceGuard.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Half.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/Optional.h>
#include <ATen/core/op_registration/op_registration.h>
#include <cstddef>
#include <functional>
#include <memory>
#include <utility>
#include <ATen/Config.h>
namespace at {
namespace {
Tensor empty_complex(IntArrayRef size, const TensorOptions & options, c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(!optional_memory_format.has_value(), "memory format is not supported")
AT_ASSERT(options.device().is_cpu());
for (auto x: size) {
TORCH_CHECK(x >= 0, "Trying to create tensor using size with negative dimension: ", size);
}
auto* allocator = at::getCPUAllocator();
int64_t nelements = at::prod_intlist(size);
auto dtype = options.dtype();
auto storage_impl = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizable=*/true);
auto tensor = detail::make_tensor<TensorImpl>(storage_impl, at::DispatchKey::ComplexCPUTensorId);
// Default TensorImpl has size [0]
if (size.size() != 1 || size[0] != 0) {
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
}
return tensor;
}
}
static auto complex_empty_registration = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
.impl_unboxedOnlyKernel<decltype(empty_complex), &empty_complex>(DispatchKey::ComplexCPUTensorId)
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }