pytorch/test/cpp_extensions/msnpu_extension.cpp
Edward Yang 515238e0a5 Unify cudaGetDeviceCount implementations. (#18445)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18445
ghimport-source-id: 30d018737bf6989bc68b7e3676f44e0ca6141fde

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18242 Test running a CUDA build on CPU machine.
* **#18445 Unify cudaGetDeviceCount implementations.**

I went about doing this by searching for calls to cudaGetDeviceCount,
and then methodically replacing them with references to c10::cuda::device_count()
or at::cuda::device_count().

There is a point to doing this: the various implementations wildly differed
in their handling of what to do when cudaGetDeviceCount returns an error.
The final standardized behavior is that **all errors are swallowed** and
we return device count of zero.  This indirectly fixes running CUDA builds
on CPU, which was broken in #17847.

I added 'noexcept' to the 'deviceCount' virtual method on DeviceGuardImpl.
This is a BC-breaking change for anyone inheriting from DeviceGuardImpl
but all you need to do is put 'noexcept' on your method and it is backwards
compatible with older libtorch.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14612189

fbshipit-source-id: 3c8d186e3dd623c0e27625212c7ce30f75d943cb
2019-03-26 09:50:14 -07:00

140 lines
4.1 KiB
C++

#include <torch/extension.h>
#include <ATen/ExtensionBackendRegistration.h>
using namespace at;
static int test_int;
Tensor get_dtype_tensor(caffe2::TypeMeta dtype) {
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
MSNPUTensorId(),
false);
return Tensor(std::move(tensor_impl));
}
Tensor zeros_override(IntArrayRef size, const TensorOptions & options) {
test_int = 0;
return get_dtype_tensor(options.dtype());
}
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
test_int = 1;
return get_dtype_tensor(a.dtype());
}
Tensor sum_override(const Tensor & self) {
test_int = 2;
return get_dtype_tensor(self.dtype());
}
// needed for sum backwards
Tensor expand_override(const Tensor & self, IntArrayRef size, bool implicit) {
return get_dtype_tensor(self.dtype());
}
Tensor kl_div_override(
const Tensor & self, const Tensor & target, int64_t reduction) {
test_int = 3;
return get_dtype_tensor(self.dtype());
}
Tensor kl_div_backward_override(
const Tensor & grad_output,
const Tensor & self,
const Tensor & target,
int64_t reduction) {
test_int = 4;
return get_dtype_tensor(self.dtype());
}
// numel and ones_like are needed for autograd backwards
int64_t numel_override(const Tensor & self) {
return 1;
}
Tensor ones_like_override(const Tensor & self, const TensorOptions & options) {
return get_dtype_tensor(options.dtype());
}
void init_msnpu_extension() {
register_extension_backend_op(
Backend::MSNPU,
"zeros(IntArrayRef size, TensorOptions options) -> Tensor", &zeros_override);
register_extension_backend_op(
Backend::MSNPU,
"add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
register_extension_backend_op(
Backend::MSNPU,
"sum(Tensor self) -> Tensor", &sum_override);
register_extension_backend_op(
Backend::MSNPU,
"expand(Tensor self, IntArrayRef size, bool implicit) -> Tensor",
&expand_override);
register_extension_backend_op(
Backend::MSNPU,
"kl_div(Tensor self, Tensor target, int64_t reduction) -> Tensor",
&kl_div_override);
register_extension_backend_op(
Backend::MSNPU,
"kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) -> Tensor",
&kl_div_backward_override);
register_extension_backend_op(
Backend::MSNPU,
"numel(Tensor self) -> int64_t", &numel_override);
register_extension_backend_op(
Backend::MSNPU,
"ones_like(Tensor self, TensorOptions options) -> Tensor",
&ones_like_override);
}
// TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device.
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::MSNPU;
MSNPUGuardImpl() {}
MSNPUGuardImpl(DeviceType t) {
AT_ASSERT(t == DeviceType::MSNPU);
}
DeviceType type() const override {
return DeviceType::MSNPU;
}
Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
return d;
}
Device getDevice() const override {
return Device(DeviceType::MSNPU, 0);
}
void setDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
}
void uncheckedSetDevice(Device d) const noexcept override {
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
};
constexpr DeviceType MSNPUGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
int get_test_int() {
return test_int;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_msnpu_extension", &init_msnpu_extension);
m.def("get_test_int", &get_test_int);
}