[Inductor] Add support for NEON ISA in the Inductor C++ backend (#105590)

Fixes #104729

As suggested in the [blog](https://dev-discuss.pytorch.org/t/torchinductor-update-5-cpu-backend-backend-performance-update-and-deep-dive-on-key-optimizations/1117#:~:text=It%20can%20be,sub%2Dclasses.), I subclassed the `VecISA` class and implemented a NEON version of the `vec_reduce_all()` function, to go along with the existing AVX2 and AVX512 versions. Any operation that calls `vec_reduce_all()` will also take the NEON path and benefit from its vectorization.

The `vec_reduce_all()` is invoked by Softmax and other operations like norms. Using the fast path results in 30% time savings for Softmax as compared to the previously taken slow path.

  | Slow path | Fast path (NEON intrinsics)
-- | -- | --
Softmax (100 passes, 1024 dimension) | 623.706ms | 452.011ms

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105590
Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
Rohan 2024-02-22 23:55:32 +00:00 committed by PyTorch MergeBot
parent 4c6ba16f82
commit 156954d6a2
4 changed files with 64 additions and 5 deletions

View file

@ -76,6 +76,34 @@ struct VecReduceAllSIMD<float, Op> {
}
};
#endif // defined(CPU_CAPABILITY_AVX512)
#if defined(CPU_CAPABILITY_NEON)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4]
Vec v1 = {v.get_high(), v.get_low()};
// [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required)
v = vec_fun(v, v1);
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2);
v1 = {v1_1, v1_1};
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
v = vec_fun(v, v1);
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
v1_1 = vrev64q_f32(v.get_low());
v1 = {v1_1, v1_1};
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
v = vec_fun(v, v1);
return v.get_low()[0];
}
};
#endif // defined(CPU_CAPABILITY_NEON)
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
template <typename scalar_t, typename Op>

View file

@ -1527,19 +1527,25 @@ class CPUReproTests(TestCase):
def test_auto_simd(self):
vec_avx512 = codecache.supported_vec_isa_list[0]
vec_avx2 = codecache.supported_vec_isa_list[1]
vec_neon = codecache.supported_vec_isa_list[2]
self.assertTrue(vec_avx512.bit_width() == 512)
self.assertTrue(vec_avx2.bit_width() == 256)
self.assertTrue(vec_neon.bit_width() == 256)
self.assertTrue(vec_avx512.nelements() == 16)
self.assertTrue(vec_avx2.nelements() == 8)
self.assertTrue(vec_neon.nelements() == 8)
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
self.assertTrue(vec_neon.nelements(torch.bfloat16) == 16)
with config.patch({"cpp.simdlen": None}):
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
elif vec_avx2 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)
else:
self.assertTrue(isa == vec_neon)
with config.patch({"cpp.simdlen": 0}):
isa = codecache.pick_vec_isa()
@ -1569,6 +1575,9 @@ class CPUReproTests(TestCase):
if vec_avx2 in isa_list:
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_avx2)
elif vec_neon in isa_list:
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_neon)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"

View file

@ -953,7 +953,7 @@ class VecISA:
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
# making the runtime check unnecessary.
_avx_code = """
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#endif
@ -1026,6 +1026,19 @@ cdll.LoadLibrary("__lib_path__")
return True
@dataclasses.dataclass
class VecNEON(VecISA):
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
_macro = "-DCPU_CAPABILITY_NEON"
_arch_flags = "" # Unused
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16}
def __str__(self) -> str:
return "neon" # Unused
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
@dataclasses.dataclass
class VecAVX512(VecISA):
_bit_width = 512
@ -1081,7 +1094,11 @@ class InvalidVecISA(VecISA):
invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAVX512(), VecAVX2()]
supported_vec_isa_list = [
VecAVX512(),
VecAVX2(),
VecNEON(),
] # This order matters for test_cpu_repro
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
@ -1099,7 +1116,12 @@ def valid_vec_isa_list() -> List[VecISA]:
with open("/proc/cpuinfo") as _cpu_info:
_cpu_info_content = _cpu_info.read()
for isa in supported_vec_isa_list:
if str(isa) in _cpu_info_content and isa:
# cpuinfo does not reveal info about NEON support. All aarch64 processors do support NEON though.
if (
(str(isa) in _cpu_info_content)
or (isinstance(isa, VecNEON) and platform.processor() == "aarch64")
and isa
):
isa_list.append(isa)
return isa_list

View file

@ -19,7 +19,7 @@
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
#define INDUCTOR_USE_VECTOR_TYPES() 1
#else
#define INDUCTOR_USE_VECTOR_TYPES() 0