diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 8375ff7dfb..3c8e8c2557 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -15,13 +15,17 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { +static void FreePackedWeights(BufferUniquePtr* array, size_t array_size) { + for (size_t i = 0; i < array_size; i++) { + array[i].reset(); + } +} + template class Attention : public OpKernel, public AttentionCPUBase { public: explicit Attention(const OpKernelInfo& info); - bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights); - Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -33,13 +37,13 @@ class Attention : public OpKernel, public AttentionCPUBase { /*out*/ bool& used_shared_buffers) override; private: - BufferUniquePtr q_packed_weights_; - BufferUniquePtr k_packed_weights_; - BufferUniquePtr v_packed_weights_; + bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, + size_t input_hidden_size, const T* weights_data, + size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights); - size_t q_packed_weights_size_ = 0; - size_t k_packed_weights_size_ = 0; - size_t v_packed_weights_size_ = 0; + BufferUniquePtr packed_weights_[3]; + size_t packed_weights_size_[3] = {0, 0, 0}; + bool is_prepack_ = false; TensorShape weight_shape_; }; @@ -297,22 +301,8 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, // buffer memory and we don not want it uninitialized and generate different hashes // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_weights_data, 0, packed_weights_data_size); - switch (qkv_index) { - case 0: - q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); - q_packed_weights_size_ = packb_size; - break; - case 1: - k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); - k_packed_weights_size_ = packb_size; - break; - case 2: - v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); - v_packed_weights_size_ = packb_size; - break; - default: - return false; - } + packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); + packed_weights_size_[qkv_index] = packb_size; for (size_t i = 0; i < loop_len; i++) { MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); @@ -320,22 +310,8 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, weights_data += head_size; } - bool share_prepacked_weights = (prepacked_weights != nullptr); - if (share_prepacked_weights) { - switch (qkv_index) { - case 0: - prepacked_weights->buffers_.push_back(std::move(q_packed_weights_)); - break; - case 1: - prepacked_weights->buffers_.push_back(std::move(k_packed_weights_)); - break; - case 2: - prepacked_weights->buffers_.push_back(std::move(v_packed_weights_)); - break; - default: - break; - } - + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_weights_[qkv_index])); prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size); } return true; @@ -345,6 +321,15 @@ template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + + /* The PrePack() massages the weights to speed up Compute(), there is an option to + * use shared prepacked weights in which case prepacked_weights parameter would be non-null. + * + * We use an array of buffers to store prepacked Q, K, V weights for the sake of simplicity + * and easy offset management in Compute(). They are packed one after the other. In case of failure, + * 1. With shared pre-pack weights the caller of this fn() frees up the memory so far allocated. + * 2. When weights are held by kernel, it will be freed before returning. + */ is_packed = false; if (1 != input_idx) { @@ -386,19 +371,20 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr v_hidden_size = hidden_size; } - const size_t q_head_size = q_hidden_size / num_heads_; - const size_t k_head_size = k_hidden_size / num_heads_; - const size_t v_head_size = v_hidden_size / num_heads_; + const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_}; const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size; - if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) || - !IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) || - !IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) { - // we are not cleaning up anything, assuming caller takes care of this + if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) { + if (prepacked_weights == nullptr) { + FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size()); + } return Status::OK(); } is_packed = true; + is_prepack_ = true; return Status::OK(); } @@ -411,9 +397,9 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre } used_shared_buffers = true; - q_packed_weights_ = std::move(prepacked_buffers[0]); - k_packed_weights_ = std::move(prepacked_buffers[1]); - v_packed_weights_ = std::move(prepacked_buffers[2]); + packed_weights_[0] = std::move(prepacked_buffers[0]); + packed_weights_[1] = std::move(prepacked_buffers[1]); + packed_weights_[2] = std::move(prepacked_buffers[2]); return Status::OK(); } @@ -421,7 +407,7 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre template Status Attention::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); - const Tensor* weights = q_packed_weights_ ? nullptr : context->Input(1); + const Tensor* weights = is_prepack_ ? nullptr : context->Input(1); const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); @@ -512,7 +498,7 @@ Status Attention::Compute(OpKernelContext* context) const { int weights_offset = 0; int bias_offset = qkv_index * q_hidden_size + head_index * head_size; - if (q_packed_weights_ == nullptr) { + if (!is_prepack_) { weights_offset = bias_offset; } else { weights_offset = head_index * head_size; @@ -534,15 +520,9 @@ Status Attention::Compute(OpKernelContext* context) const { // A: input (BxSxD) (B.)S x D S x D // B: weights (DxNxT) D x (N.)T D x H // C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H - if (q_packed_weights_) { + if (is_prepack_) { uint8_t* packed_weight; - if (qkv_index == 0) { - packed_weight = static_cast(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size); - } else if (qkv_index == 1) { - packed_weight = static_cast(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size); - } else { - packed_weight = static_cast(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size); - } + packed_weight = static_cast(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size); MlasGemm( CblasNoTrans, // TransA = no diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 7830971c4e..691f40d949 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -128,20 +128,18 @@ class FusionAttention(Fusion): return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - # Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1 - # inputs_ids has to be the name of input, this may be changes if needed - inputs_ids = self.model.find_graph_input("input_ids") - if inputs_ids == None: - logger.debug("no input with name \"input_ids\"") - return None - batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type) + shape_infer = self.model.infer_runtime_shape(update=True) + if shape_infer is None: + return - if batch_size < 0 or seq_len < 0: - logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve") + input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + + if input_0_shape is None or input_1_shape is None: + logger.debug(f"one of the inputs of {add_qk} is None") return None - shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model) - shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len}) - if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]): + + if input_0_shape != input_1_shape: logger.debug(f"the shape of two inputs of {add_qk} is not same") return None