mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36240 It's annoying, historical, and unnecessary (enum class is already namespaced). I did this codemod with: ``` git grep -l 'CPUTensorId' | xargs sed -i 's/CPUTensorId/CPU/g' git grep -l 'CUDATensorId' | xargs sed -i 's/CUDATensorId/CUDA/g' git grep -l 'VariableTensorId' | xargs sed -i 's/VariableTensorId/Autograd/g' git grep -l 'HIPTensorId' | xargs sed -i 's/HIPTensorId/HIP/g' git grep -l 'MSNPUTensorId' | xargs sed -i 's/MSNPUTensorId/MSNPU/g' git grep -l 'XLATensorId' | xargs sed -i 's/XLATensorId/XLA/g' git grep -l 'PrivateUse1_TensorId' | xargs sed -i 's/PrivateUse1_TensorId/PrivateUse1/g' git grep -l 'PrivateUse2_TensorId' | xargs sed -i 's/PrivateUse2_TensorId/PrivateUse2/g' git grep -l 'PrivateUse3_TensorId' | xargs sed -i 's/PrivateUse3_TensorId/PrivateUse3/g' git grep -l 'AutocastTensorId' | xargs sed -i 's/AutocastTensorId/Autocast/g' git grep -l '_PreAutogradTensorId' | xargs sed -i 's/_PreAutogradTensorId/_PreAutograd/g' git grep -l 'TESTING_ONLY_GenericWrapperTensorId' | xargs sed -i 's/TESTING_ONLY_GenericWrapperTensorId/TESTING_ONLY_GenericWrapper/g' git grep -l 'TESTING_ONLY_GenericModeTensorId' | xargs sed -i 's/TESTING_ONLY_GenericModeTensorId/TESTING_ONLY_GenericMode/g' ``` Then I did a git grep for remaining TensorId occurrences, and manually killed those (mostly in codegen, and some docs that needed updating). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D20929255 Pulled By: ezyang fbshipit-source-id: dc371b6aa6e6ea7c0a5660137c14debde806a09d
124 lines
4.2 KiB
C++
124 lines
4.2 KiB
C++
#include <torch/extension.h>
|
|
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
|
|
using namespace at;
|
|
|
|
static int test_int;
|
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
Storage(
|
|
dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
|
|
DispatchKey::MSNPU);
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
tensor_impl->set_sizes_contiguous(size);
|
|
return Tensor(std::move(tensor_impl));
|
|
}
|
|
|
|
Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
|
|
test_int = 0;
|
|
return get_tensor(options.dtype(), size);
|
|
}
|
|
|
|
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
|
test_int = 1;
|
|
return get_tensor(a.dtype(), a.sizes());
|
|
}
|
|
|
|
Tensor fake_convolution(
|
|
const Tensor& input, const Tensor& weight, const Tensor& bias,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
test_int = 2;
|
|
// Only the first 2 dimension of output shape is correct.
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
test_int = 3;
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
get_tensor(input.dtype(), {}));
|
|
}
|
|
|
|
void init_msnpu_extension() {
|
|
static auto registry = torch::import()
|
|
.impl_UNBOXED("aten::empty.memory_format", kMSNPU, empty_override)
|
|
.impl_UNBOXED("aten::add.Tensor", kMSNPU, add_override)
|
|
.impl_UNBOXED("aten::convolution_overrideable", kMSNPU, fake_convolution)
|
|
.impl_UNBOXED("aten::convolution_backward_overrideable", kMSNPU, fake_convolution_backward)
|
|
;
|
|
}
|
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
// we need to add a thread local variable to track the current device.
|
|
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
|
MSNPUGuardImpl() {}
|
|
MSNPUGuardImpl(DeviceType t) {
|
|
AT_ASSERT(t == DeviceType::MSNPU);
|
|
}
|
|
DeviceType type() const override {
|
|
return DeviceType::MSNPU;
|
|
}
|
|
Device exchangeDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
return d;
|
|
}
|
|
Device getDevice() const override {
|
|
return Device(DeviceType::MSNPU, 0);
|
|
}
|
|
void setDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
}
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
}
|
|
Stream getStream(Device d) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
return 1;
|
|
}
|
|
|
|
// Event-related functions
|
|
void record(void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void block(
|
|
void* event,
|
|
const Stream& stream) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
bool queryEvent(void* event) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void destroyEvent(
|
|
void* event,
|
|
const DeviceIndex device_index) const noexcept override { }
|
|
};
|
|
|
|
constexpr DeviceType MSNPUGuardImpl::static_type;
|
|
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
|
|
|
int get_test_int() {
|
|
return test_int;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("init_msnpu_extension", &init_msnpu_extension);
|
|
m.def("get_test_int", &get_test_int);
|
|
}
|