fix error due to () not used on operator priority. (#14699)

This commit is contained in:
Zhang Lei 2023-02-16 13:11:52 -08:00 committed by GitHub
parent 6f99fb9d4b
commit ff3aed8540
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -146,7 +146,8 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
const auto BNS = batch_size * num_heads_ * seq_len;
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
bool reuse_output = (seq_len >= D);
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
@ -161,9 +162,9 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);
// reuse output if possible
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
: reinterpret_cast<CudaT*>(output->template MutableData<T>());
int ld_gemm_output = max(seq_len, D);
CudaT* gemm_output = reuse_output ? reinterpret_cast<CudaT*>(output->template MutableData<T>())
: (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query);
int ld_gemm_output = reuse_output ? seq_len : D;
const CudaT one = ToCudaType<T>::FromFloat(1.0f);
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);