mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Optimize cuComputePartGradGammaBeta kernel for MI100 (#10475)
* Optimize cuComputePartGradGammaBeta kernel for MI100 Co-authored-by: root <root@gb-sjc2-10.local.lan> Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
7a2bf3c24c
commit
c9fbd0b15a
1 changed files with 10 additions and 0 deletions
|
|
@ -89,7 +89,12 @@ Status LayerNormGrad<T, U, simplified>::ComputeInternal(OpKernelContext* p_op_ke
|
|||
bias_grad_data = reinterpret_cast<CudaT*>(bias_grad->template MutableData<T>());
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
const int part_size = 16;
|
||||
#else
|
||||
// Optimization for ROCm MI100
|
||||
const int part_size = 64;
|
||||
#endif
|
||||
auto part_grad_gamma = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
auto part_grad_beta = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
|
||||
|
|
@ -138,7 +143,12 @@ Status InvertibleLayerNormGrad<T, U>::ComputeInternal(OpKernelContext* p_op_kern
|
|||
auto scale_grad_data = reinterpret_cast<CudaT*>(scale_grad->template MutableData<T>());
|
||||
auto bias_grad_data = reinterpret_cast<CudaT*>(bias_grad->template MutableData<T>());
|
||||
|
||||
#ifndef USE_ROCM
|
||||
const int part_size = 16;
|
||||
#else
|
||||
// Optimization for ROCm MI100
|
||||
const int part_size = 64;
|
||||
#endif
|
||||
auto part_grad_gamma = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
auto part_grad_beta = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue