mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Follow up fix for Gelu impl (#19693)
### Follow up fix for Gelu impl There are two minor comments in https://github.com/microsoft/onnxruntime/pull/19560. Fix them in this pull request. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
2a857d9a86
commit
acbfc29f27
3 changed files with 7 additions and 7 deletions
|
|
@ -293,7 +293,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e
|
|||
export ORTMODULE_MEMORY_OPT_LEVEL=0
|
||||
```
|
||||
|
||||
### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
|
||||
#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
|
||||
|
||||
- **Feature Area**: *ORTMODULE/Optimizations*
|
||||
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@
|
|||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
|
||||
#ifdef USE_ROCM
|
||||
#include "contrib_ops/rocm/bert/elementwise.h"
|
||||
#endif
|
||||
#ifdef USE_CUDA
|
||||
#else
|
||||
#include "contrib_ops/cuda/bert/transformer_common.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ using namespace ONNX_NAMESPACE;
|
|||
|
||||
template <typename T>
|
||||
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
|
||||
#ifdef USE_CUDA
|
||||
#ifndef USE_ROCM
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
use_half2_ = !options->DisableHalf2();
|
||||
#endif
|
||||
|
|
@ -63,8 +62,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
reinterpret_cast<const CudaT*>(input->Data<T>()), static_cast<int>(input_length),
|
||||
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()));
|
||||
#endif
|
||||
#ifdef USE_CUDA
|
||||
#else
|
||||
return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
|
||||
Stream(context),
|
||||
static_cast<int>(input_length),
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
|
||||
#ifndef USE_ROCM
|
||||
bool use_half2_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
Loading…
Reference in a new issue