From dec11afb834a49a00c8b73049dcd7fec919fe735 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 3 Apr 2023 18:25:25 -0700 Subject: [PATCH] Fix a prefast warning (#15343) ### Description ### Motivation and Context https://aiinfra.visualstudio.com/ONNX%20Runtime/_workitems/edit/14272/?triage=true --- .../decoder/decoder_masked_multihead_attention.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc index 6130fd9eeb..1a810d24eb 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc @@ -107,9 +107,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); - // Present input will have the same shape as the past input - Tensor* present_key = context->Output(kPresentOutputIndex, past_key->Shape()); - Tensor* present_value = context->Output(kPresentOutputIndex + 1, past_value->Shape()); + std::vector present_dims{ + parameters.batch_size, parameters.num_heads, + past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length, + parameters.head_size}; + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); + Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); auto cuda_stream = Stream(context); @@ -139,6 +143,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* } else { // Sanity check ORT_ENFORCE(past_present_share_buffer_); + ORT_ENFORCE(past_key != nullptr && past_value != nullptr); auto* present_key_data = present_key->MutableData(); auto* present_value_data = present_value->MutableData();