2019-02-01 18:55:00 +00:00
|
|
|
#include <torch/extension.h>
|
2020-04-22 16:15:41 +00:00
|
|
|
#include <torch/library.h>
|
2019-02-01 18:55:00 +00:00
|
|
|
|
|
|
|
|
using namespace at;
|
|
|
|
|
|
|
|
|
|
static int test_int;
|
|
|
|
|
|
2019-08-28 01:18:45 +00:00
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
2019-02-15 21:44:18 +00:00
|
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
|
|
|
Storage(
|
2020-05-06 05:41:11 +00:00
|
|
|
Storage::use_byte_size_t(),
|
|
|
|
|
0,
|
2024-04-23 00:33:20 +00:00
|
|
|
at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)),
|
2020-05-06 05:41:11 +00:00
|
|
|
nullptr,
|
|
|
|
|
false),
|
2024-04-23 00:33:20 +00:00
|
|
|
DispatchKey::MAIA,
|
2020-05-21 22:21:23 +00:00
|
|
|
dtype);
|
2019-08-28 01:18:45 +00:00
|
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
|
|
|
tensor_impl->set_sizes_contiguous(size);
|
2019-02-15 21:44:18 +00:00
|
|
|
return Tensor(std::move(tensor_impl));
|
|
|
|
|
}
|
|
|
|
|
|
2024-05-24 00:26:15 +00:00
|
|
|
Tensor empty_override(IntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device,
|
|
|
|
|
std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
|
2019-02-01 18:55:00 +00:00
|
|
|
test_int = 0;
|
Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.
With this change, I accept old type registrations transparently. The
idea is in several parts:
- At the registration site, at compile time I have no idea whether or not
if the function being registered has a SymInt schema or not. So I
must defer the exact compatibility check. What I do instead is
check if the function pointer registered to me has SymInt in the
argument or not. If it does, I assume it is new-style and ensure
it is also registered to a special sym_ slot on KernelFunction.
If not, it only goes in the conventional slot.
- At the dispatcher site, I know at compile time whether or not this
is a SymInt function. If it is, I check for a sym_ slot on the
KernelFunction, and preferentially use that. If no such slot
exists, I then fall back to the regular slot... but I convert
all SymInt arguments to int64_t arguments (doing assertions that
no true symbolic integer was passed.) I can skip this test entirely
if the function doesn't have any SymInts in it; in that case I know
that only the original slot could have been registered. Fortunately,
both branches of the short circuit typecheck, so I didn't have to
use SFINAE or if-constexpr to make it work; just a plain if statement
that I expect the compiler to optimize away.
- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.
To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.
I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557
Approved by: https://github.com/wconstab
2022-09-07 12:58:32 +00:00
|
|
|
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
|
2022-07-04 20:08:53 +00:00
|
|
|
}
|
|
|
|
|
|
Pass Scalar by reference (#53583)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53583
`Scalar` takes 32 bytes due to `c10::complex<double>`
requires aligning to 16 bytes. Passing Scalar by reference
shows about 1% improvements on instruction count.
All the changes in this commit are codemoded except for
the following 4 files (which code-gen signatures):
```
tools/codegen/api/cpp.py
tools/codegen/api/native.py
tools/codegen/api/structured.py
caffe2/contrib/aten/gen_op.py
```
# Codemode
## Main Step
For the codemod part, here is the main command used:
```
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)optional<Scalar> (\w+)' '${1}const optional<Scalar>& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)optional<Scalar> (\w+)' '${1}const optional<Scalar>& ${2}'
```
As you can tell, it codemods both `Scalar` and `optional<Scalar>`. Apply these commands iteratively until reaching a fix-point (since one method signature might contain multiple `Scalar` parameter).
In retrospect, excluding `thrid_party` and `torch/csrc/jit` would be a good idea. (I revert it manually later, see https://github.com/pytorch/pytorch/pull/53479 as an reference).
## Pre-Step
Prior to applying the main command, as some `Scalar` are presented as `at::Scalar` or `c10::Scalar`, so I codemod some of them in advance. Here is an incomplete list:
```
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)at::Scalar (\w+)' '${1}const at::Scalar& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)at::Scalar (\w+)' '${1}const at::Scalar& ${2}'
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)c10::optional<Scalar> (\w+)' '${1}const c10::optional<Scalar>& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)c10::optional<Scalar> (\w+)' '${1}const c10::optional<Scalar>& ${2}'
```
## Fixup
There are a couple of post codemod fixup. For example, `const Scalar` will be codemoded into `const const Scalar&`. `at:Scalar` will be codemoded into `at::const Scalar&` (if `Pre-step` is not done comprehensively). Here is an incomplete list:
```
fastmod --extensions cpp 'const const Scalar' 'const Scalar'
fastmod --extensions h 'const const c10::optional<Scalar>' 'const c10::optional<Scalar>'
fastmod --extensions cpp 'const const c10::optional<Scalar>' 'const c10::optional<Scalar>'
fastmod 'at::const Scalar&' 'const at::Scalar&'
```
## Supplementary
`cu` and `mm` files also need to be codemoded, for example:
```
fastmod --extensions cu 'at::const Scalar&' 'const at::Scalar&'
fastmod --extensions mm '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
```
Function pointers are not codemoded. Here is an incomplete list:
```
# Cover case: using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, Scalar source);
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
# Cover case: using softplus_fn = void (*)(TensorIterator&, Scalar, Scalar);
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar([, \)])' '${1}const Scalar&${2}'
fastmod --extensions cpp '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar([, \)])' '${1}const Scalar&${2}'
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)optional<Scalar>([, \)])' '${1}const optional<Scalar>&${2}'
```
Some corner cases needs to be manually fixed.
ghstack-source-id: 123970306
Test Plan: Imported from OSS
Reviewed By: smessmer
Differential Revision: D26904445
fbshipit-source-id: 8d8a002af4b5125f153a32f03c6956be7ae5671d
2021-03-16 06:13:12 +00:00
|
|
|
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
|
2019-02-01 18:55:00 +00:00
|
|
|
test_int = 1;
|
2021-03-02 22:05:21 +00:00
|
|
|
return out;
|
2019-08-28 01:18:45 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor fake_convolution(
|
2024-05-14 19:35:49 +00:00
|
|
|
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias,
|
2019-08-28 01:18:45 +00:00
|
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
|
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
|
|
|
test_int = 2;
|
|
|
|
|
// Only the first 2 dimension of output shape is correct.
|
|
|
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
|
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
|
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
|
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
|
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
|
|
|
test_int = 3;
|
|
|
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
|
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
|
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
|
|
|
get_tensor(input.dtype(), {}));
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
2024-04-23 00:33:20 +00:00
|
|
|
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
|
2021-01-06 22:14:24 +00:00
|
|
|
m.impl("empty.memory_format", empty_override);
|
2021-03-02 22:05:21 +00:00
|
|
|
m.impl("add.out", add_out_override);
|
2021-01-06 22:14:24 +00:00
|
|
|
m.impl("convolution_overrideable", fake_convolution);
|
|
|
|
|
m.impl("convolution_backward_overrideable", fake_convolution_backward);
|
2019-02-01 18:55:00 +00:00
|
|
|
}
|
|
|
|
|
|
2019-03-20 20:47:41 +00:00
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
|
|
|
// we need to add a thread local variable to track the current device.
|
2024-04-23 00:33:20 +00:00
|
|
|
struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
|
|
|
static constexpr DeviceType static_type = DeviceType::MAIA;
|
|
|
|
|
MAIAGuardImpl() {}
|
|
|
|
|
MAIAGuardImpl(DeviceType t) {
|
|
|
|
|
AT_ASSERT(t == DeviceType::MAIA);
|
2019-03-20 20:47:41 +00:00
|
|
|
}
|
|
|
|
|
DeviceType type() const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
return DeviceType::MAIA;
|
2019-03-20 20:47:41 +00:00
|
|
|
}
|
|
|
|
|
Device exchangeDevice(Device d) const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
AT_ASSERT(d.type() == DeviceType::MAIA);
|
2019-03-20 20:47:41 +00:00
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
return d;
|
|
|
|
|
}
|
|
|
|
|
Device getDevice() const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
return Device(DeviceType::MAIA, 0);
|
2019-03-20 20:47:41 +00:00
|
|
|
}
|
|
|
|
|
void setDevice(Device d) const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
AT_ASSERT(d.type() == DeviceType::MAIA);
|
2019-03-20 20:47:41 +00:00
|
|
|
AT_ASSERT(d.index() == 0);
|
|
|
|
|
}
|
|
|
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
|
|
|
}
|
|
|
|
|
Stream getStream(Device d) const noexcept override {
|
2024-04-23 00:33:20 +00:00
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
|
2019-03-20 20:47:41 +00:00
|
|
|
}
|
|
|
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
2024-04-23 00:33:20 +00:00
|
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
|
2019-03-20 20:47:41 +00:00
|
|
|
}
|
2019-03-26 16:42:41 +00:00
|
|
|
DeviceIndex deviceCount() const noexcept override {
|
2019-03-20 20:47:41 +00:00
|
|
|
return 1;
|
|
|
|
|
}
|
2019-09-01 19:36:22 +00:00
|
|
|
|
|
|
|
|
// Event-related functions
|
|
|
|
|
void record(void** event,
|
|
|
|
|
const Stream& stream,
|
|
|
|
|
const DeviceIndex device_index,
|
|
|
|
|
const EventFlag flag) const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
2019-09-01 19:36:22 +00:00
|
|
|
}
|
|
|
|
|
void block(
|
|
|
|
|
void* event,
|
|
|
|
|
const Stream& stream) const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
2019-09-01 19:36:22 +00:00
|
|
|
}
|
|
|
|
|
bool queryEvent(void* event) const override {
|
2024-04-23 00:33:20 +00:00
|
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
2019-09-01 19:36:22 +00:00
|
|
|
}
|
|
|
|
|
void destroyEvent(
|
|
|
|
|
void* event,
|
|
|
|
|
const DeviceIndex device_index) const noexcept override { }
|
2019-03-20 20:47:41 +00:00
|
|
|
};
|
|
|
|
|
|
2024-04-23 00:33:20 +00:00
|
|
|
constexpr DeviceType MAIAGuardImpl::static_type;
|
|
|
|
|
C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl);
|
2019-03-20 20:47:41 +00:00
|
|
|
|
2019-02-01 18:55:00 +00:00
|
|
|
int get_test_int() {
|
|
|
|
|
return test_int;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|
|
m.def("get_test_int", &get_test_int);
|
|
|
|
|
}
|