pytorch/test/cpp_extensions/msnpu_extension.cpp
Edward Yang 2db61193bb Add DispatchKey impl overload; remove use of torch::dispatch (#35706)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35706

It is extremely common to define implementations of operators at a
specific dispatch key, so we add an overload to impl specifically for
this case.  I then delete most uses of torch::dispatch

dispatch_autograd call sites can't make use of this overload.  So
instead the new preferred way to specify something as autograd is to
pass kAutograd as the dispatch key (short form, analogous to kCPU/kCUDA
which we support today).

I flip flopped about whether or not kAutograd should have the type
DispatchKey or some other type (to help better encapsulate the
DispatchKey enum); this is more direct and I can't think of any
BC problems from this usage.

Some other reorganization I did:
- I renamed all of the worker functions in op_registration to have
  a leading underscore and made them private, just to make it more
  clear what the public versus private API were (the private API
  shouldn't be used by users because it doesn't come with && overloads)
- In a few places where I was touching lines already, I replaced
  full DispatchKey typed out enums with shorter kFoo names, similar
  to kAutograd but I didn't publish these globally.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20775783

Pulled By: ezyang

fbshipit-source-id: e45b289e5d1f86c180b24cf14c63cf4459ab5337
2020-04-02 08:51:22 -07:00

124 lines
4.3 KiB
C++

#include <torch/extension.h>
#include <ATen/core/op_registration/op_registration.h>
using namespace at;
static int test_int;
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
DispatchKey::MSNPUTensorId);
// This is a hack to workaround the shape checks in _convolution.
tensor_impl->set_sizes_contiguous(size);
return Tensor(std::move(tensor_impl));
}
Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
test_int = 0;
return get_tensor(options.dtype(), size);
}
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
test_int = 1;
return get_tensor(a.dtype(), a.sizes());
}
Tensor fake_convolution(
const Tensor& input, const Tensor& weight, const Tensor& bias,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
test_int = 2;
// Only the first 2 dimension of output shape is correct.
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
}
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
IntArrayRef stride, IntArrayRef padding,
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
int64_t groups, std::array<bool,3> output_mask) {
test_int = 3;
return std::tuple<Tensor, Tensor, Tensor>(
get_tensor(input.dtype(), input.sizes()),
get_tensor(weight.dtype(), weight.sizes()),
get_tensor(input.dtype(), {}));
}
void init_msnpu_extension() {
static auto registry = torch::import()
.impl("aten::empty.memory_format", kMSNPU, CppFunction::makeUnboxedOnly(empty_override))
.impl("aten::add.Tensor", kMSNPU, CppFunction::makeUnboxedOnly(add_override))
.impl("aten::convolution_overrideable", kMSNPU, CppFunction::makeUnboxedOnly(fake_convolution))
.impl("aten::convolution_backward_overrideable", kMSNPU, CppFunction::makeUnboxedOnly(fake_convolution_backward))
;
}
// TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device.
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::MSNPU;
MSNPUGuardImpl() {}
MSNPUGuardImpl(DeviceType t) {
AT_ASSERT(t == DeviceType::MSNPU);
}
DeviceType type() const override {
return DeviceType::MSNPU;
}
Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
return d;
}
Device getDevice() const override {
return Device(DeviceType::MSNPU, 0);
}
void setDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
}
void uncheckedSetDevice(Device d) const noexcept override {
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
void block(
void* event,
const Stream& stream) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
bool queryEvent(void* event) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
};
constexpr DeviceType MSNPUGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
int get_test_int() {
return test_int;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_msnpu_extension", &init_msnpu_extension);
m.def("get_test_int", &get_test_int);
}