attention fusion kernel refactoring (#8432)

* attention fusion kernel refactored

* consider the case of none in add_qk

* variabled added to check for pre-pack weights

* added a comment to PrePack()

* Optimized prepack and try to free the weights

* making comment sound better

* fixing a bug with optimizer.py

* commented out changes to be done

* removed comments

* make the private fn() private

* fix build

* making clean up fn static

* backed out optimizer tool change, needs more looking into
This commit is contained in:
Viswanath Boga 2021-07-23 17:46:39 -07:00 committed by GitHub
parent a396c9e572
commit 6dee9b9d2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 72 deletions

View file

@ -15,13 +15,17 @@ using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {
static void FreePackedWeights(BufferUniquePtr* array, size_t array_size) {
for (size_t i = 0; i < array_size; i++) {
array[i].reset();
}
}
template <typename T>
class Attention : public OpKernel, public AttentionCPUBase {
public:
explicit Attention(const OpKernelInfo& info);
bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
@ -33,13 +37,13 @@ class Attention : public OpKernel, public AttentionCPUBase {
/*out*/ bool& used_shared_buffers) override;
private:
BufferUniquePtr q_packed_weights_;
BufferUniquePtr k_packed_weights_;
BufferUniquePtr v_packed_weights_;
bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size,
size_t input_hidden_size, const T* weights_data,
size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
size_t q_packed_weights_size_ = 0;
size_t k_packed_weights_size_ = 0;
size_t v_packed_weights_size_ = 0;
BufferUniquePtr packed_weights_[3];
size_t packed_weights_size_[3] = {0, 0, 0};
bool is_prepack_ = false;
TensorShape weight_shape_;
};
@ -297,22 +301,8 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
// buffer memory and we don not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_weights_data, 0, packed_weights_data_size);
switch (qkv_index) {
case 0:
q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
q_packed_weights_size_ = packb_size;
break;
case 1:
k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
k_packed_weights_size_ = packb_size;
break;
case 2:
v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
v_packed_weights_size_ = packb_size;
break;
default:
return false;
}
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
packed_weights_size_[qkv_index] = packb_size;
for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
@ -320,22 +310,8 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
weights_data += head_size;
}
bool share_prepacked_weights = (prepacked_weights != nullptr);
if (share_prepacked_weights) {
switch (qkv_index) {
case 0:
prepacked_weights->buffers_.push_back(std::move(q_packed_weights_));
break;
case 1:
prepacked_weights->buffers_.push_back(std::move(k_packed_weights_));
break;
case 2:
prepacked_weights->buffers_.push_back(std::move(v_packed_weights_));
break;
default:
break;
}
if (prepacked_weights != nullptr) {
prepacked_weights->buffers_.push_back(std::move(packed_weights_[qkv_index]));
prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
}
return true;
@ -345,6 +321,15 @@ template <typename T>
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
/* The PrePack() massages the weights to speed up Compute(), there is an option to
* use shared prepacked weights in which case prepacked_weights parameter would be non-null.
*
* We use an array of buffers to store prepacked Q, K, V weights for the sake of simplicity
* and easy offset management in Compute(). They are packed one after the other. In case of failure,
* 1. With shared pre-pack weights the caller of this fn() frees up the memory so far allocated.
* 2. When weights are held by kernel, it will be freed before returning.
*/
is_packed = false;
if (1 != input_idx) {
@ -386,19 +371,20 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
v_hidden_size = hidden_size;
}
const size_t q_head_size = q_hidden_size / num_heads_;
const size_t k_head_size = k_hidden_size / num_heads_;
const size_t v_head_size = v_hidden_size / num_heads_;
const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) {
// we are not cleaning up anything, assuming caller takes care of this
if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) {
if (prepacked_weights == nullptr) {
FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size());
}
return Status::OK();
}
is_packed = true;
is_prepack_ = true;
return Status::OK();
}
@ -411,9 +397,9 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
}
used_shared_buffers = true;
q_packed_weights_ = std::move(prepacked_buffers[0]);
k_packed_weights_ = std::move(prepacked_buffers[1]);
v_packed_weights_ = std::move(prepacked_buffers[2]);
packed_weights_[0] = std::move(prepacked_buffers[0]);
packed_weights_[1] = std::move(prepacked_buffers[1]);
packed_weights_[2] = std::move(prepacked_buffers[2]);
return Status::OK();
}
@ -421,7 +407,7 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
template <typename T>
Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = q_packed_weights_ ? nullptr : context->Input<Tensor>(1);
const Tensor* weights = is_prepack_ ? nullptr : context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
@ -512,7 +498,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
int weights_offset = 0;
int bias_offset = qkv_index * q_hidden_size + head_index * head_size;
if (q_packed_weights_ == nullptr) {
if (!is_prepack_) {
weights_offset = bias_offset;
} else {
weights_offset = head_index * head_size;
@ -534,15 +520,9 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// A: input (BxSxD) (B.)S x D S x D
// B: weights (DxNxT) D x (N.)T D x H
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
if (q_packed_weights_) {
if (is_prepack_) {
uint8_t* packed_weight;
if (qkv_index == 0) {
packed_weight = static_cast<uint8_t*>(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size);
} else if (qkv_index == 1) {
packed_weight = static_cast<uint8_t*>(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size);
} else {
packed_weight = static_cast<uint8_t*>(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size);
}
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size);
MlasGemm(
CblasNoTrans, // TransA = no

View file

@ -128,20 +128,18 @@ class FusionAttention(Fusion):
return num_heads, hidden_size
def get_add_qk_str(self, add_qk: NodeProto):
# Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1
# inputs_ids has to be the name of input, this may be changes if needed
inputs_ids = self.model.find_graph_input("input_ids")
if inputs_ids == None:
logger.debug("no input with name \"input_ids\"")
return None
batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type)
shape_infer = self.model.infer_runtime_shape(update=True)
if shape_infer is None:
return
if batch_size < 0 or seq_len < 0:
logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve")
input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
if input_0_shape is None or input_1_shape is None:
logger.debug(f"one of the inputs of {add_qk} is None")
return None
shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model)
shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len})
if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]):
if input_0_shape != input_1_shape:
logger.debug(f"the shape of two inputs of {add_qk} is not same")
return None