mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
attention fusion kernel refactoring (#8432)
* attention fusion kernel refactored * consider the case of none in add_qk * variabled added to check for pre-pack weights * added a comment to PrePack() * Optimized prepack and try to free the weights * making comment sound better * fixing a bug with optimizer.py * commented out changes to be done * removed comments * make the private fn() private * fix build * making clean up fn static * backed out optimizer tool change, needs more looking into
This commit is contained in:
parent
a396c9e572
commit
6dee9b9d2d
2 changed files with 50 additions and 72 deletions
|
|
@ -15,13 +15,17 @@ using onnxruntime::concurrency::ThreadPool;
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
static void FreePackedWeights(BufferUniquePtr* array, size_t array_size) {
|
||||
for (size_t i = 0; i < array_size; i++) {
|
||||
array[i].reset();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Attention : public OpKernel, public AttentionCPUBase {
|
||||
public:
|
||||
explicit Attention(const OpKernelInfo& info);
|
||||
|
||||
bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
|
|
@ -33,13 +37,13 @@ class Attention : public OpKernel, public AttentionCPUBase {
|
|||
/*out*/ bool& used_shared_buffers) override;
|
||||
|
||||
private:
|
||||
BufferUniquePtr q_packed_weights_;
|
||||
BufferUniquePtr k_packed_weights_;
|
||||
BufferUniquePtr v_packed_weights_;
|
||||
bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size,
|
||||
size_t input_hidden_size, const T* weights_data,
|
||||
size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
|
||||
|
||||
size_t q_packed_weights_size_ = 0;
|
||||
size_t k_packed_weights_size_ = 0;
|
||||
size_t v_packed_weights_size_ = 0;
|
||||
BufferUniquePtr packed_weights_[3];
|
||||
size_t packed_weights_size_[3] = {0, 0, 0};
|
||||
bool is_prepack_ = false;
|
||||
TensorShape weight_shape_;
|
||||
};
|
||||
|
||||
|
|
@ -297,22 +301,8 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
|
|||
// buffer memory and we don not want it uninitialized and generate different hashes
|
||||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_weights_data, 0, packed_weights_data_size);
|
||||
switch (qkv_index) {
|
||||
case 0:
|
||||
q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
q_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
case 1:
|
||||
k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
k_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
case 2:
|
||||
v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
v_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
packed_weights_size_[qkv_index] = packb_size;
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
|
||||
|
|
@ -320,22 +310,8 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
|
|||
weights_data += head_size;
|
||||
}
|
||||
|
||||
bool share_prepacked_weights = (prepacked_weights != nullptr);
|
||||
if (share_prepacked_weights) {
|
||||
switch (qkv_index) {
|
||||
case 0:
|
||||
prepacked_weights->buffers_.push_back(std::move(q_packed_weights_));
|
||||
break;
|
||||
case 1:
|
||||
prepacked_weights->buffers_.push_back(std::move(k_packed_weights_));
|
||||
break;
|
||||
case 2:
|
||||
prepacked_weights->buffers_.push_back(std::move(v_packed_weights_));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (prepacked_weights != nullptr) {
|
||||
prepacked_weights->buffers_.push_back(std::move(packed_weights_[qkv_index]));
|
||||
prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
|
||||
}
|
||||
return true;
|
||||
|
|
@ -345,6 +321,15 @@ template <typename T>
|
|||
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
|
||||
/*out*/ bool& is_packed,
|
||||
/*out*/ PrePackedWeights* prepacked_weights) {
|
||||
|
||||
/* The PrePack() massages the weights to speed up Compute(), there is an option to
|
||||
* use shared prepacked weights in which case prepacked_weights parameter would be non-null.
|
||||
*
|
||||
* We use an array of buffers to store prepacked Q, K, V weights for the sake of simplicity
|
||||
* and easy offset management in Compute(). They are packed one after the other. In case of failure,
|
||||
* 1. With shared pre-pack weights the caller of this fn() frees up the memory so far allocated.
|
||||
* 2. When weights are held by kernel, it will be freed before returning.
|
||||
*/
|
||||
is_packed = false;
|
||||
|
||||
if (1 != input_idx) {
|
||||
|
|
@ -386,19 +371,20 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
|
|||
v_hidden_size = hidden_size;
|
||||
}
|
||||
|
||||
const size_t q_head_size = q_hidden_size / num_heads_;
|
||||
const size_t k_head_size = k_hidden_size / num_heads_;
|
||||
const size_t v_head_size = v_hidden_size / num_heads_;
|
||||
const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
|
||||
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
|
||||
|
||||
if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) {
|
||||
// we are not cleaning up anything, assuming caller takes care of this
|
||||
if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) {
|
||||
if (prepacked_weights == nullptr) {
|
||||
FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
is_packed = true;
|
||||
is_prepack_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -411,9 +397,9 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
|
|||
}
|
||||
|
||||
used_shared_buffers = true;
|
||||
q_packed_weights_ = std::move(prepacked_buffers[0]);
|
||||
k_packed_weights_ = std::move(prepacked_buffers[1]);
|
||||
v_packed_weights_ = std::move(prepacked_buffers[2]);
|
||||
packed_weights_[0] = std::move(prepacked_buffers[0]);
|
||||
packed_weights_[1] = std::move(prepacked_buffers[1]);
|
||||
packed_weights_[2] = std::move(prepacked_buffers[2]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -421,7 +407,7 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
|
|||
template <typename T>
|
||||
Status Attention<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = q_packed_weights_ ? nullptr : context->Input<Tensor>(1);
|
||||
const Tensor* weights = is_prepack_ ? nullptr : context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
|
|
@ -512,7 +498,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
int weights_offset = 0;
|
||||
int bias_offset = qkv_index * q_hidden_size + head_index * head_size;
|
||||
|
||||
if (q_packed_weights_ == nullptr) {
|
||||
if (!is_prepack_) {
|
||||
weights_offset = bias_offset;
|
||||
} else {
|
||||
weights_offset = head_index * head_size;
|
||||
|
|
@ -534,15 +520,9 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
// A: input (BxSxD) (B.)S x D S x D
|
||||
// B: weights (DxNxT) D x (N.)T D x H
|
||||
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
|
||||
if (q_packed_weights_) {
|
||||
if (is_prepack_) {
|
||||
uint8_t* packed_weight;
|
||||
if (qkv_index == 0) {
|
||||
packed_weight = static_cast<uint8_t*>(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size);
|
||||
} else if (qkv_index == 1) {
|
||||
packed_weight = static_cast<uint8_t*>(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size);
|
||||
} else {
|
||||
packed_weight = static_cast<uint8_t*>(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size);
|
||||
}
|
||||
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size);
|
||||
|
||||
MlasGemm(
|
||||
CblasNoTrans, // TransA = no
|
||||
|
|
|
|||
|
|
@ -128,20 +128,18 @@ class FusionAttention(Fusion):
|
|||
return num_heads, hidden_size
|
||||
|
||||
def get_add_qk_str(self, add_qk: NodeProto):
|
||||
# Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1
|
||||
# inputs_ids has to be the name of input, this may be changes if needed
|
||||
inputs_ids = self.model.find_graph_input("input_ids")
|
||||
if inputs_ids == None:
|
||||
logger.debug("no input with name \"input_ids\"")
|
||||
return None
|
||||
batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type)
|
||||
shape_infer = self.model.infer_runtime_shape(update=True)
|
||||
if shape_infer is None:
|
||||
return
|
||||
|
||||
if batch_size < 0 or seq_len < 0:
|
||||
logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve")
|
||||
input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
|
||||
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
|
||||
|
||||
if input_0_shape is None or input_1_shape is None:
|
||||
logger.debug(f"one of the inputs of {add_qk} is None")
|
||||
return None
|
||||
shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model)
|
||||
shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len})
|
||||
if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]):
|
||||
|
||||
if input_0_shape != input_1_shape:
|
||||
logger.debug(f"the shape of two inputs of {add_qk} is not same")
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue