diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 1b46916cb92..83e4a5f915d 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -10,13 +10,9 @@ endif() file(GLOB common_srcs *.cc) file(GLOB avx_srcs *_avx.cc) file(GLOB avx2_srcs *_avx2.cc) -file(GLOB avx512_srcs *_avx512.cc) -file(GLOB sve_srcs *_sve.cc) -# exclude avx, avx2, avx512, and sve srcs from common_srcs +# exclude avx and avx2 srcs from common_srcs exclude(common_srcs "${common_srcs}" ${avx_srcs}) exclude(common_srcs "${common_srcs}" ${avx2_srcs}) -exclude(common_srcs "${common_srcs}" ${avx512_srcs}) -exclude(common_srcs "${common_srcs}" ${sve_srcs}) # We will always build common srcs. set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) @@ -46,22 +42,6 @@ if(CXX_AVX2_FOUND) "Caffe2_perfkernels_avx2_interface") endif() -# We will only build the SVE perfkernel files if the compiler supports SVE -# extensions. -if(CXX_SVE_FOUND) - add_library(Caffe2_perfkernels_sve STATIC ${sve_srcs}) - target_link_libraries(Caffe2_perfkernels_sve PRIVATE c10) - install(TARGETS Caffe2_perfkernels_sve - ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") - - target_compile_options(Caffe2_perfkernels_sve PRIVATE "-march=armv8-a+sve") - - caffe2_interface_library( - Caffe2_perfkernels_sve Caffe2_perfkernels_sve_interface) - list(APPEND - Caffe2_DEPENDENCY_WHOLE_LINK_LIBS "Caffe2_perfkernels_sve_interface") -endif() - # TODO(jiayq): currently, we only implement the very base files for the # perfkernels. This is because to implement avx and avx2 files, we actually # need to set up different compilation units and this is a bit more involving diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index 6e069861b28..6fed9e1d6d0 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -61,8 +61,9 @@ In foo.cc, do: // we use cpuinfo to identify cpu support and run the proper functions. #pragma once -#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \ - defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX) + +#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \ + || defined(CAFFE2_PERF_WITH_AVX) #include #endif @@ -71,18 +72,6 @@ In foo.cc, do: #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); -#ifdef CAFFE2_PERF_WITH_SVE -#define SVE_DO(funcname, ...) \ - { \ - static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \ - if (isDo) { \ - return funcname##__sve(__VA_ARGS__); \ - } \ - } -#else // CAFFE2_PERF_WITH_SVE -#define SVE_DO(funcname, ...) -#endif // CAFFE2_PERF_WITH_SVE - #ifdef CAFFE2_PERF_WITH_AVX512 #define AVX512_DO(funcname, ...) \ { \ diff --git a/caffe2/perfkernels/common_sve.cc b/caffe2/perfkernels/common_sve.cc deleted file mode 100644 index 03b0bf983c8..00000000000 --- a/caffe2/perfkernels/common_sve.cc +++ /dev/null @@ -1,22 +0,0 @@ -// This file is here merely to check that the flags are not mixed up: for -// example, if your compiler did not specify -march=armv8-a+sve, you should not -// provide the CAFFE2_PERF_WITH_SVE macro. - -#include "caffe2/core/common.h" - -#ifdef CAFFE2_PERF_WITH_SVE -#ifndef __ARM_FEATURE_SVE -#error( \ - "You found a build system error: CAFFE2_PERF_WITH_SVE is defined" \ - "but __ARM_FEATURE_SVE is not defined (via e.g. -march=armv8-a+sve)."); -#endif // __ARM_FEATURE_SVE -#endif // CAFFE2_PERF_WITH_SVE - -#ifdef __ARM_FEATURE_SVE -#ifndef CAFFE2_PERF_WITH_SVE -#error( \ - "You found a build system error: __SVE__ is defined \ - (via e.g. -march=armv8-a+sve) " \ - "but CAFFE2_PERF_WITH_SVE is not defined."); -#endif // CAFFE2_PERF_WITH_SVE -#endif diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 7c62d9e883f..5fcf71016ae 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -88,7 +88,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -113,9 +113,6 @@ static bool EmbeddingLookupGenericSlowIdx( decltype( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ - decltype( \ - EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ - EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__sve; \ bool \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ const int64_t block_size, \ @@ -124,7 +121,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -134,19 +131,6 @@ static bool EmbeddingLookupGenericSlowIdx( } else { \ CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ } \ - SVE_DO( \ - EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ AVX2_FMA_DO( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ block_size, \ @@ -182,7 +166,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ diff --git a/caffe2/perfkernels/embedding_lookup_idx_sve.cc b/caffe2/perfkernels/embedding_lookup_idx_sve.cc deleted file mode 100644 index 873823536b5..00000000000 --- a/caffe2/perfkernels/embedding_lookup_idx_sve.cc +++ /dev/null @@ -1,6769 +0,0 @@ -//// -------------------------- -//// ATTENTION: -//// THIS CODE IS AUTOGENERATED -//// BY sve_emblookup_codegen.py -//// DO NOT MODIFY!!! -//// -------------------------- - -#include -#include -#include -#include -#include -namespace caffe2 { - -template -static bool EmbeddingLookupIdx_int32_t_float_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int32_t_float_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_float_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int32_t_float_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_float_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int64_t_float_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int64_t_float_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_float_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int64_t_float_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const float* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_float_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int32_t_half_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int32_t_half_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_half_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int32_t_half_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_half_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int64_t_half_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int64_t_half_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_half_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int64_t_half_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::Half* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_half_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const at::BFloat16* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int32_t* indices, - const int32_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -template -static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - const svbool_t svAll = svptrue_b32(); - const auto vLen = static_cast(svcntw()); - int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } - } - } else if (block_size == 16 * vLen) { - // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - } - } - } else if (block_size == 8 * vLen) { - // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - } - } - } else if (block_size == 4 * vLen) { - // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - } - } - } else if (block_size == 2 * vLen) { - // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - } - } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; - } - const int64_t length = end_offset - start_offset; - - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } - } - } - } - return pos == index_size; -} -bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} -bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const uint8_t* input, - const int64_t* indices, - const int64_t* offsets, - const float* weights, - const float* scale_bias, - bool normalize_by_lengths, - float* out) { - return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( - block_size, - output_size, - index_size, - data_size, - input, - indices, - offsets, - weights, - scale_bias, - normalize_by_lengths, - out); -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py deleted file mode 100644 index 02f010ccc25..00000000000 --- a/caffe2/perfkernels/sve_emblookup_codegen.py +++ /dev/null @@ -1,408 +0,0 @@ -# mypy: allow-untyped-defs -import argparse -import sys - -# Unroll loops when block_size is a multiple of vector length. -def unroll(num_unrolls, IndexType, InType, OutType, use_weights): - def compute(regid, InType, use_weights): - code = [] - - if InType == "float": - code.append( - f" vsum{regid} =\n" - " svmad_f32_x(" - f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen])," - f" vsum{regid});" - ) - elif InType == "at::Half": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " svAll,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])))),\n" # noqa - f" vsum{regid});" - ) - elif InType == "at::BFloat16": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " svAll,\n" - " svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])),\n" - " 16)),\n" # noqa - f" vsum{regid});" - ) - elif InType == "uint8_t": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_u32_x(svAll," - f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa - f" svadd_f32_x(svAll, vsum{regid}, vbio));" - ) - else: - raise ValueError(f"Unknown datatype \"{InType}\"") - - return code - - code = [] - code.append(f" // unrolling {num_unrolls} times") - - code.append(" for (int64_t i = 0; i < output_size; ++i) {") - - code.append(" " + OutType + "* const op = &out[i * block_size];") - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - - # Initialise vector sum registers - for i in range(num_unrolls): - code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);") - - # inner loop - code.append("""\ - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1];""") - code.append( - " for (" - + "int64_t" - + " j = start_offset; j < end_offset; ++j) {" # noqa - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) - - if InType == "uint8_t": - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa - ) - code.append(" }") - - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* const ip = &input[idx * block_size];") - code.append(" // weight * input + out") - - for i in range(num_unrolls): - code.extend(compute(i, InType, use_weights)) - - code.append(" ++pos;") - code.append(" }") - - code.append(" // Normalisation") - code.append(" const int64_t length = end_offset - start_offset;") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - - for i in range(num_unrolls): - code.append(f" svst1_f32(svAll, &op[{i} * vLen]," - + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));") - - code.append(" } else {") - # inv of length - for i in range(num_unrolls): - code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});") - - code.append(" }") - code.append(" }") - return code - - -# Handle the case where block_size is not a multiple of vector length. -def generic(IndexType, InType, OutType, use_weights): - def compute(InType, use_weights): - code = [] - if InType == "float": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg, vwgt, svld1_f32(pg, &ip[k])," - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::Half": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " pg,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])))),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::BFloat16": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " pg,\n" - " svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])),\n" - " 16)),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "uint8_t": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_u32_x(pg," - " svld1ub_u32(pg, &ip[k])),\n" # noqa - " svadd_f32_x(pg," - " svld1_f32(pg, &op[k]), vbio)));" - ) - else: - raise ValueError(f"Unknown datatype \"{InType}\"") - - return code - - code = [] - - code.append( - " for (int64_t i = 0; i < output_size; ++i) {" - ) - - code.append(" " + OutType + "* const op = &out[i * block_size];") - - # initialize to 0 - code.append(" memset(op, 0, sizeof(float) * block_size);") - - # inner loop - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - code.append( - " int64_t start_offset = offsets[i];\n" - + " int64_t end_offset = offsets[i + 1];" - ) - code.append( - " for (" - + "int64_t" - + " j = start_offset; j < end_offset; ++j) {" # noqa - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) - - if InType == "uint8_t": - code.append(" // unimplemented") - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa - ) - code.append(" }") - - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* ip = &input[idx * block_size];") - - # compute and store main loop - code.append(" svbool_t pg;") - code.append(" for (int64_t k = 0;") - code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64(" - + "k, block_size));") - code.append(" k += vLen) {") - code.extend(compute(InType, use_weights)) - code.append(" }\n") - code.append(" ++pos;") - code.append(" }") - - code.append(" const int64_t length = end_offset - start_offset;\n") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - code.append(" svbool_t pg;") - code.append(" for (int64_t j = 0;\n" - " svptest_first(svAll, pg = svwhilelt_b32_s64(" - "j, block_size));") - code.append(" j += vLen) {") - code.append( - " svst1_f32(\n" - " pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));" - ) - code.append(" }") - code.append(" }") - code.append(" }") - return code - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--filename", help="file name") - opts = parser.parse_args() - if opts.filename: - filename = opts.filename - else: - filename = "embedding_lookup_idx_sve.cc" - - options = [ - ["int32_t", "int32_t", "float", "float", "float", "float"], - ["int64_t", "int64_t", "float", "float", "float", "float"], - ["int32_t", "int32_t", "half", "at::Half", "float", "float"], - ["int64_t", "int64_t", "half", "at::Half", "float", "float"], - ["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"], - ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], - ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"], - ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], - ] - - code = [] - # includes - code.append("//// --------------------------") - code.append("//// ATTENTION:") - code.append("//// THIS CODE IS AUTOGENERATED") - code.append(f"//// BY {' '.join(sys.argv)}") - code.append("//// DO NOT MODIFY!!!") - code.append("//// --------------------------\n") - - code.append("#include ") - code.append("#include ") - code.append("#include ") - code.append("#include ") - code.append("#include ") - - code.append("namespace caffe2 {\n") - for o in options: - [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o - - code.append("template ") - fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}" - - suffix = "__sve" - fn = "static bool " + fn_base + suffix - code.append(fn + "(") - - args = [] - args.append(" const int64_t block_size,") - args.append(" const int64_t output_size,") - args.append(" const int64_t index_size,") - args.append(" const int64_t data_size,") - args.append(" const " + InType + "* input,") - args.append(" const " + IndexType + "* indices,") - args.append(" const " + IndexType + "* offsets,") - args.append(" const float* weights,") - args.append(" const float* scale_bias,") - args.append(" bool normalize_by_lengths,") - args.append(" " + OutType + "* out) {") - code += args - - code.append(" const svbool_t svAll = svptrue_b32();") - code.append(" const auto vLen = static_cast(svcntw());") - code.append(" int64_t pos = 0;") - - code.append(" if (block_size == 32 * vLen) {") - code += unroll(32, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 16 * vLen) {") - code += unroll(16, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 8 * vLen) {") - code += unroll(8, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 4 * vLen) {") - code += unroll(4, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 2 * vLen) {") - code += unroll(2, IndexType, InType, OutType, True) - code.append(" } else {") - code.append(" // generic code:") - code += generic(IndexType, InType, OutType, True) - code.append(" }") - code.append(" return pos == index_size;") - - code.append("}") - - for is_weight_positional in ["false", "true"]: - code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") - code += args - - # Resolve the Lint warnings: Limit of 80 characters in one line. - extra_space = "\n " - ret_string = " return " + fn_base + suffix \ - + "<" + is_weight_positional + ">(" - if len(ret_string) <= 80: - code.append(ret_string) - else: - code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") - - code.append(" block_size,") - code.append(" output_size,") - code.append(" index_size,") - code.append(" data_size,") - code.append(" input,") - code.append(" indices,") - code.append(" offsets,") - code.append(" weights,") - code.append(" scale_bias,") - code.append(" normalize_by_lengths,") - code.append(" out);") - code.append("}") - - code.append("") - - code.append("} // namespace caffe2") - - with open(filename, "w") as fout: - fout.write("\n".join(code) + "\n") - - print("Created " + filename) - -if __name__ == "__main__": - main() diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 74fc1487333..10fa810b8fd 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -101,16 +101,6 @@ endif() # Also, we will turn off deprecated-declarations # due to protobuf. -# ---[ Check if the compiler has SVE support. -find_package(ARM) # checks SVE -if(CXX_SVE_FOUND) - message(STATUS "Compiler supports SVE extension. Will build perfkernels.") - # Also see CMakeLists.txt under caffe2/perfkernels. - add_compile_definitions(CAFFE2_PERF_WITH_SVE=1) -else() - message(STATUS "Compiler does not support SVE extension. Will not build perfkernels.") -endif() - if(IOS AND (${IOS_ARCH} MATCHES "armv7*")) add_definitions("-mfpu=neon-fp16") add_definitions("-arch" ${IOS_ARCH}) diff --git a/cmake/public/ComputeLibrary.cmake b/cmake/public/ComputeLibrary.cmake index e18527ce65b..d0b3b56ff53 100644 --- a/cmake/public/ComputeLibrary.cmake +++ b/cmake/public/ComputeLibrary.cmake @@ -21,10 +21,10 @@ if("${ACL_VERSION_FILE}" STREQUAL "") message(WARNING "Build may fail: Could not determine ACL version (minimum required is ${ACL_MINIMUM_VERSION})") else() file(READ ${ACL_VERSION_FILE} ACL_VERSION_STRING) - string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION "${ACL_VERSION_STRING}") + string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION ${ACL_VERSION_STRING}) set(ACL_VERSION "${CMAKE_MATCH_1}") - if("${ACL_VERSION}" VERSION_EQUAL "0.0") + if(${ACL_VERSION} VERSION_EQUAL "0.0") # Unreleased ACL versions come with version string "v0.0-unreleased", and may not be compatible with oneDNN. # It is recommended to use the latest release of ACL. message(WARNING "Build may fail: Using unreleased ACL version (minimum required is ${ACL_MINIMUM_VERSION})")