2022-07-11 22:11:58 +00:00
|
|
|
#include <c10/core/impl/alloc_cpu.h>
|
|
|
|
|
#include <c10/core/Allocator.h>
|
|
|
|
|
|
|
|
|
|
#include <torch/csrc/Device.h>
|
2023-04-17 19:18:35 +00:00
|
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
|
|
|
#include <c10/macros/Macros.h>
|
2022-07-11 22:11:58 +00:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
|
|
#include <ATen/native/cpu/Loops.h>
|
|
|
|
|
#include <ATen/native/DispatchStub.h>
|
2023-04-17 19:18:35 +00:00
|
|
|
#include <ATen/native/Resize.h>
|
2022-07-11 22:11:58 +00:00
|
|
|
#include <ATen/EmptyTensor.h>
|
2023-04-08 05:32:21 +00:00
|
|
|
#include <ATen/core/GeneratorForPrivateuseone.h>
|
2022-07-11 22:11:58 +00:00
|
|
|
|
|
|
|
|
static uint64_t add_counter = 0;
|
|
|
|
|
static uint64_t last_saved_value = 0;
|
|
|
|
|
|
2023-04-17 19:18:35 +00:00
|
|
|
// register guard
|
|
|
|
|
namespace at {
|
|
|
|
|
namespace detail {
|
|
|
|
|
|
|
|
|
|
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
|
|
|
|
|
|
|
|
|
|
}} // namespace at::detail
|
|
|
|
|
|
2022-07-11 22:11:58 +00:00
|
|
|
// basic dummy add function
|
|
|
|
|
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
|
|
|
|
|
add_counter += 1;
|
|
|
|
|
// Since this custom device is just for testing, not bothering to implement kernels.
|
|
|
|
|
return at::empty(self.sizes(), self.options());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// A dummy allocator for our custom device, that secretly uses the CPU
|
|
|
|
|
struct DummyCustomAllocator final : at::Allocator {
|
|
|
|
|
DummyCustomAllocator() = default;
|
|
|
|
|
at::DataPtr allocate(size_t nbytes) const override {
|
|
|
|
|
void* data = c10::alloc_cpu(nbytes);
|
|
|
|
|
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ReportAndDelete(void* ptr) {
|
|
|
|
|
if (!ptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
c10::free_cpu(ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::DeleterFnPtr raw_deleter() const override {
|
|
|
|
|
return &ReportAndDelete;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Register our dummy allocator
|
|
|
|
|
static DummyCustomAllocator global_custom_alloc;
|
|
|
|
|
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
|
|
|
|
|
|
|
|
|
|
// basic dummy empty function, so we can directly construct tensors on the custom device
|
|
|
|
|
// This dummy test device will just use the CPU allocator, and ignores pinned memory.
|
|
|
|
|
at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
|
|
|
|
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
|
|
|
|
|
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
|
|
|
|
|
}
|
Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.
With this change, I accept old type registrations transparently. The
idea is in several parts:
- At the registration site, at compile time I have no idea whether or not
if the function being registered has a SymInt schema or not. So I
must defer the exact compatibility check. What I do instead is
check if the function pointer registered to me has SymInt in the
argument or not. If it does, I assume it is new-style and ensure
it is also registered to a special sym_ slot on KernelFunction.
If not, it only goes in the conventional slot.
- At the dispatcher site, I know at compile time whether or not this
is a SymInt function. If it is, I check for a sym_ slot on the
KernelFunction, and preferentially use that. If no such slot
exists, I then fall back to the regular slot... but I convert
all SymInt arguments to int64_t arguments (doing assertions that
no true symbolic integer was passed.) I can skip this test entirely
if the function doesn't have any SymInts in it; in that case I know
that only the original slot could have been registered. Fortunately,
both branches of the short circuit typecheck, so I didn't have to
use SFINAE or if-constexpr to make it work; just a plain if statement
that I expect the compiler to optimize away.
- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.
To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.
I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557
Approved by: https://github.com/wconstab
2022-09-07 12:58:32 +00:00
|
|
|
at::Tensor custom_empty_symint(c10::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
2022-07-11 22:11:58 +00:00
|
|
|
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
|
Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.
With this change, I accept old type registrations transparently. The
idea is in several parts:
- At the registration site, at compile time I have no idea whether or not
if the function being registered has a SymInt schema or not. So I
must defer the exact compatibility check. What I do instead is
check if the function pointer registered to me has SymInt in the
argument or not. If it does, I assume it is new-style and ensure
it is also registered to a special sym_ slot on KernelFunction.
If not, it only goes in the conventional slot.
- At the dispatcher site, I know at compile time whether or not this
is a SymInt function. If it is, I check for a sym_ slot on the
KernelFunction, and preferentially use that. If no such slot
exists, I then fall back to the regular slot... but I convert
all SymInt arguments to int64_t arguments (doing assertions that
no true symbolic integer was passed.) I can skip this test entirely
if the function doesn't have any SymInts in it; in that case I know
that only the original slot could have been registered. Fortunately,
both branches of the short circuit typecheck, so I didn't have to
use SFINAE or if-constexpr to make it work; just a plain if statement
that I expect the compiler to optimize away.
- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.
To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.
I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557
Approved by: https://github.com/wconstab
2022-09-07 12:58:32 +00:00
|
|
|
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
|
2022-07-11 22:11:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
|
|
|
|
|
// Not bothering to implement.
|
|
|
|
|
return self;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// basic dummy copy_() function, so we can copy from the custom device to/from CPU
|
|
|
|
|
at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
|
|
|
|
|
TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
|
|
|
|
|
TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
|
|
|
|
|
|
|
|
|
|
// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
|
|
|
|
|
TORCH_CHECK(self.sizes() == dst.sizes());
|
|
|
|
|
TORCH_CHECK(self.scalar_type() == dst.scalar_type());
|
|
|
|
|
TORCH_CHECK(self.is_contiguous() && dst.is_contiguous());
|
|
|
|
|
|
|
|
|
|
std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes());
|
|
|
|
|
return dst;
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-13 22:04:01 +00:00
|
|
|
at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
|
|
|
|
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
|
|
|
|
|
auto dtype = c10::dtype_or_default(dtype_opt);
|
|
|
|
|
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
|
|
|
|
|
}
|
2022-07-11 22:11:58 +00:00
|
|
|
|
2023-04-17 19:18:35 +00:00
|
|
|
// Some set operations for the basic use case
|
|
|
|
|
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
|
|
|
|
|
int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
|
|
|
|
|
c10::IntArrayRef stride = {};
|
|
|
|
|
result.unsafeGetTensorImpl()->set_storage_offset(0);
|
|
|
|
|
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : c10::nullopt;
|
|
|
|
|
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), new_size, stride_opt, /*resize_storage=*/!result.is_meta());
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-21 18:31:01 +00:00
|
|
|
// basic dummy functions related to pin_memory.
|
|
|
|
|
std::vector<void*> custom_pinned_data_ptr;
|
|
|
|
|
|
|
|
|
|
at::Tensor custom__pin_memory(const at::Tensor& self, c10::optional<at::Device> device) {
|
|
|
|
|
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
|
|
|
|
|
|
|
|
|
|
// record pinned data ptr
|
|
|
|
|
at::Tensor dump_pinned_tensor = self * 1.0;
|
|
|
|
|
custom_pinned_data_ptr.push_back(dump_pinned_tensor.storage().data_ptr().get());
|
|
|
|
|
|
|
|
|
|
return dump_pinned_tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool custom_is_pinned(const at::Tensor& self, c10::optional<at::Device> device) {
|
|
|
|
|
// Only CPU tensors can be pinned
|
|
|
|
|
if (!self.is_cpu()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void* query_pinned_ptr = self.storage().data_ptr().get();
|
|
|
|
|
for (const auto& iter_ptr : custom_pinned_data_ptr) {
|
|
|
|
|
if (iter_ptr == query_pinned_ptr) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-11 22:11:58 +00:00
|
|
|
// This macro does the heavy lifting.
|
|
|
|
|
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
|
|
|
|
|
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
|
|
|
|
|
// Later in this file, we map a custom device to the PrivateUse1 device type,
|
|
|
|
|
// which allows user code that puts a tensor on your custom_device to eventually get plumbed
|
|
|
|
|
// into the kernels registered here.
|
|
|
|
|
//
|
|
|
|
|
// This macro registers your kernels to the PyTorch Dispatcher.
|
|
|
|
|
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
|
|
|
|
|
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
|
|
|
|
m.impl("add.Tensor", &custom_add_Tensor);
|
2022-08-29 13:08:43 +00:00
|
|
|
m.impl("empty.memory_format", &custom_empty_symint);
|
2022-07-11 22:11:58 +00:00
|
|
|
m.impl("fill_.Scalar", &custom_fill__scalar);
|
|
|
|
|
m.impl("_copy_from", &custom__copy_from);
|
2023-04-13 22:04:01 +00:00
|
|
|
m.impl("empty_strided", &custom_empty_strided);
|
2023-04-17 19:18:35 +00:00
|
|
|
m.impl("set_.source_Storage", &custom_set_source_Storage);
|
2023-04-21 18:31:01 +00:00
|
|
|
m.impl("_pin_memory", &custom__pin_memory);
|
|
|
|
|
m.impl("is_pinned", &custom_is_pinned);
|
2022-07-11 22:11:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This basic implementation doesn't bother dealing with different device indices
|
|
|
|
|
// (e.g. custom_device:0 vs. custom_device:1).
|
|
|
|
|
// We could do that by letting the user pass in a device index in our exposed device function.
|
|
|
|
|
// Note that if you do that, you'll also need to register a device guard to core.
|
|
|
|
|
// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
|
|
|
|
|
c10::Device get_custom_device() {
|
|
|
|
|
return c10::Device(c10::DeviceType::PrivateUse1, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool custom_add_called() {
|
|
|
|
|
bool called = false;
|
|
|
|
|
if (add_counter > last_saved_value) {
|
|
|
|
|
called = true;
|
|
|
|
|
last_saved_value = add_counter;
|
|
|
|
|
}
|
|
|
|
|
return called;
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-07 03:43:23 +00:00
|
|
|
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
|
|
|
|
|
public:
|
|
|
|
|
// Constructors
|
|
|
|
|
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
|
|
|
|
|
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
|
|
|
|
|
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
|
|
|
|
|
}
|
|
|
|
|
~PrivateGeneratorImpl() override = default;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// this is used to register generator
|
|
|
|
|
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
|
|
|
|
|
return at::make_generator<PrivateGeneratorImpl>(device_index);
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-20 20:43:56 +00:00
|
|
|
void register_generator() {
|
2023-03-07 03:43:23 +00:00
|
|
|
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-11 22:11:58 +00:00
|
|
|
// Here, we're exposing a custom device object that corresponds to our custom backend.
|
|
|
|
|
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
|
|
|
|
|
// that's implemented in C++.
|
|
|
|
|
// The implementation in this file maps directly to the `PrivateUse1` device type.
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|
|
m.def("custom_device", &get_custom_device, "get custom device object");
|
|
|
|
|
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
|
2023-03-20 20:43:56 +00:00
|
|
|
m.def("register_generator", ®ister_generator, "register generator for custom device");
|
2022-07-11 22:11:58 +00:00
|
|
|
}
|