mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Added vectorized horizontal flip path for channels last for NcHW (#91806)
## Description - Added AVX2-only vectorization for horizontal flip op applied on channels last NCHW input, where **2 <= C * sizeof(dtype) <= 16**. PR is a bit faster than Pillow and largely faster (x2 - x5) than Nightly. - ~Still keeping `cpu_vflip_memcpy` code ([it's PR](https://github.com/pytorch/pytorch/pull/89414) was reverted and is under investigations)~ ## Benchmarks ``` [---------------------------------------------------------------------- Horizontal flip ----------------------------------------------------------------------] | torch (2.0.0a0+gitf6d73f3) PR | Pillow (9.4.0) | torch (2.0.0a0+git4386f31) nightly 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------- channels=2, size=256, dtype=torch.uint8, mf=channels_last | 31.859 (+-0.498) | | 190.599 (+-7.579) channels=2, size=520, dtype=torch.uint8, mf=channels_last | 60.648 (+-0.074) | | 706.895 (+-11.219) channels=2, size=712, dtype=torch.uint8, mf=channels_last | 95.994 (+-2.510) | | 1340.685 (+-169.279) channels=3, size=256, dtype=torch.uint8, mf=channels_last | 45.490 (+-0.108) | 47.359 (+-0.942) | 179.520 (+-2.916) channels=3, size=520, dtype=torch.uint8, mf=channels_last | 146.802 (+-2.175) | 174.201 (+-4.124) | 707.765 (+-2.691) channels=3, size=712, dtype=torch.uint8, mf=channels_last | 215.148 (+-0.925) | 313.606 (+-3.972) | 1346.678 (+-89.854) channels=3, size=256, dtype=torch.int8, mf=channels_last | 43.618 (+-0.160) | | 191.613 (+-16.252) channels=3, size=520, dtype=torch.int8, mf=channels_last | 147.487 (+-0.691) | | 755.020 (+-25.045) channels=3, size=712, dtype=torch.int8, mf=channels_last | 216.687 (+-0.906) | | 1314.854 (+-31.137) channels=4, size=256, dtype=torch.uint8, mf=channels_last | 32.169 (+-0.092) | | 195.415 (+-3.647) channels=4, size=520, dtype=torch.uint8, mf=channels_last | 89.465 (+-0.154) | | 776.459 (+-14.845) channels=4, size=712, dtype=torch.uint8, mf=channels_last | 152.773 (+-0.610) | | 1456.304 (+-45.280) channels=8, size=256, dtype=torch.uint8, mf=channels_last | 43.444 (+-0.158) | | 163.669 (+-4.580) channels=8, size=520, dtype=torch.uint8, mf=channels_last | 151.285 (+-0.602) | | 642.396 (+-13.500) channels=8, size=712, dtype=torch.uint8, mf=channels_last | 278.471 (+-0.912) | | 1205.472 (+-47.609) channels=16, size=256, dtype=torch.uint8, mf=channels_last | 75.176 (+-0.188) | | 181.278 (+-3.388) channels=16, size=520, dtype=torch.uint8, mf=channels_last | 291.105 (+-1.163) | | 716.906 (+-30.842) channels=16, size=712, dtype=torch.uint8, mf=channels_last | 893.267 (+-10.899) | | 1434.931 (+-40.399) channels=2, size=256, dtype=torch.int16, mf=channels_last | 31.437 (+-0.143) | | 195.299 (+-2.916) channels=2, size=520, dtype=torch.int16, mf=channels_last | 89.834 (+-0.175) | | 774.940 (+-8.638) channels=2, size=712, dtype=torch.int16, mf=channels_last | 154.806 (+-0.550) | | 1443.435 (+-37.799) channels=3, size=256, dtype=torch.int16, mf=channels_last | 70.909 (+-0.146) | | 195.347 (+-1.986) channels=3, size=520, dtype=torch.int16, mf=channels_last | 212.998 (+-1.181) | | 776.282 (+-15.598) channels=3, size=712, dtype=torch.int16, mf=channels_last | 382.991 (+-0.968) | | 1441.674 (+-9.873) channels=4, size=256, dtype=torch.int16, mf=channels_last | 43.574 (+-0.157) | | 163.176 (+-1.941) channels=4, size=520, dtype=torch.int16, mf=channels_last | 151.289 (+-0.557) | | 641.169 (+-9.457) channels=4, size=712, dtype=torch.int16, mf=channels_last | 275.275 (+-0.874) | | 1186.589 (+-12.063) channels=8, size=256, dtype=torch.int16, mf=channels_last | 74.455 (+-0.292) | | 181.191 (+-1.721) channels=8, size=520, dtype=torch.int16, mf=channels_last | 289.591 (+-1.134) | | 715.755 (+-2.368) channels=8, size=712, dtype=torch.int16, mf=channels_last | 923.831 (+-68.807) | | 1437.078 (+-14.649) channels=2, size=256, dtype=torch.int32, mf=channels_last | 44.217 (+-0.203) | | 163.011 (+-1.497) channels=2, size=520, dtype=torch.int32, mf=channels_last | 150.920 (+-0.950) | | 640.761 (+-1.882) channels=2, size=712, dtype=torch.int32, mf=channels_last | 281.648 (+-1.163) | | 1188.464 (+-10.374) channels=3, size=256, dtype=torch.int32, mf=channels_last | 103.708 (+-0.517) | | 165.001 (+-1.315) channels=3, size=520, dtype=torch.int32, mf=channels_last | 409.785 (+-8.004) | | 647.939 (+-11.431) channels=3, size=712, dtype=torch.int32, mf=channels_last | 790.819 (+-16.471) | | 1219.206 (+-9.503) channels=4, size=256, dtype=torch.int32, mf=channels_last | 72.975 (+-0.155) | | 181.298 (+-1.059) channels=4, size=520, dtype=torch.int32, mf=channels_last | 291.584 (+-0.905) | | 716.033 (+-4.824) channels=4, size=712, dtype=torch.int32, mf=channels_last | 938.790 (+-15.930) | | 1434.134 (+-15.060) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/8e8c989d35835d7ab20567bff36632be#file-20230123-143303-pr_vs_nightly-md) ## Context: Follow-up work to PRs : https://github.com/pytorch/pytorch/pull/88989, https://github.com/pytorch/pytorch/pull/89414 and https://github.com/pytorch/pytorch/pull/90013 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91806 Approved by: https://github.com/peterbell10, https://github.com/lezcano
This commit is contained in:
parent
a112814a7f
commit
e994e78397
2 changed files with 179 additions and 1 deletions
|
|
@ -569,6 +569,145 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) {
|
|||
iter.cast_outputs();
|
||||
}
|
||||
|
||||
constexpr int64_t hflip_mask_size = 32;
|
||||
|
||||
std::array<char, hflip_mask_size> generate_vec_hflip_reg_mask(int64_t data_stride) {
|
||||
std::array<char, hflip_mask_size> mask;
|
||||
for (const auto k : c10::irange(hflip_mask_size / 2)) {
|
||||
int j = k / data_stride + 1;
|
||||
int v = (j * data_stride - 1) - (k % data_stride);
|
||||
v = std::min(v, (int) (hflip_mask_size / 2 - 1));
|
||||
mask[hflip_mask_size - 1 - k] = v;
|
||||
mask[hflip_mask_size / 2 - 1 - k] = v;
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
int64_t vectorized_cpu_hflip_channels_last(
|
||||
char * C10_RESTRICT *data, const int64_t data_size, const int64_t data_stride, const std::array<char, 32> & mdata) {
|
||||
|
||||
int64_t i = 0;
|
||||
#ifdef CPU_CAPABILITY_AVX2
|
||||
|
||||
constexpr auto vec_size = 256 / 8;
|
||||
|
||||
if (data_size > vec_size) {
|
||||
|
||||
// Example for num channels=3 and dtype=uint8
|
||||
// -> data_stride = 3
|
||||
// -> usable_vec_stride = 30
|
||||
// -> usable_vec_half_stride = 15
|
||||
// Data: (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 32 33)
|
||||
// load by 2 parts
|
||||
// R = [ (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 | (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 ]
|
||||
// flip(R) ->
|
||||
// R = [ 31 (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) | 16 (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3) ]
|
||||
//
|
||||
// Write in 2 parts
|
||||
// Output pointer: output_ptr = data[0] v
|
||||
// - Init:
|
||||
// (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
|
||||
// 0) Move to initial position: output_ptr = data[0] + data_stride - vec_size / 2;
|
||||
// v
|
||||
// (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
|
||||
// - In the loop:
|
||||
// 1) Write 1st block from output_ptr
|
||||
// v
|
||||
// |----> vec_size / 2 ---------------------------|
|
||||
// Output part 1: (X X X) (X X X) (X X X) (X X X) (X X X) (X X 16) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
|
||||
// 2) Write 2nd block from output_ptr - usable_vec_half_stride:
|
||||
// v
|
||||
// |-----> vec_size / 2 ----------------------------------|
|
||||
// Output part 2: (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
|
||||
//
|
||||
// 3) Move to the next position: output_ptr -= usable_vec_stride
|
||||
//
|
||||
// - After the loop:
|
||||
// 4) Move to write position
|
||||
// v
|
||||
// (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
|
||||
|
||||
const __m256i mask = _mm256_loadu_si256((__m256i *) mdata.data());
|
||||
|
||||
const auto usable_vec_stride = 2 * (vec_size / 2 / data_stride) * data_stride;
|
||||
const auto usable_vec_half_stride = usable_vec_stride / 2;
|
||||
|
||||
auto output_ptr = data[0] + data_stride - vec_size / 2;
|
||||
auto input_ptr = data[1];
|
||||
|
||||
for (; i < data_size - vec_size; i += usable_vec_stride) {
|
||||
|
||||
// load 256-bits by two 128-bits parts
|
||||
auto a0 = _mm_loadu_si128((__m128i *) (input_ptr + i));
|
||||
auto b0 = _mm256_castsi128_si256(a0);
|
||||
auto a1 = _mm_loadu_si128((__m128i *) (input_ptr + i + usable_vec_half_stride));
|
||||
auto data_vec = _mm256_inserti128_si256(b0, a1, 1);
|
||||
|
||||
auto reversed_vec = _mm256_shuffle_epi8(data_vec, mask);
|
||||
|
||||
// write output in two parts
|
||||
auto rev_vec_h = _mm256_extracti128_si256(reversed_vec, 0);
|
||||
_mm_storeu_si128((__m128i *) (output_ptr - i), rev_vec_h);
|
||||
auto rev_vec_l = _mm256_extracti128_si256(reversed_vec, 1);
|
||||
_mm_storeu_si128((__m128i *) (output_ptr - i - usable_vec_half_stride), rev_vec_l);
|
||||
}
|
||||
|
||||
data[0] -= i;
|
||||
data[1] += i;
|
||||
}
|
||||
#endif
|
||||
return i;
|
||||
}
|
||||
|
||||
void cpu_hflip_channels_last_vec(at::TensorIterator& iter) {
|
||||
|
||||
auto input_strides = iter.strides(1);
|
||||
const auto data_stride = input_strides[1];
|
||||
|
||||
// Generate avx mask once
|
||||
alignas(hflip_mask_size) auto mdata = generate_vec_hflip_reg_mask(data_stride);
|
||||
|
||||
auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
|
||||
|
||||
// Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
|
||||
// and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
|
||||
// output and input.
|
||||
static constexpr int ntensors = 2;
|
||||
const int64_t *outer_strides = &strides[3];
|
||||
const int64_t stride = strides[0];
|
||||
|
||||
TORCH_INTERNAL_ASSERT(stride == strides[1]);
|
||||
|
||||
auto c = -outer_strides[0];
|
||||
TORCH_INTERNAL_ASSERT(c == outer_strides[1]);
|
||||
|
||||
char* C10_RESTRICT data[ntensors] = {base[0], base[1]};
|
||||
const int64_t size = size0 * size1;
|
||||
|
||||
int64_t i = 0;
|
||||
|
||||
if (c >= 2 && c <= 16) {
|
||||
i = vectorized_cpu_hflip_channels_last(data, size * stride, c, mdata) / stride;
|
||||
}
|
||||
|
||||
auto data_stride = size0 * stride;
|
||||
for (; i < size; i += size0) {
|
||||
|
||||
memcpy(data[0], data[1], data_stride);
|
||||
|
||||
// advance:
|
||||
for (const auto arg : c10::irange(ntensors)) {
|
||||
data[arg] += outer_strides[arg];
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE;
|
||||
iter.for_each(loop2d, grain_size);
|
||||
iter.cast_outputs();
|
||||
}
|
||||
|
||||
void flip_kernel(TensorIterator& iter, const bool quantized) {
|
||||
if (quantized) {
|
||||
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu",
|
||||
|
|
@ -613,10 +752,21 @@ void flip_kernel(TensorIterator& iter, const bool quantized) {
|
|||
} else if (iter_dtype == kDouble) {
|
||||
return cpu_hflip_vec<double>(iter);
|
||||
}
|
||||
|
||||
}
|
||||
// other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
|
||||
} else if (iter.has_contiguous_first_dim()) {
|
||||
// Special cases:
|
||||
// a) channels last hflip on (N, C, H, W) and outer_stride(=dtype_size * C) in [2, 16]
|
||||
// b) flip dim=-2 on (N, ..., M, C) and outer_stride(=dtype_size * C) in [2, 16]
|
||||
auto output_strides = iter.strides(0);
|
||||
auto input_strides = iter.strides(1);
|
||||
auto c = -output_strides[1];
|
||||
if (c >= 2 && c <= 16 &&
|
||||
c == input_strides[1] &&
|
||||
c == iter.element_size(0) * iter.shape()[0] // checks if dim=1 is contiguous as well
|
||||
) {
|
||||
return cpu_hflip_channels_last_vec(iter);
|
||||
}
|
||||
// Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
|
||||
return cpu_vflip_memcpy(iter);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -406,6 +406,34 @@ class TestShapeOps(TestCase):
|
|||
out_t = make_from_data([[3, 2, 1], [6, 5, 4]])
|
||||
yield in_t, dims, out_t
|
||||
|
||||
# vectorized NCHW cases (images)
|
||||
if device == "cpu" and dtype != torch.bfloat16:
|
||||
for mf in [torch.contiguous_format, torch.channels_last]:
|
||||
for c in [2, 3, 8, 16]:
|
||||
in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf)
|
||||
np_in_t = in_t.numpy()
|
||||
|
||||
np_out_t = np_in_t[:, :, :, ::-1].copy()
|
||||
out_t = torch.from_numpy(np_out_t)
|
||||
yield in_t, 3, out_t
|
||||
|
||||
np_out_t = np_in_t[:, :, ::-1, :].copy()
|
||||
out_t = torch.from_numpy(np_out_t)
|
||||
yield in_t, 2, out_t
|
||||
|
||||
# non-contig cases
|
||||
in_tt = in_t[..., ::2, :]
|
||||
np_in_t = in_tt.numpy()
|
||||
np_out_t = np_in_t[:, :, :, ::-1].copy()
|
||||
out_t = torch.from_numpy(np_out_t)
|
||||
yield in_tt, 3, out_t
|
||||
|
||||
in_tt = in_t[..., ::2]
|
||||
np_in_t = in_tt.numpy()
|
||||
np_out_t = np_in_t[:, :, :, ::-1].copy()
|
||||
out_t = torch.from_numpy(np_out_t)
|
||||
yield in_tt, 3, out_t
|
||||
|
||||
# Noops (edge cases)
|
||||
|
||||
# Size 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue