mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26131 Changes in this PR: - For each operator with use_c10_dispatcher: True, additionally generate a c10 registration line in TypeDefault.cpp, CPUType.cpp, and other backend files. - This doesn't change globalATenDispatch yet, the c10 registration is purely additional and the operator calling path doesn't change. A diff further up the stack will change these things. - Enable the use_c10_dispatcher: True flag for about ~70% of operators - This also changes the c10->jit operator export because ATen ops are already exported to JIT directly and we don't want to export the registered c10 ops because they would clash - For this, we need a way to recognize if a certain operator is already moved from ATen to c10, this is done by generating a OpsAlreadyMovedToC10.cpp file with the list. A diff further up in the stack will also need this file to make sure we don't break the backend extension API for these ops. Reasons for some ops to be excluded (i.e. not have the `use_c10_dispatcher` flag set to true): - `Tensor?(a!)` (i.e. optional tensor with annotations) not supported in c++ function schema parser yet - `-> void` in native_functions.yaml vs `-> ()` expected by function schema parser - out functions have different argument order in C++ as in the jit schema - `Tensor?` (i.e. optional tensor) doesn't work nicely with undefined tensor sometimes being undefined tensor and sometimes being None. - fixed-size arrays like `int[3]` not supported in c10 yet These will be fixed in separate diffs and then the exclusion tag will be removed. ghstack-source-id: 90060748 Test Plan: a diff stacked on top uses these registrations to call these ops from ATen Differential Revision: D16603131 fbshipit-source-id: 315eb83d0b567eb0cd49973060b44ee1d6d64bfb
136 lines
5.4 KiB
C++
136 lines
5.4 KiB
C++
#include <torch/extension.h>
|
|
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
|
|
using namespace at;
|
|
|
|
static int test_int;
|
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
Storage(
|
|
dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
|
|
TensorTypeId::MSNPUTensorId);
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
tensor_impl->set_sizes_contiguous(size);
|
|
return Tensor(std::move(tensor_impl));
|
|
}
|
|
|
|
Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
|
|
test_int = 0;
|
|
return get_tensor(options.dtype(), size);
|
|
}
|
|
|
|
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
|
test_int = 1;
|
|
return get_tensor(a.dtype(), a.sizes());
|
|
}
|
|
|
|
Tensor fake_convolution(
|
|
const Tensor& input, const Tensor& weight, const Tensor& bias,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
test_int = 2;
|
|
// Only the first 2 dimension of output shape is correct.
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
test_int = 3;
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
get_tensor(input.dtype(), {}));
|
|
}
|
|
|
|
void init_msnpu_extension() {
|
|
static auto registry = torch::RegisterOperators()
|
|
.op(torch::RegisterOperators::options()
|
|
.schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
|
|
.impl_unboxedOnlyKernel<decltype(empty_override), &empty_override>(TensorTypeId::MSNPUTensorId)
|
|
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
|
|
.op(torch::RegisterOperators::options()
|
|
.schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
|
|
.impl_unboxedOnlyKernel<decltype(add_override), &add_override>(TensorTypeId::MSNPUTensorId)
|
|
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
|
|
.op(torch::RegisterOperators::options()
|
|
.schema("aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor")
|
|
.impl_unboxedOnlyKernel<decltype(fake_convolution), &fake_convolution>(TensorTypeId::MSNPUTensorId)
|
|
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
|
|
.op(torch::RegisterOperators::options()
|
|
.schema("aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)")
|
|
.impl_unboxedOnlyKernel<decltype(fake_convolution_backward), &fake_convolution_backward>(TensorTypeId::MSNPUTensorId)
|
|
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))
|
|
;
|
|
}
|
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
// we need to add a thread local variable to track the current device.
|
|
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
|
MSNPUGuardImpl() {}
|
|
MSNPUGuardImpl(DeviceType t) {
|
|
AT_ASSERT(t == DeviceType::MSNPU);
|
|
}
|
|
DeviceType type() const override {
|
|
return DeviceType::MSNPU;
|
|
}
|
|
Device exchangeDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
return d;
|
|
}
|
|
Device getDevice() const override {
|
|
return Device(DeviceType::MSNPU, 0);
|
|
}
|
|
void setDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
|
AT_ASSERT(d.index() == 0);
|
|
}
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
}
|
|
Stream getStream(Device d) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
return 1;
|
|
}
|
|
|
|
// Event-related functions
|
|
void record(void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void block(
|
|
void* event,
|
|
const Stream& stream) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
bool queryEvent(void* event) const override {
|
|
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
|
}
|
|
void destroyEvent(
|
|
void* event,
|
|
const DeviceIndex device_index) const noexcept override { }
|
|
};
|
|
|
|
constexpr DeviceType MSNPUGuardImpl::static_type;
|
|
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
|
|
|
int get_test_int() {
|
|
return test_int;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("init_msnpu_extension", &init_msnpu_extension);
|
|
m.def("get_test_int", &get_test_int);
|
|
}
|