Add BF16 for Scale Op. (#20753)

Adding Bfloat16 to scale op

---------

Co-authored-by: Adam Louly <adamlouly@microsoft.com@h100vm-ort.kxelwkzfzxguje5bxvwxxs135a.gvxx.internal.cloudapp.net>
This commit is contained in:
Adam Louly 2024-05-22 17:01:17 -07:00 committed by GitHub
parent a39f8862fd
commit 529feb01f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 78 additions and 4 deletions

View file

@ -3247,11 +3247,11 @@ Return true if all elements are true and false otherwise.
static_cast<int64_t>(0))
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
{"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint(
"ScaleT",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
{"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
"Constrain scale types to float and int64 tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);

View file

@ -3,6 +3,7 @@
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
namespace onnxruntime {
namespace test {
@ -13,10 +14,21 @@ struct ScaleInputOutput {
output_up_half.resize(output_up_float.size());
output_down_half.resize(output_down_float.size());
scale_half.resize(scale_float.size());
input_bf16.resize(input_float.size());
output_up_bf16.resize(output_up_float.size());
output_down_bf16.resize(output_down_float.size());
scale_bf16.resize(scale_float.size());
ConvertFloatToMLFloat16(input_float.data(), input_half.data(), int(input_float.size()));
ConvertFloatToMLFloat16(output_up_float.data(), output_up_half.data(), int(output_up_float.size()));
ConvertFloatToMLFloat16(output_down_float.data(), output_down_half.data(), int(output_down_float.size()));
ConvertFloatToMLFloat16(scale_float.data(), scale_half.data(), int(scale_float.size()));
input_bf16 = FloatsToBFloat16s(input_float);
output_up_bf16 = FloatsToBFloat16s(output_up_float);
output_down_bf16 = FloatsToBFloat16s(output_down_float);
scale_bf16 = FloatsToBFloat16s(scale_float);
}
// Fp32 Inputs/Output
@ -36,6 +48,12 @@ struct ScaleInputOutput {
std::vector<MLFloat16> output_up_half;
std::vector<MLFloat16> output_down_half;
std::vector<MLFloat16> scale_half;
// BFloat16 Inputs/Output
std::vector<BFloat16> input_bf16;
std::vector<BFloat16> output_up_bf16;
std::vector<BFloat16> output_down_bf16;
std::vector<BFloat16> scale_bf16;
};
TEST(CudaKernelTest, ScaleFloatFloatScaleUp) {
@ -116,5 +134,53 @@ TEST(CudaKernelTest, ScaleHalfInt64ScaleDown) {
test.Run();
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(CudaKernelTest, ScaleBFloat16BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16";
return;
}
#endif
ScaleInputOutput data;
OpTester test("Scale", 1, onnxruntime::kMSDomain);
test.AddInput<BFloat16>("input", {3}, data.input_bf16);
test.AddInput<BFloat16>("scale", {1}, data.scale_bf16);
test.AddOutput<BFloat16>("output", {3}, data.output_up_bf16);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
TEST(CudaKernelTest, ScaleFloatBFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16";
return;
}
#endif
ScaleInputOutput data;
OpTester test("Scale", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {3}, data.input_float);
test.AddInput<BFloat16>("scale", {1}, data.scale_bf16);
test.AddOutput<float>("output", {3}, data.output_up_float);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif
} // namespace test
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -164,6 +164,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Gath
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal);
@ -421,6 +422,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal)>,

View file

@ -21,6 +21,7 @@ namespace cuda {
.TypeConstraint("ScaleT", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<double>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<BFloat16>(), \
DataTypeImpl::GetTensorType<int64_t>(), \
DataTypeImpl::GetTensorType<int32_t>()}) \
.InputMemoryType(OrtMemTypeCPUInput, 1), \
@ -47,7 +48,7 @@ Status Scale<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
float scale_value;
auto scale_tensor = context->Input<Tensor>(1);
utils::MLTypeCallDispatcher<float, double, MLFloat16, int64_t, int32_t> t_disp(scale_tensor->GetElementType());
utils::MLTypeCallDispatcher<float, double, MLFloat16, BFloat16, int64_t, int32_t> t_disp(scale_tensor->GetElementType());
t_disp.Invoke<GetScaleValueImpl>(scale_tensor, scale_value);
if (scale_down_) {
@ -69,10 +70,12 @@ Status Scale<T>::ComputeInternal(OpKernelContext* context) const {
REGISTER_SCALE_KERNEL_TYPED(MLFloat16)
REGISTER_SCALE_KERNEL_TYPED(float)
REGISTER_SCALE_KERNEL_TYPED(double)
REGISTER_SCALE_KERNEL_TYPED(BFloat16)
template Status Scale<MLFloat16>::ComputeInternal(OpKernelContext* context) const;
template Status Scale<float>::ComputeInternal(OpKernelContext* context) const;
template Status Scale<double>::ComputeInternal(OpKernelContext* context) const;
template Status Scale<BFloat16>::ComputeInternal(OpKernelContext* context) const;
} // namespace cuda
} // namespace onnxruntime

View file

@ -61,6 +61,7 @@ template void Impl_Scale<T>( \
SPECIALIZE_SCALE_IMPL(half)
SPECIALIZE_SCALE_IMPL(float)
SPECIALIZE_SCALE_IMPL(double)
SPECIALIZE_SCALE_IMPL(BFloat16)
} // namespace cuda
} // namespace onnxruntime

View file

@ -156,6 +156,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Gath
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder);
@ -360,6 +361,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder)>,