Enable fast qlinear_dynamic path for AArch64 through ACL directly

This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687 enabled an optimized implementation for qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48).
To achieve this, we:
* Use ACL which is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
* Add ACL to ATen's CPU include and dependency libs
* Introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.
* A future PR will introduce a direct ACL implementation qlinear and will allow us to remove the dependence on PackedLinearWeightsOnednn
This commit is contained in:
Fadi Arafeh 2025-01-29 16:55:23 +00:00
parent 7b07415aaa
commit 3d05899222
6 changed files with 468 additions and 12 deletions

View file

@ -1250,6 +1250,18 @@ if(USE_MIMALLOC)
include_directories(third_party/mimalloc/include)
endif()
if(USE_MKLDNN_ACL)
find_package(ACL REQUIRED)
if(ACL_FOUND)
include_directories(${ACL_INCLUDE_DIRS})
message(STATUS "ACL Include: ${ACL_INCLUDE_DIRS}")
message(STATUS "ACL Library: ${ACL_LIBRARIES}")
else()
message(FATAL_ERROR "ACL not found")
endif()
endif()
if(USE_MIMALLOC AND USE_MIMALLOC_ON_MKL)
add_definitions(-DUSE_MIMALLOC_ON_MKL)
endif()

View file

@ -458,6 +458,12 @@ if(MKLDNN_FOUND)
list(APPEND ATen_CPU_DEPENDENCY_LIBS ${MKLDNN_LIBRARIES})
endif(MKLDNN_FOUND)
if(USE_MKLDNN_ACL)
list(APPEND ATen_CPU_INCLUDE ${ACL_INCLUDE_DIRS})
list(APPEND ATen_CPU_DEPENDENCY_LIBS ${ACL_LIBRARIES})
endif()
if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
endif()

View file

