2019-02-01 18:55:00 +00:00
|
|
|
#include <torch/extension.h>
|
2020-04-22 16:15:41 +00:00
|
|
|
#include <torch/library.h>
|
2019-02-01 18:55:00 +00:00
|
|
|
|
|
|
|
|
using namespace at;
|
|
|
|
|
|
|
|
|
|
static int test_int;
|
|
|
|
|
|
2019-08-28 01:18:45 +00:00
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
2019-02-15 21:44:18 +00:00
|
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
|
|
|
Storage(
|
2020-05-06 05:41:11 +00:00
|
|
|
Storage::use_byte_size_t(),
|
|
|
|
|
dtype,
|
|
|
|
|
0,
|
|
|
|
|
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
|
|
|
|
|
nullptr,
|
|
|
|
|
false),
|
2020-04-14 06:28:32 +00:00
|
|
|
DispatchKey::MSNPU);
|
2019-08-28 01:18:45 +00:00
|
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
|
|
|
tensor_impl->set_sizes_contiguous(size);
|
2019-02-15 21:44:18 +00:00
|
|
|
return Tensor(std::move(tensor_impl));
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-30 21:03:58 +00:00
|
|
|
Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
|
2019-02-01 18:55:00 +00:00
|
|
|
test_int = 0;
|
2019-08-28 01:18:45 +00:00
|
|
|
return get_tensor(options.dtype(), size);
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
|
|
|
|
test_int = 1;
|
2019-08-28 01:18:45 +00:00
|
|
|
return get_tensor(a.dtype(), a.sizes());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor fake_convolution(
|
|
|
|
|
const Tensor& input, const Tensor& weight, const Tensor& bias,
|
|
|
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
|
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
|
|
|
test_int = 2;
|
|
|
|
|
// Only the first 2 dimension of output shape is correct.
|
|
|
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
|
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
|
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
|
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
|
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
|
|
|
test_int = 3;
|
|
|
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
|
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
|
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
|
|
|
get_tensor(input.dtype(), {}));
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
Switch to pybind11 style registration function API. (#36258)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36258
Previous we had a && chaining style API. There are some downsides to
this API:
- It's easy to forget the 'static' qualifier in front, leading to
subtle ODR bugs.
- It is not compatible with torchbind class_ definitions, as these
need multiple levels of chaining. So in practice people end
up having to define multiple static initializers, one per class.
- It's not like pybind11.
- There's no way to conveniently get the file and line number of
the registration, as there is no macro point in the API.
- The old API doesn't really encourage people to put all of their
definitions for a library in one place, and to give a custom
namespace for it. Similarly, the old API wasn't very DRY, because
you had to keep repeating the namespace/dispatch key you
were writing implementations for.
The new API is modeled exactly off of the PYBIND11_MODULE macro:
you write:
```
TORCH_LIBRARY(aten, m) {
m.def("aten::add(Tensor self, Tensor other) -> Tensor");
...
}
```
in a non-chaining fashion, and under the hood the macro expands to
define a function, and define a static initializer that allocates
c10::Library (previously called c10::Module, but we renamed it
to avoid confusion with the existing NN module concept), passes
it to your function, and then retains it for the rest of the lifetime
of the program. Specification of the namespace is mandatory,
and in later commit I plan to make it a hard error to TORCH_LIBRARY
the same library name twice.
If you are specifying an implementation for an existing operator
(e.g., you're the XLA backend, or even if you're just putting
registrations for implementations at the implementation site),
you should use TORCH_LIBRARY_IMPL, which instead takes a backend
argument (instead of namespace) and can be used to specify an
implementation for a backend. Unlike TORCH_LIBRARY, you can do
as many of these as you want for a backend.
This needs updates to the mobile code analyzer.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D20929257
Pulled By: ezyang
fbshipit-source-id: ba04d78492e8c93ae7190165fb936f6872896ada
2020-04-16 17:40:43 +00:00
|
|
|
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
|
|
|
|
m.impl_UNBOXED("empty.memory_format", empty_override);
|
|
|
|
|
m.impl_UNBOXED("add.Tensor", add_override);
|
|
|
|
|
m.impl_UNBOXED("convolution_overrideable", fake_convolution);
|
|
|
|
|
m.impl_UNBOXED("convolution_backward_overrideable", fake_convolution_backward);
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
2019-03-20 20:47:41 +00:00
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
|
|
|
// we need to add a thread local variable to track the current device.
|
|
|
|
|
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
|
|
|
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
|
|
|
|
MSNPUGuardImpl() {}
|
|
|
|
|
MSNPUGuardImpl(DeviceType t) {
|
|
|
|
|
AT_ASSERT(t == DeviceType::MSNPU);
|
|
|
|
|
}
|
|
|
|
|
DeviceType type() const override {
|
|
|
|
|
return DeviceType::MSNPU;
|
|
|
|
|
}
|
|
|
|
|
Device exchangeDevice(Device d) const override {
|
|
|
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
return d;
|
|
|
|
|
}
|
|
|
|
|
Device getDevice() const override {
|
|
|
|
|
return Device(DeviceType::MSNPU, 0);
|
|
|
|
|
}
|
|
|
|
|
void setDevice(Device d) const override {
|
|
|
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
}
|
|
|
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
|
|
|
}
|
|
|
|
|
Stream getStream(Device d) const noexcept override {
|
|
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
|
|
|
}
|
|
|
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
|
|
|
}
|
2019-03-26 16:42:41 +00:00
|
|
|
DeviceIndex deviceCount() const noexcept override {
|
2019-03-20 20:47:41 +00:00
|
|
|
return 1;
|
|
|
|
|
}
|
2019-09-01 19:36:22 +00:00
|
|
|
|
|
|
|
|
// Event-related functions
|
|
|
|
|
void record(void** event,
|
|
|
|
|
const Stream& stream,
|
|
|
|
|
const DeviceIndex device_index,
|
|
|
|
|
const EventFlag flag) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
void block(
|
|
|
|
|
void* event,
|
|
|
|
|
const Stream& stream) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
bool queryEvent(void* event) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
void destroyEvent(
|
|
|
|
|
void* event,
|
|
|
|
|
const DeviceIndex device_index) const noexcept override { }
|
2019-03-20 20:47:41 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
constexpr DeviceType MSNPUGuardImpl::static_type;
|
|
|
|
|
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
|
|
|
|
|
2019-02-01 18:55:00 +00:00
|
|
|
int get_test_int() {
|
|
|
|
|
return test_int;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|
|
m.def("get_test_int", &get_test_int);
|
|
|
|
|
}
|