mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
fix error due to () not used on operator priority. (#14699)
This commit is contained in:
parent
6f99fb9d4b
commit
ff3aed8540
1 changed files with 5 additions and 4 deletions
|
|
@ -146,7 +146,8 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
|
|||
const auto BNS = batch_size * num_heads_ * seq_len;
|
||||
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
|
||||
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
|
||||
size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
|
||||
bool reuse_output = (seq_len >= D);
|
||||
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
|
||||
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());
|
||||
|
||||
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
|
||||
|
|
@ -161,9 +162,9 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
|
|||
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);
|
||||
|
||||
// reuse output if possible
|
||||
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
|
||||
: reinterpret_cast<CudaT*>(output->template MutableData<T>());
|
||||
int ld_gemm_output = max(seq_len, D);
|
||||
CudaT* gemm_output = reuse_output ? reinterpret_cast<CudaT*>(output->template MutableData<T>())
|
||||
: (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query);
|
||||
int ld_gemm_output = reuse_output ? seq_len : D;
|
||||
|
||||
const CudaT one = ToCudaType<T>::FromFloat(1.0f);
|
||||
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);
|
||||
|
|
|
|||
Loading…
Reference in a new issue