@ -0,0 +1,251 @@
#pragma once
#include <ATen/Config.h>
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <arm_compute/core/Error.h>
#include <arm_compute/core/TensorInfo.h>
#include <arm_compute/function_info/ActivationLayerInfo.h>
#include <arm_compute/runtime/Allocator.h>
#include <arm_compute/runtime/NEON/functions/NEActivationLayer.h>
#include <arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h>
#include <arm_compute/runtime/NEON/functions/NEQuantizationLayer.h>
#include <arm_compute/runtime/Tensor.h>
#include <array>
using ACLDynamicQuantMatmulCacheKey = std::tuple<
int64_t, // M
bool, // FUSE_RELU
int64_t // NUM_THREADS
>;
enum ACLDynamicQuantMatmulCacheKeyIndex {
M,
FUSE_RELU,
NUM_THREADS,
};
struct ACLDynamicQuantMatmul {
arm_compute::Tensor src_s8_tensor;
arm_compute::Tensor src_fp32_tensor;
arm_compute::Tensor wei_tensor;
arm_compute::Tensor bia_tensor;
arm_compute::Tensor dst_tensor;
arm_compute::NEQuantizationLayer quant;
std::shared_ptr<arm_compute::IMemoryManager> memory_manager{
arm_compute::MemoryManagerOnDemand::make_default()};
arm_compute::NEGEMMLowpMatrixMultiplyCore gemm{memory_manager};
arm_compute::NEActivationLayer acl_relu;
// configuration details for the ACL gemm
arm_compute::TensorInfo src_s8_tensor_info;
arm_compute::TensorInfo src_fp32_tensor_info;
arm_compute::TensorInfo wei_tensor_info;
arm_compute::TensorInfo bia_tensor_info;
arm_compute::TensorInfo dst_tensor_info;
arm_compute::GEMMInfo gemm_info;
arm_compute::ActivationLayerInfo acl_relu_info{
arm_compute::ActivationFunction::RELU};
bool with_bias{false};
// key for use in the cache
ACLDynamicQuantMatmulCacheKey key;
~ACLDynamicQuantMatmul() {
// this will free memory allocated for the quantized src tensor since the
// allocation happened through ACL: src_s8_tensor.allocator()->allocate()
src_s8_tensor.allocator()->free();
// this will not free memory, it will just tell ACL that we're no longer
// using the pointer
wei_tensor.allocator()->free();
if (with_bias) {
bia_tensor.allocator()->free();
}
// deallocate memory used for auxiliary tensors
memory_manager->clear();
}
};
struct PackedLinearWeightsACL : public PackedLinearWeightsOnednn {
PackedLinearWeightsACL(
std::unique_ptr<ideep::tensor> weight,
std::optional<ideep::tensor> bias,
at::Tensor orig_weight,
std::optional<at::Tensor> orig_bias)
: PackedLinearWeightsOnednn(
std::move(weight),
std::move(bias),
std::move(orig_weight),
std::move(orig_bias)) {
auto w = *(weight_.get());
k_ = w.get_dim(0);
n_ = w.get_dim(1);
wei_zero_point_ = orig_weight_.q_zero_point();
wei_scale_ = orig_weight_.q_scale();
}
int64_t k_;
int64_t n_;
int64_t wei_zero_point_;
double wei_scale_;
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false)
override;
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
override;
std::shared_ptr<ACLDynamicQuantMatmul> get_acl_dynamic_quant_matmul(
const ACLDynamicQuantMatmulCacheKey& key) {
// We're only maintaining a 2 element LRU cache
// hit first
if (acl_dynamic_quant_cache[0] != nullptr &&
acl_dynamic_quant_cache[0]->key == key) {
return acl_dynamic_quant_cache[0];
}
// hit second
if (acl_dynamic_quant_cache[1] != nullptr &&
acl_dynamic_quant_cache[1]->key == key) {
// update LRU
std::rotate(
acl_dynamic_quant_cache.begin(),
acl_dynamic_quant_cache.begin() + 1,
acl_dynamic_quant_cache.end());
return acl_dynamic_quant_cache[0];
}
// miss -> replace Least Recently Used - i.e. element at index 1
acl_dynamic_quant_cache[1] = create_acl_dynamic_quant_matmul(key);
std::rotate(
acl_dynamic_quant_cache.begin(),
acl_dynamic_quant_cache.begin() + 1,
acl_dynamic_quant_cache.end());
return acl_dynamic_quant_cache[0];
}
private:
// A 2 element (per layer) cache. Given it's not intended to store more than 2
// elements, we do not need a fancy implementation. The idea behind it is to
// allow for a (configuration free) fast path for autoregressive
// transformer-like models which usually involve 2 input tensor shapes; one
// for the prefill phase and another for the autoregressive phase
std::array<std::shared_ptr<ACLDynamicQuantMatmul>, 2> acl_dynamic_quant_cache;
std::shared_ptr<ACLDynamicQuantMatmul> create_acl_dynamic_quant_matmul(
const ACLDynamicQuantMatmulCacheKey& key) {
int64_t m = std::get<M>(key);
bool fuse_relu = std::get<FUSE_RELU>(key);
auto acl_gemm = std::make_shared<ACLDynamicQuantMatmul>();
acl_gemm->with_bias = bias_.has_value();
acl_gemm->key = key;
acl_gemm->src_fp32_tensor_info = arm_compute::TensorInfo(
arm_compute::TensorShape(k_, m), arm_compute::Format::F32);
acl_gemm->src_fp32_tensor_info.set_are_values_constant(false);
acl_gemm->src_s8_tensor_info = arm_compute::TensorInfo(
arm_compute::TensorShape(k_, m),
1,
arm_compute::DataType::QASYMM8_SIGNED,
// TODO: setting the initial offset value to int8_t max instead of zero,
// because ACL currently skips MatrixBReduction calculation if the
// source offset at configuration time is zero. This is fixed by this
// PR: https://review.mlplatform.org/c/ml/ComputeLibrary/+/12820/8 This
// will be set to the actual src offset value at runtime.
arm_compute::QuantizationInfo(
1.0, std::numeric_limits<int8_t>::max(), true));
acl_gemm->src_s8_tensor_info.set_are_values_constant(false);
acl_gemm->wei_tensor_info = arm_compute::TensorInfo(
arm_compute::TensorShape(n_, k_),
1,
arm_compute::DataType::QASYMM8_SIGNED,
arm_compute::QuantizationInfo(wei_scale_, wei_zero_point_, true));
acl_gemm->wei_tensor_info.set_are_values_constant(true);
acl_gemm->bia_tensor_info = arm_compute::TensorInfo(
arm_compute::TensorShape(), 1, arm_compute::DataType::F32);
if (acl_gemm->with_bias) {
acl_gemm->bia_tensor_info.set_tensor_shape(
arm_compute::TensorShape(1, n_));
}
acl_gemm->dst_tensor_info = arm_compute::TensorInfo(
arm_compute::TensorShape(n_, m), arm_compute::Format::F32);
// validate that ACL can handle the given problem and inputs.
if (fuse_relu) {
arm_compute::Status relu_status =
arm_compute::NEActivationLayer::validate(
&acl_gemm->dst_tensor_info,
&acl_gemm->dst_tensor_info,
acl_gemm->acl_relu_info);
if (relu_status.error_code() != arm_compute::ErrorCode::OK) {
return nullptr;
}
}
arm_compute::Status quant_status =
arm_compute::NEQuantizationLayer::validate(
&acl_gemm->src_fp32_tensor_info, &acl_gemm->src_s8_tensor_info);
if (quant_status.error_code() != arm_compute::ErrorCode::OK) {
return nullptr;
}
arm_compute::Status gemm_status =
arm_compute::NEGEMMLowpMatrixMultiplyCore::validate(
&acl_gemm->src_s8_tensor_info,
&acl_gemm->wei_tensor_info,
acl_gemm->with_bias ? &acl_gemm->bia_tensor_info : nullptr,
&acl_gemm->dst_tensor_info,
acl_gemm->gemm_info);
if (gemm_status.error_code() != arm_compute::ErrorCode::OK) {
return nullptr;
}
// set the tensor info (i.e. shape, datatype, quant info) for the ACL
// tensors
acl_gemm->src_fp32_tensor.allocator()->init(acl_gemm->src_fp32_tensor_info);
acl_gemm->src_s8_tensor.allocator()->init(acl_gemm->src_s8_tensor_info);
acl_gemm->wei_tensor.allocator()->init(acl_gemm->wei_tensor_info);
if (acl_gemm->with_bias) {
acl_gemm->bia_tensor.allocator()->init(acl_gemm->bia_tensor_info);
}
acl_gemm->dst_tensor.allocator()->init(acl_gemm->dst_tensor_info);
// allocate memory only for the quantized tensor, the rest will use memory
// already avaliable from PyTorch
acl_gemm->src_s8_tensor.allocator()->allocate();
// give ACL access to weight and bias pointer
acl_gemm->wei_tensor.allocator()->import_memory(
(int8_t*)weight_.get()->get_data_handle());
if (bias_.has_value()) {
acl_gemm->bia_tensor.allocator()->import_memory(
(float*)bias_.value().get_data_handle());
}
// configure
acl_gemm->quant.configure(
&acl_gemm->src_fp32_tensor, &acl_gemm->src_s8_tensor);
acl_gemm->gemm.configure(
&acl_gemm->src_s8_tensor,
&acl_gemm->wei_tensor,
acl_gemm->with_bias ? &acl_gemm->bia_tensor : nullptr,
&acl_gemm->dst_tensor,
acl_gemm->gemm_info);
if (fuse_relu) {
acl_gemm->acl_relu.configure(
&acl_gemm->dst_tensor,
&acl_gemm->dst_tensor,
acl_gemm->acl_relu_info);
}
// allocate memory for ACL's auxiliary tensors
arm_compute::Allocator alloc{};
acl_gemm->memory_manager->populate(alloc, 1);
return acl_gemm;
}
template <bool ReluFused>
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false);
};
#endif // #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()

