bfloat16 support for quickgelugrad (#18336)

### Description
<!-- Describe your changes. -->

Registers BFloat16 datatype as valid input type for CUDA QuickGeluGrad
Kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Prathik Rao 2023-11-08 08:40:02 -08:00 committed by GitHub
parent 2151c79bf1
commit 34f77eaa24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 4 deletions

View file

@ -43,11 +43,15 @@ namespace cuda {
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double)
#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \
ACTIVATION_GRAD_OP_HFD(name, ver, domain) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16)
ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);

View file

@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
#define ACTIVATION_GRAD_OP_NAME(name) \
BINARY_ELEMENTWISE_IMPL(name); \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name)
ACTIVATION_GRAD_OPS()
#undef ACTIVATION_GRAD_OP_NAME

View file

@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,