From c9fbd0b15af9f2b595e59632c146f997790ed25c Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Wed, 9 Feb 2022 12:51:06 -0800 Subject: [PATCH] Optimize cuComputePartGradGammaBeta kernel for MI100 (#10475) * Optimize cuComputePartGradGammaBeta kernel for MI100 Co-authored-by: root Co-authored-by: Jeff Daily --- .../orttraining/training_ops/cuda/nn/layer_norm.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc index 15bac8f7d7..4813382211 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc @@ -89,7 +89,12 @@ Status LayerNormGrad::ComputeInternal(OpKernelContext* p_op_ke bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); } + #ifndef USE_ROCM const int part_size = 16; + #else + // Optimization for ROCm MI100 + const int part_size = 64; + #endif auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2); @@ -138,7 +143,12 @@ Status InvertibleLayerNormGrad::ComputeInternal(OpKernelContext* p_op_kern auto scale_grad_data = reinterpret_cast(scale_grad->template MutableData()); auto bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); + #ifndef USE_ROCM const int part_size = 16; + #else + // Optimization for ROCm MI100 + const int part_size = 64; + #endif auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2);