View file

@ -5,6 +5,7 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <ATen/native/quantized/cpu/ACLUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/quantized/library.h>
#include <ATen/native/quantized/PackedParams.h>
@ -697,6 +698,127 @@ static at::Tensor linear_dynamic_fp16_with_onednn_weight(
primitive.execute(ideep::stream::default_stream(), args);
return dim == 2 ? output : output.reshape(output_size);
}
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
template <bool ReluFused>
at::Tensor PackedLinearWeightsACL::apply_dynamic_impl(
at::Tensor input,
bool reduce_range) {
// Dynamic: fp32 * int8 -> fp32
using at::Tensor;
TORCH_CHECK(
input.dim() >= 2,
"The dimension of input tensor should be larger than or equal to 2");
TORCH_CHECK(
input.scalar_type() == c10::ScalarType::Float,
"qlinear_dynamic (ONEDNN): data type of input should be float.");
auto input_contig = input.contiguous();
const int64_t dim = input.dim();
auto input_reshaped =
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
auto input_dims = input_reshaped.sizes().vec();
int64_t m = input_dims[0];
auto key = std::make_tuple(
m, ReluFused, static_cast<int64_t>(at::get_num_threads()));
auto acl_gemm = get_acl_dynamic_quant_matmul(key);
if (acl_gemm) {
// Find quantization parameters
float x_max = 0, x_min = 0;
#ifdef USE_FBGEMM
// Use FBGEMM's FindMinMax if available since it's faster
fbgemm::FindMinMax(
/*m=*/input_contig.data_ptr<float>(),
/*min=*/&x_min,
/*max=*/&x_max,
/*len=*/input.numel());
#else
if (input_contig.numel() > 0) {
auto [t_min, t_max] = at::aminmax(input_contig);
x_max = t_max.item<float>();
x_min = t_min.item<float>();
}
#endif
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/std::numeric_limits<int8_t>::min(),
/*qmax=*/std::numeric_limits<int8_t>::max(),
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
acl_gemm->src_fp32_tensor.allocator()->import_memory(
(float*)input_contig.data_ptr());
acl_gemm->src_s8_tensor.info()->set_quantization_info(
arm_compute::QuantizationInfo(
q_params.scale, q_params.zero_point, true));
// quantize src tensor: fp32 -> s8
acl_gemm->quant.run();
// allocation for fp32 out tensor
at::Tensor output = at::empty({m, n_}, input.options().dtype(at::kFloat));
if (output.numel() == 0)
return output;
// We set the offset to "-zero_point" for the GEMM, but to "zero_point" for
// the quantization layer This is a known inconsistency in ACL.
acl_gemm->src_s8_tensor.info()->set_quantization_info(
arm_compute::QuantizationInfo(
q_params.scale, -q_params.zero_point, true));
acl_gemm->dst_tensor.allocator()->import_memory((float*)output.data_ptr());
// s8 src, s8 wei -> f32 dst
acl_gemm->gemm.run();
if (ReluFused) {
acl_gemm->acl_relu.run();
}
// this will not free memory, it will just tell ACL that we're no longer
// using the pointer
acl_gemm->src_fp32_tensor.allocator()->free();
acl_gemm->dst_tensor.allocator()->free();
auto out_sizes = input.sizes().vec();
out_sizes.back() = n_;
if (output.sizes().vec() == out_sizes)
return output;
return output.reshape(out_sizes);
}
// fallback to oneDNN in the unlikely scinario that ACL's validation fails
if (ReluFused) {
return PackedLinearWeightsOnednn::apply_dynamic_relu(input, reduce_range);
} else {
return PackedLinearWeightsOnednn::apply_dynamic(input, reduce_range);
}
}
at::Tensor PackedLinearWeightsACL::apply_dynamic(
at::Tensor input,
bool reduce_range) {
return apply_dynamic_impl</*ReluFused=*/false>(
std::move(input), reduce_range);
}
at::Tensor PackedLinearWeightsACL::apply_dynamic_relu(
at::Tensor input,
bool reduce_range) {
return apply_dynamic_impl</*ReluFused=*/true>(std::move(input), reduce_range);
}
#endif // #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
#endif // #if AT_MKLDNN_ENABLED()
namespace at::native {

View file

@ -1,15 +1,16 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Context.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/Context.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/quantized/PackedParams.h>
#include <ATen/native/quantized/cpu/ACLUtils.h>
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/quantized/library.h>
#include <ATen/native/quantized/PackedParams.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/quantized/Quantizer.h>
#include <torch/custom_class.h>
#include <torch/library.h>
@ -279,12 +280,15 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
packed_bias.init(bias_desc, b.data_ptr());
onednn_bias = std::optional<ideep::tensor>(packed_bias);
}
auto ret_ptr = c10::make_intrusive<PackedLinearWeightsOnednn>(
PackedLinearWeightsOnednn{
std::move(weight_ptr),
onednn_bias,
weight,
bias});
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
if (qtype == c10::kPerTensorAffine) {
return c10::make_intrusive<PackedLinearWeightsACL>(PackedLinearWeightsACL{
std::move(weight_ptr), onednn_bias, weight, bias});
}
#endif // #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
auto ret_ptr =
c10::make_intrusive<PackedLinearWeightsOnednn>(PackedLinearWeightsOnednn{
std::move(weight_ptr), onednn_bias, weight, bias});
return ret_ptr;
}

