mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
bfloat16 support for quickgelugrad (#18336)
### Description <!-- Describe your changes. --> Registers BFloat16 datatype as valid input type for CUDA QuickGeluGrad Kernel. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. --------- Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
2151c79bf1
commit
34f77eaa24
3 changed files with 11 additions and 4 deletions
|
|
@ -43,11 +43,15 @@ namespace cuda {
|
|||
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \
|
||||
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double)
|
||||
|
||||
#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \
|
||||
ACTIVATION_GRAD_OP_HFD(name, ver, domain) \
|
||||
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16)
|
||||
|
||||
ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);
|
||||
|
||||
|
|
|
|||
|
|
@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
|
|||
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \
|
||||
template void Impl_##name<T>(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count);
|
||||
|
||||
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
|
||||
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
|
||||
|
||||
#define ACTIVATION_GRAD_OP_NAME(name) \
|
||||
BINARY_ELEMENTWISE_IMPL(name); \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name)
|
||||
|
||||
ACTIVATION_GRAD_OPS()
|
||||
#undef ACTIVATION_GRAD_OP_NAME
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
|
||||
|
|
@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue