mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39492 This PR adds use_c10_dispatcher: full to ops taking TensorOptions. To allow this, since the c10 operator library doesn't know about TensorOptions, we need to register the operator kernels as optional<ScalarType>, optional<Device>, optional<Layout>, optional<bool> instead, and also call them this way. Changes: Add use_c10_dispatcher: full to those ops Write hacky_wrapper_for_legacy_signatures which takes an old-style kernel (i.e. one written to take TensorOptions) an creates a wrapper kernel for it that takes the scattered optional<ScalarType>, optional<Device>, optional<Layout>, optional<bool> instead. Change codegen so that all op registrations are wrapped into hacky_wrapper_for_legacy_signatures. This is added to all ops but is a no-op if the op doesn't take TensorOptions. This allows us in the future to just change a kernel signature from TensorOptions to the scattered version and have it work without having to touch codegen. Change codegen so that the frontend calls those operators with expanded arguments instead of with a TensorOptions object. This is required because now the kernels are written in this way. This PR does not remove TensorOptions special cases from codegen, but instead it separates kernels from the codegen/frontend issues. After this, kernels can be worked on separately without having to touch codegen and codegen can be worked on without having to touch kernels. Codegen diff: P133121032 ghstack-source-id: 106426630 Test Plan: waitforsandcastle Differential Revision: D21581908 fbshipit-source-id: 6d4a9f526fd70fae40581bf26f3ccf794ce6a89e
129 lines
4.3 KiB
C++
129 lines
4.3 KiB
C++
#include <torch/extension.h>
|
|
#include <torch/library.h>
|
|
|
|
using namespace at;
|
|
|
|
static int test_int;
|
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
Storage(
|
|
Storage::use_byte_size_t(),
|
|
0,
|
|
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
|
|
nullptr,
|
|
false),
|
|
DispatchKey::MSNPU,
|
|
dtype);
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
tensor_impl->set_sizes_contiguous(size);
|
|
return Tensor(std::move(tensor_impl));
|
|
}
|
|
|
|
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
|
test_int = 0;
|
|
caffe2::TypeMeta typeMeta;
|
|
if (dtype.has_value()) {
|
|
typeMeta = scalarTypeToTypeMeta(*dtype);
|
|
}
|
|
return get_tensor(typeMeta, size);
|
|
}
|
|
|
|
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
|
test_int = 1;
|
|
return get_tensor(a.dtype(), a.sizes());
|
|
}
|
|
|
|
Tensor fake_convolution(
|
|
const Tensor& input, const Tensor& weight, const Tensor& bias,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
test_int = 2;
|
|
// Only the first 2 dimension of output shape is correct.
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
test_int = 3;
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
get_tensor(input.dtype(), {}));
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
|
m.impl_UNBOXED("empty.memory_format", empty_override);
|
|
m.impl_UNBOXED("add.Tensor", add_override);
|
|
m.impl_UNBOXED("convolution_overrideable", fake_convolution);
|
|
m.impl_UNBOXED("convolution_backward_overrideable", fake_convolution_backward);
|
|
}
|
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
// we need to add a thread local variable to track the current device.
|
|
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
|
MSNPUGuardImpl() {}
|
|
MSNPUGuardImpl(DeviceType t) {
|
|
AT_ASSERT(t == DeviceType::MSNPU);
|
|
}
|
|
DeviceType type() const override {
|
|
return DeviceType::MSNPU;
|
|
}
|
|
Device exchangeDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
return d;
|
|
}
|
|
Device getDevice() const override {
|
|
return Device(DeviceType::MSNPU, 0);
|
|
}
|
|
void setDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
}
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
}
|
|
Stream getStream(Device d) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
return 1;
|
|
}
|
|
|
|
// Event-related functions
|
|
void record(void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void block(
|
|
void* event,
|
|
const Stream& stream) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
bool queryEvent(void* event) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void destroyEvent(
|
|
void* event,
|
|
const DeviceIndex device_index) const noexcept override { }
|
|
};
|
|
|
|
constexpr DeviceType MSNPUGuardImpl::static_type;
|
|
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
|
|
|
int get_test_int() {
|
|
return test_int;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("get_test_int", &get_test_int);
|
|
}
|