From 529feb01f4f8aa96e5a4a2f55078fbfeff8f8f0e Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Wed, 22 May 2024 17:01:17 -0700 Subject: [PATCH] Add BF16 for Scale Op. (#20753) Adding Bfloat16 to scale op --------- Co-authored-by: Adam Louly --- .../core/graph/training_op_defs.cc | 4 +- .../test/training_ops/cuda/scale_test.cc | 68 ++++++++++++++++++- .../cuda/cuda_training_kernels.cc | 2 + .../training_ops/cuda/math/scale.cc | 5 +- .../training_ops/cuda/math/scale_impl.cu | 1 + .../rocm/rocm_training_kernels.cc | 2 + 6 files changed, 78 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 677f383264..ad449459fb 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3247,11 +3247,11 @@ Return true if all elements are true and false otherwise. static_cast(0)) .TypeConstraint( "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)"}, "Constrain input types to float tensors.") .TypeConstraint( "ScaleT", - {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"}, + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"}, "Constrain scale types to float and int64 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); diff --git a/orttraining/orttraining/test/training_ops/cuda/scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/scale_test.cc index f4857a1b68..ec48cccf92 100644 --- a/orttraining/orttraining/test/training_ops/cuda/scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/scale_test.cc @@ -3,6 +3,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" namespace onnxruntime { namespace test { @@ -13,10 +14,21 @@ struct ScaleInputOutput { output_up_half.resize(output_up_float.size()); output_down_half.resize(output_down_float.size()); scale_half.resize(scale_float.size()); + + input_bf16.resize(input_float.size()); + output_up_bf16.resize(output_up_float.size()); + output_down_bf16.resize(output_down_float.size()); + scale_bf16.resize(scale_float.size()); + ConvertFloatToMLFloat16(input_float.data(), input_half.data(), int(input_float.size())); ConvertFloatToMLFloat16(output_up_float.data(), output_up_half.data(), int(output_up_float.size())); ConvertFloatToMLFloat16(output_down_float.data(), output_down_half.data(), int(output_down_float.size())); ConvertFloatToMLFloat16(scale_float.data(), scale_half.data(), int(scale_float.size())); + + input_bf16 = FloatsToBFloat16s(input_float); + output_up_bf16 = FloatsToBFloat16s(output_up_float); + output_down_bf16 = FloatsToBFloat16s(output_down_float); + scale_bf16 = FloatsToBFloat16s(scale_float); } // Fp32 Inputs/Output @@ -36,6 +48,12 @@ struct ScaleInputOutput { std::vector output_up_half; std::vector output_down_half; std::vector scale_half; + + // BFloat16 Inputs/Output + std::vector input_bf16; + std::vector output_up_bf16; + std::vector output_down_bf16; + std::vector scale_bf16; }; TEST(CudaKernelTest, ScaleFloatFloatScaleUp) { @@ -116,5 +134,53 @@ TEST(CudaKernelTest, ScaleHalfInt64ScaleDown) { test.Run(); } +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(CudaKernelTest, ScaleBFloat16BFloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16"; + return; + } +#endif + ScaleInputOutput data; + OpTester test("Scale", 1, onnxruntime::kMSDomain); + test.AddInput("input", {3}, data.input_bf16); + test.AddInput("scale", {1}, data.scale_bf16); + test.AddOutput("output", {3}, data.output_up_bf16); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(CudaKernelTest, ScaleFloatBFloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16"; + return; + } +#endif + ScaleInputOutput data; + OpTester test("Scale", 1, onnxruntime::kMSDomain); + test.AddInput("input", {3}, data.input_float); + test.AddInput("scale", {1}, data.scale_bf16); + test.AddOutput("output", {3}, data.output_up_float); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +#endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8b2bc7e2ef..bcc9a06f5a 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -164,6 +164,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Gath class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal); @@ -421,6 +422,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/math/scale.cc b/orttraining/orttraining/training_ops/cuda/math/scale.cc index 5550957482..925dfa33b8 100644 --- a/orttraining/orttraining/training_ops/cuda/math/scale.cc +++ b/orttraining/orttraining/training_ops/cuda/math/scale.cc @@ -21,6 +21,7 @@ namespace cuda { .TypeConstraint("ScaleT", {DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}) \ .InputMemoryType(OrtMemTypeCPUInput, 1), \ @@ -47,7 +48,7 @@ Status Scale::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; float scale_value; auto scale_tensor = context->Input(1); - utils::MLTypeCallDispatcher t_disp(scale_tensor->GetElementType()); + utils::MLTypeCallDispatcher t_disp(scale_tensor->GetElementType()); t_disp.Invoke(scale_tensor, scale_value); if (scale_down_) { @@ -69,10 +70,12 @@ Status Scale::ComputeInternal(OpKernelContext* context) const { REGISTER_SCALE_KERNEL_TYPED(MLFloat16) REGISTER_SCALE_KERNEL_TYPED(float) REGISTER_SCALE_KERNEL_TYPED(double) +REGISTER_SCALE_KERNEL_TYPED(BFloat16) template Status Scale::ComputeInternal(OpKernelContext* context) const; template Status Scale::ComputeInternal(OpKernelContext* context) const; template Status Scale::ComputeInternal(OpKernelContext* context) const; +template Status Scale::ComputeInternal(OpKernelContext* context) const; } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu b/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu index 376a9696c7..e1e4fcc968 100644 --- a/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu @@ -61,6 +61,7 @@ template void Impl_Scale( \ SPECIALIZE_SCALE_IMPL(half) SPECIALIZE_SCALE_IMPL(float) SPECIALIZE_SCALE_IMPL(double) +SPECIALIZE_SCALE_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index e107a2542f..7824e98fe8 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -156,6 +156,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Gath class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder); @@ -360,6 +361,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,