Added vectorized horizontal flip path for channels last for NcHW (#91806)

## Description

- Added AVX2-only vectorization for horizontal flip op applied on channels last NCHW input, where **2 <= C * sizeof(dtype) <= 16**. PR is a bit faster than Pillow and largely faster (x2 - x5) than Nightly.
- ~Still keeping `cpu_vflip_memcpy` code ([it's PR](https://github.com/pytorch/pytorch/pull/89414) was reverted and is under investigations)~

## Benchmarks

```
[---------------------------------------------------------------------- Horizontal flip ----------------------------------------------------------------------]
                                                                  |  torch (2.0.0a0+gitf6d73f3) PR  |    Pillow (9.4.0)   |  torch (2.0.0a0+git4386f31) nightly
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      channels=2, size=256, dtype=torch.uint8, mf=channels_last   |         31.859 (+-0.498)        |                     |          190.599 (+-7.579)
      channels=2, size=520, dtype=torch.uint8, mf=channels_last   |         60.648 (+-0.074)        |                     |          706.895 (+-11.219)
      channels=2, size=712, dtype=torch.uint8, mf=channels_last   |         95.994 (+-2.510)        |                     |         1340.685 (+-169.279)

      channels=3, size=256, dtype=torch.uint8, mf=channels_last   |         45.490 (+-0.108)        |   47.359 (+-0.942)  |          179.520 (+-2.916)
      channels=3, size=520, dtype=torch.uint8, mf=channels_last   |        146.802 (+-2.175)        |  174.201 (+-4.124)  |          707.765 (+-2.691)
      channels=3, size=712, dtype=torch.uint8, mf=channels_last   |        215.148 (+-0.925)        |  313.606 (+-3.972)  |         1346.678 (+-89.854)

      channels=3, size=256, dtype=torch.int8, mf=channels_last    |         43.618 (+-0.160)        |                     |          191.613 (+-16.252)
      channels=3, size=520, dtype=torch.int8, mf=channels_last    |        147.487 (+-0.691)        |                     |          755.020 (+-25.045)
      channels=3, size=712, dtype=torch.int8, mf=channels_last    |        216.687 (+-0.906)        |                     |         1314.854 (+-31.137)

      channels=4, size=256, dtype=torch.uint8, mf=channels_last   |         32.169 (+-0.092)        |                     |          195.415 (+-3.647)
      channels=4, size=520, dtype=torch.uint8, mf=channels_last   |         89.465 (+-0.154)        |                     |          776.459 (+-14.845)
      channels=4, size=712, dtype=torch.uint8, mf=channels_last   |        152.773 (+-0.610)        |                     |         1456.304 (+-45.280)

      channels=8, size=256, dtype=torch.uint8, mf=channels_last   |         43.444 (+-0.158)        |                     |          163.669 (+-4.580)
      channels=8, size=520, dtype=torch.uint8, mf=channels_last   |        151.285 (+-0.602)        |                     |          642.396 (+-13.500)
      channels=8, size=712, dtype=torch.uint8, mf=channels_last   |        278.471 (+-0.912)        |                     |         1205.472 (+-47.609)

      channels=16, size=256, dtype=torch.uint8, mf=channels_last  |         75.176 (+-0.188)        |                     |          181.278 (+-3.388)
      channels=16, size=520, dtype=torch.uint8, mf=channels_last  |        291.105 (+-1.163)        |                     |          716.906 (+-30.842)
      channels=16, size=712, dtype=torch.uint8, mf=channels_last  |        893.267 (+-10.899)       |                     |         1434.931 (+-40.399)

      channels=2, size=256, dtype=torch.int16, mf=channels_last   |         31.437 (+-0.143)        |                     |          195.299 (+-2.916)
      channels=2, size=520, dtype=torch.int16, mf=channels_last   |         89.834 (+-0.175)        |                     |          774.940 (+-8.638)
      channels=2, size=712, dtype=torch.int16, mf=channels_last   |        154.806 (+-0.550)        |                     |         1443.435 (+-37.799)

      channels=3, size=256, dtype=torch.int16, mf=channels_last   |         70.909 (+-0.146)        |                     |          195.347 (+-1.986)
      channels=3, size=520, dtype=torch.int16, mf=channels_last   |        212.998 (+-1.181)        |                     |          776.282 (+-15.598)
      channels=3, size=712, dtype=torch.int16, mf=channels_last   |        382.991 (+-0.968)        |                     |          1441.674 (+-9.873)

      channels=4, size=256, dtype=torch.int16, mf=channels_last   |         43.574 (+-0.157)        |                     |          163.176 (+-1.941)
      channels=4, size=520, dtype=torch.int16, mf=channels_last   |        151.289 (+-0.557)        |                     |          641.169 (+-9.457)
      channels=4, size=712, dtype=torch.int16, mf=channels_last   |        275.275 (+-0.874)        |                     |         1186.589 (+-12.063)

      channels=8, size=256, dtype=torch.int16, mf=channels_last   |         74.455 (+-0.292)        |                     |          181.191 (+-1.721)
      channels=8, size=520, dtype=torch.int16, mf=channels_last   |        289.591 (+-1.134)        |                     |          715.755 (+-2.368)
      channels=8, size=712, dtype=torch.int16, mf=channels_last   |        923.831 (+-68.807)       |                     |         1437.078 (+-14.649)

      channels=2, size=256, dtype=torch.int32, mf=channels_last   |         44.217 (+-0.203)        |                     |          163.011 (+-1.497)
      channels=2, size=520, dtype=torch.int32, mf=channels_last   |        150.920 (+-0.950)        |                     |          640.761 (+-1.882)
      channels=2, size=712, dtype=torch.int32, mf=channels_last   |        281.648 (+-1.163)        |                     |         1188.464 (+-10.374)

      channels=3, size=256, dtype=torch.int32, mf=channels_last   |        103.708 (+-0.517)        |                     |          165.001 (+-1.315)
      channels=3, size=520, dtype=torch.int32, mf=channels_last   |        409.785 (+-8.004)        |                     |          647.939 (+-11.431)
      channels=3, size=712, dtype=torch.int32, mf=channels_last   |        790.819 (+-16.471)       |                     |          1219.206 (+-9.503)

      channels=4, size=256, dtype=torch.int32, mf=channels_last   |         72.975 (+-0.155)        |                     |          181.298 (+-1.059)
      channels=4, size=520, dtype=torch.int32, mf=channels_last   |        291.584 (+-0.905)        |                     |          716.033 (+-4.824)
      channels=4, size=712, dtype=torch.int32, mf=channels_last   |        938.790 (+-15.930)       |                     |         1434.134 (+-15.060)

Times are in microseconds (us).
```

[Source](https://gist.github.com/vfdev-5/8e8c989d35835d7ab20567bff36632be#file-20230123-143303-pr_vs_nightly-md)

## Context:

Follow-up work to PRs : https://github.com/pytorch/pytorch/pull/88989, https://github.com/pytorch/pytorch/pull/89414 and https://github.com/pytorch/pytorch/pull/90013

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91806
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This commit is contained in:
vfdev-5 2023-01-23 20:15:26 +00:00 committed by PyTorch MergeBot
parent a112814a7f
commit e994e78397
2 changed files with 179 additions and 1 deletions

View file

@ -569,6 +569,145 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) {
iter.cast_outputs();
}
constexpr int64_t hflip_mask_size = 32;
std::array<char, hflip_mask_size> generate_vec_hflip_reg_mask(int64_t data_stride) {
std::array<char, hflip_mask_size> mask;
for (const auto k : c10::irange(hflip_mask_size / 2)) {
int j = k / data_stride + 1;
int v = (j * data_stride - 1) - (k % data_stride);
v = std::min(v, (int) (hflip_mask_size / 2 - 1));
mask[hflip_mask_size - 1 - k] = v;
mask[hflip_mask_size / 2 - 1 - k] = v;
}
return mask;
}
int64_t vectorized_cpu_hflip_channels_last(
char * C10_RESTRICT *data, const int64_t data_size, const int64_t data_stride, const std::array<char, 32> & mdata) {
int64_t i = 0;
#ifdef CPU_CAPABILITY_AVX2
constexpr auto vec_size = 256 / 8;
if (data_size > vec_size) {
// Example for num channels=3 and dtype=uint8
// -> data_stride = 3
// -> usable_vec_stride = 30
// -> usable_vec_half_stride = 15
// Data: (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 32 33)
// load by 2 parts
// R = [ (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 | (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 ]
// flip(R) ->
// R = [ 31 (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) | 16 (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3) ]
//
// Write in 2 parts
// Output pointer: output_ptr = data[0] v
// - Init:
// (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
// 0) Move to initial position: output_ptr = data[0] + data_stride - vec_size / 2;
// v
// (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
// - In the loop:
// 1) Write 1st block from output_ptr
// v
// |----> vec_size / 2 ---------------------------|
// Output part 1: (X X X) (X X X) (X X X) (X X X) (X X X) (X X 16) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
// 2) Write 2nd block from output_ptr - usable_vec_half_stride:
// v
// |-----> vec_size / 2 ----------------------------------|
// Output part 2: (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
//
// 3) Move to the next position: output_ptr -= usable_vec_stride
//
// - After the loop:
// 4) Move to write position
// v
// (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
const __m256i mask = _mm256_loadu_si256((__m256i *) mdata.data());
const auto usable_vec_stride = 2 * (vec_size / 2 / data_stride) * data_stride;
const auto usable_vec_half_stride = usable_vec_stride / 2;
auto output_ptr = data[0] + data_stride - vec_size / 2;
auto input_ptr = data[1];
for (; i < data_size - vec_size; i += usable_vec_stride) {
// load 256-bits by two 128-bits parts
auto a0 = _mm_loadu_si128((__m128i *) (input_ptr + i));
auto b0 = _mm256_castsi128_si256(a0);
auto a1 = _mm_loadu_si128((__m128i *) (input_ptr + i + usable_vec_half_stride));
auto data_vec = _mm256_inserti128_si256(b0, a1, 1);
auto reversed_vec = _mm256_shuffle_epi8(data_vec, mask);
// write output in two parts
auto rev_vec_h = _mm256_extracti128_si256(reversed_vec, 0);
_mm_storeu_si128((__m128i *) (output_ptr - i), rev_vec_h);
auto rev_vec_l = _mm256_extracti128_si256(reversed_vec, 1);
_mm_storeu_si128((__m128i *) (output_ptr - i - usable_vec_half_stride), rev_vec_l);
}
data[0] -= i;
data[1] += i;
}
#endif
return i;
}
void cpu_hflip_channels_last_vec(at::TensorIterator& iter) {
auto input_strides = iter.strides(1);
const auto data_stride = input_strides[1];
// Generate avx mask once
alignas(hflip_mask_size) auto mdata = generate_vec_hflip_reg_mask(data_stride);
auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
// Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
// and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
// output and input.
static constexpr int ntensors = 2;
const int64_t *outer_strides = &strides[3];
const int64_t stride = strides[0];
TORCH_INTERNAL_ASSERT(stride == strides[1]);
auto c = -outer_strides[0];
TORCH_INTERNAL_ASSERT(c == outer_strides[1]);
char* C10_RESTRICT data[ntensors] = {base[0], base[1]};
const int64_t size = size0 * size1;
int64_t i = 0;
if (c >= 2 && c <= 16) {
i = vectorized_cpu_hflip_channels_last(data, size * stride, c, mdata) / stride;
}
auto data_stride = size0 * stride;
for (; i < size; i += size0) {
memcpy(data[0], data[1], data_stride);
// advance:
for (const auto arg : c10::irange(ntensors)) {
data[arg] += outer_strides[arg];
}
}
};
int64_t grain_size = at::internal::GRAIN_SIZE;
iter.for_each(loop2d, grain_size);
iter.cast_outputs();
}
void flip_kernel(TensorIterator& iter, const bool quantized) {
if (quantized) {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu",
@ -613,10 +752,21 @@ void flip_kernel(TensorIterator& iter, const bool quantized) {
} else if (iter_dtype == kDouble) {
return cpu_hflip_vec<double>(iter);
}
}
// other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
} else if (iter.has_contiguous_first_dim()) {
// Special cases:
// a) channels last hflip on (N, C, H, W) and outer_stride(=dtype_size * C) in [2, 16]
// b) flip dim=-2 on (N, ..., M, C) and outer_stride(=dtype_size * C) in [2, 16]
auto output_strides = iter.strides(0);
auto input_strides = iter.strides(1);
auto c = -output_strides[1];
if (c >= 2 && c <= 16 &&
c == input_strides[1] &&
c == iter.element_size(0) * iter.shape()[0] // checks if dim=1 is contiguous as well
) {
return cpu_hflip_channels_last_vec(iter);
}
// Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
return cpu_vflip_memcpy(iter);
}

View file

@ -406,6 +406,34 @@ class TestShapeOps(TestCase):
out_t = make_from_data([[3, 2, 1], [6, 5, 4]])
yield in_t, dims, out_t
# vectorized NCHW cases (images)
if device == "cpu" and dtype != torch.bfloat16:
for mf in [torch.contiguous_format, torch.channels_last]:
for c in [2, 3, 8, 16]:
in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf)
np_in_t = in_t.numpy()
np_out_t = np_in_t[:, :, :, ::-1].copy()
out_t = torch.from_numpy(np_out_t)
yield in_t, 3, out_t
np_out_t = np_in_t[:, :, ::-1, :].copy()
out_t = torch.from_numpy(np_out_t)
yield in_t, 2, out_t
# non-contig cases
in_tt = in_t[..., ::2, :]
np_in_t = in_tt.numpy()
np_out_t = np_in_t[:, :, :, ::-1].copy()
out_t = torch.from_numpy(np_out_t)
yield in_tt, 3, out_t
in_tt = in_t[..., ::2]
np_in_t = in_tt.numpy()
np_out_t = np_in_t[:, :, :, ::-1].copy()
out_t = torch.from_numpy(np_out_t)
yield in_tt, 3, out_t
# Noops (edge cases)
# Size 0