2019-02-01 18:55:00 +00:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
2019-07-30 21:03:58 +00:00
|
|
|
#include <ATen/core/ATenDispatch.h>
|
2019-02-01 18:55:00 +00:00
|
|
|
|
|
|
|
|
using namespace at;
|
|
|
|
|
|
|
|
|
|
static int test_int;
|
|
|
|
|
|
2019-02-15 21:44:18 +00:00
|
|
|
Tensor get_dtype_tensor(caffe2::TypeMeta dtype) {
|
|
|
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
|
|
|
Storage(
|
2019-03-20 20:47:41 +00:00
|
|
|
dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
|
2019-04-11 20:32:45 +00:00
|
|
|
MSNPUTensorId());
|
2019-02-15 21:44:18 +00:00
|
|
|
return Tensor(std::move(tensor_impl));
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-30 21:03:58 +00:00
|
|
|
Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
|
2019-02-01 18:55:00 +00:00
|
|
|
test_int = 0;
|
2019-02-15 21:44:18 +00:00
|
|
|
return get_dtype_tensor(options.dtype());
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
|
|
|
|
test_int = 1;
|
2019-02-23 02:33:18 +00:00
|
|
|
return get_dtype_tensor(a.dtype());
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void init_msnpu_extension() {
|
2019-07-30 21:03:58 +00:00
|
|
|
globalATenDispatch().registerOp(
|
2019-02-01 18:55:00 +00:00
|
|
|
Backend::MSNPU,
|
2019-08-01 09:00:41 +00:00
|
|
|
"aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
2019-07-30 21:03:58 +00:00
|
|
|
&empty_override);
|
|
|
|
|
globalATenDispatch().registerOp(
|
2019-02-01 18:55:00 +00:00
|
|
|
Backend::MSNPU,
|
2019-08-01 09:00:41 +00:00
|
|
|
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
|
2019-07-30 21:03:58 +00:00
|
|
|
&add_override);
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
2019-03-20 20:47:41 +00:00
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
|
|
|
// we need to add a thread local variable to track the current device.
|
|
|
|
|
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
|
|
|
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
|
|
|
|
MSNPUGuardImpl() {}
|
|
|
|
|
MSNPUGuardImpl(DeviceType t) {
|
|
|
|
|
AT_ASSERT(t == DeviceType::MSNPU);
|
|
|
|
|
}
|
|
|
|
|
DeviceType type() const override {
|
|
|
|
|
return DeviceType::MSNPU;
|
|
|
|
|
}
|
|
|
|
|
Device exchangeDevice(Device d) const override {
|
|
|
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
return d;
|
|
|
|
|
}
|
|
|
|
|
Device getDevice() const override {
|
|
|
|
|
return Device(DeviceType::MSNPU, 0);
|
|
|
|
|
}
|
|
|
|
|
void setDevice(Device d) const override {
|
|
|
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
}
|
|
|
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
|
|
|
}
|
|
|
|
|
Stream getStream(Device d) const noexcept override {
|
|
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
|
|
|
}
|
|
|
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
|
|
|
}
|
2019-03-26 16:42:41 +00:00
|
|
|
DeviceIndex deviceCount() const noexcept override {
|
2019-03-20 20:47:41 +00:00
|
|
|
return 1;
|
|
|
|
|
}
|
2019-08-26 22:17:26 +00:00
|
|
|
|
|
|
|
|
// Event-related functions
|
|
|
|
|
void record(void** event,
|
|
|
|
|
const Stream& stream,
|
|
|
|
|
const DeviceIndex device_index,
|
|
|
|
|
const EventFlag flag) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
void block(
|
|
|
|
|
void* event,
|
|
|
|
|
const Stream& stream) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
bool queryEvent(void* event) const override {
|
|
|
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
|
|
|
}
|
|
|
|
|
void destroyEvent(
|
|
|
|
|
void* event,
|
|
|
|
|
const DeviceIndex device_index) const noexcept override { }
|
2019-03-20 20:47:41 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
constexpr DeviceType MSNPUGuardImpl::static_type;
|
|
|
|
|
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
|
|
|
|
|
2019-02-01 18:55:00 +00:00
|
|
|
int get_test_int() {
|
|
|
|
|
return test_int;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|
|
m.def("init_msnpu_extension", &init_msnpu_extension);
|
|
|
|
|
m.def("get_test_int", &get_test_int);
|
|
|
|
|
}
|