View file

@ -0,0 +1,61 @@
# ----------
# FindACL
# ----------
#
# Finds the Arm Compute Library
# https://arm-software.github.io/ComputeLibrary/latest/
#
# This module defines the following variables:
#
# ACL_FOUND - True if ACL was found
# ACL_INCLUDE_DIRS - include directories for ACL
# ACL_LIBRARIES - link against this library to use ACL
#
# The module will also define two cache variables:
#
# ACL_INCLUDE_DIR - the ACL include directory
# ACL_LIBRARY - the path to the ACL library
#
# Use ACL_ROOT_DIR environment variable to find the library and headers
find_path(ACL_INCLUDE_DIR
NAMES arm_compute/graph.h
PATHS ENV ACL_ROOT_DIR
)
find_library(ACL_LIBRARY
NAMES arm_compute
PATHS ENV ACL_ROOT_DIR
PATH_SUFFIXES lib build
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(ACL DEFAULT_MSG
ACL_INCLUDE_DIR
ACL_LIBRARY
)
mark_as_advanced(
ACL_LIBRARY
ACL_INCLUDE_DIR
)
# Find the extra libraries and include dirs
if(ACL_FOUND)
find_path(ACL_EXTRA_INCLUDE_DIR
NAMES half/half.hpp
PATHS ENV ACL_ROOT_DIR
PATH_SUFFIXES include
)
find_library(ACL_GRAPH_LIBRARY
NAMES arm_compute_graph
PATHS ENV ACL_ROOT_DIR
PATH_SUFFIXES lib build
)
list(APPEND ACL_INCLUDE_DIRS
${ACL_INCLUDE_DIR} ${ACL_EXTRA_INCLUDE_DIR})
list(APPEND ACL_LIBRARIES
${ACL_LIBRARY} ${ACL_GRAPH_LIBRARY})
endif()