From 34f77eaa243ed16bbcea8fa585c9f89539488b27 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 08:40:02 -0800 Subject: [PATCH] bfloat16 support for quickgelugrad (#18336) ### Description Registers BFloat16 datatype as valid input type for CUDA QuickGeluGrad Kernel. ### Motivation and Context Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. --------- Co-authored-by: Prathik Rao --- .../training_ops/cuda/activation/activations_grad.cc | 6 +++++- .../training_ops/cuda/activation/activations_grad_impl.cu | 7 ++++--- .../orttraining/training_ops/cuda/cuda_training_kernels.cc | 2 ++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc index 7fde69d758..98e3b878c9 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc @@ -43,11 +43,15 @@ namespace cuda { ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \ ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double) +#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \ + ACTIVATION_GRAD_OP_HFD(name, ver, domain) \ + ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16) + ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain); -ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain); +ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain); diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index 164aba8667..dd6a44b9e3 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad { #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \ template void Impl_##name(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count); -#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \ +#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16) #define ACTIVATION_GRAD_OP_NAME(name) \ BINARY_ELEMENTWISE_IMPL(name); \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name) + SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name) ACTIVATION_GRAD_OPS() #undef ACTIVATION_GRAD_OP_NAME diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index ae4f48b6b4..eeaa51c4dc 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad); @@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,