pytorch/test/cpp_extensions/msnpu_extension.cpp
Wenlei Xie 2ecb2c7931 Pass Scalar by reference (#53583)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53583

`Scalar` takes 32 bytes due to `c10::complex<double>`
requires aligning to 16 bytes. Passing Scalar by reference
shows about 1% improvements on instruction count.

All the changes in this commit are codemoded except for
the following 4 files (which code-gen signatures):
```
tools/codegen/api/cpp.py
tools/codegen/api/native.py
tools/codegen/api/structured.py
caffe2/contrib/aten/gen_op.py
```

# Codemode

## Main Step

For the codemod part, here is the main command used:
```
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)optional<Scalar> (\w+)' '${1}const optional<Scalar>& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)optional<Scalar> (\w+)' '${1}const optional<Scalar>& ${2}'
```

As you can tell, it codemods both `Scalar` and `optional<Scalar>`.  Apply these commands iteratively until reaching a fix-point (since one method signature might contain multiple `Scalar` parameter).

In retrospect, excluding `thrid_party` and `torch/csrc/jit` would be a good idea. (I revert it manually later, see https://github.com/pytorch/pytorch/pull/53479 as an reference).

## Pre-Step

Prior to applying the main command,  as some `Scalar` are presented as `at::Scalar` or `c10::Scalar`, so I codemod some of them in advance. Here is an incomplete list:
```
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)at::Scalar (\w+)' '${1}const at::Scalar& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)at::Scalar (\w+)' '${1}const at::Scalar& ${2}'
fastmod --extensions h '([a-zA-Z_+]\([^)]*,?\s*)c10::optional<Scalar> (\w+)' '${1}const c10::optional<Scalar>& ${2}'
fastmod --extensions cpp '([a-zA-Z_+]\([^)]*,?\s*)c10::optional<Scalar> (\w+)' '${1}const c10::optional<Scalar>& ${2}'
```

## Fixup
There are a couple of post codemod fixup. For example, `const Scalar` will be codemoded into `const const Scalar&`. `at:Scalar` will be codemoded into `at::const Scalar&`  (if `Pre-step` is not done comprehensively). Here is an incomplete list:
```
fastmod --extensions cpp 'const const Scalar' 'const Scalar'
fastmod --extensions h 'const const c10::optional<Scalar>' 'const c10::optional<Scalar>'
fastmod --extensions cpp 'const const c10::optional<Scalar>' 'const c10::optional<Scalar>'
fastmod 'at::const Scalar&' 'const at::Scalar&'
```

## Supplementary

`cu` and `mm` files also need to be codemoded, for example:

```
fastmod --extensions cu 'at::const Scalar&' 'const at::Scalar&'
fastmod --extensions mm '([a-zA-Z_+]\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'
```

Function pointers are not codemoded. Here is an incomplete list:

```
# Cover case: using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, Scalar source);
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar (\w+)' '${1}const Scalar& ${2}'

# Cover case: using softplus_fn = void (*)(TensorIterator&, Scalar, Scalar);
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar([, \)])' '${1}const Scalar&${2}'
fastmod --extensions cpp '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)Scalar([, \)])' '${1}const Scalar&${2}'
fastmod --extensions h '(void\s*\(\s*\*\s*\)\([^)]*,?\s*)optional<Scalar>([, \)])' '${1}const optional<Scalar>&${2}'
```

Some corner cases needs to be manually fixed.

ghstack-source-id: 123970306

Test Plan: Imported from OSS

Reviewed By: smessmer

Differential Revision: D26904445

fbshipit-source-id: 8d8a002af4b5125f153a32f03c6956be7ae5671d
2021-03-15 23:17:06 -07:00

126 lines
4.3 KiB
C++

#include <torch/extension.h>
#include <torch/library.h>
using namespace at;
static int test_int;
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
Storage::use_byte_size_t(),
0,
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
nullptr,
false),
DispatchKey::MSNPU,
dtype);
// This is a hack to workaround the shape checks in _convolution.
tensor_impl->set_sizes_contiguous(size);
return Tensor(std::move(tensor_impl));
}
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
test_int = 0;
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
}
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
test_int = 1;
return out;
}
Tensor fake_convolution(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
test_int = 2;
// Only the first 2 dimension of output shape is correct.
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
}
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
IntArrayRef stride, IntArrayRef padding,
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
int64_t groups, std::array<bool,3> output_mask) {
test_int = 3;
return std::tuple<Tensor, Tensor, Tensor>(
get_tensor(input.dtype(), input.sizes()),
get_tensor(weight.dtype(), weight.sizes()),
get_tensor(input.dtype(), {}));
}
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);
m.impl("convolution_backward_overrideable", fake_convolution_backward);
}
// TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device.
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::MSNPU;
MSNPUGuardImpl() {}
MSNPUGuardImpl(DeviceType t) {
AT_ASSERT(t == DeviceType::MSNPU);
}
DeviceType type() const override {
return DeviceType::MSNPU;
}
Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
return d;
}
Device getDevice() const override {
return Device(DeviceType::MSNPU, 0);
}
void setDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU);
AT_ASSERT(d.index() == 0);
}
void uncheckedSetDevice(Device d) const noexcept override {
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
void block(
void* event,
const Stream& stream) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
bool queryEvent(void* event) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
};
constexpr DeviceType MSNPUGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
int get_test_int() {
return test_int;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_test_int", &get_test_int);
}