mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add BF16 for Scale Op. (#20753)
Adding Bfloat16 to scale op --------- Co-authored-by: Adam Louly <adamlouly@microsoft.com@h100vm-ort.kxelwkzfzxguje5bxvwxxs135a.gvxx.internal.cloudapp.net>
This commit is contained in:
parent
a39f8862fd
commit
529feb01f4
6 changed files with 78 additions and 4 deletions
|
|
@ -3247,11 +3247,11 @@ Return true if all elements are true and false otherwise.
|
|||
static_cast<int64_t>(0))
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
{"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)"},
|
||||
"Constrain input types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"ScaleT",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
|
||||
{"tensor(float16)", "tensor(bfloat16)", "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
|
||||
"Constrain scale types to float and int64 tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -13,10 +14,21 @@ struct ScaleInputOutput {
|
|||
output_up_half.resize(output_up_float.size());
|
||||
output_down_half.resize(output_down_float.size());
|
||||
scale_half.resize(scale_float.size());
|
||||
|
||||
input_bf16.resize(input_float.size());
|
||||
output_up_bf16.resize(output_up_float.size());
|
||||
output_down_bf16.resize(output_down_float.size());
|
||||
scale_bf16.resize(scale_float.size());
|
||||
|
||||
ConvertFloatToMLFloat16(input_float.data(), input_half.data(), int(input_float.size()));
|
||||
ConvertFloatToMLFloat16(output_up_float.data(), output_up_half.data(), int(output_up_float.size()));
|
||||
ConvertFloatToMLFloat16(output_down_float.data(), output_down_half.data(), int(output_down_float.size()));
|
||||
ConvertFloatToMLFloat16(scale_float.data(), scale_half.data(), int(scale_float.size()));
|
||||
|
||||
input_bf16 = FloatsToBFloat16s(input_float);
|
||||
output_up_bf16 = FloatsToBFloat16s(output_up_float);
|
||||
output_down_bf16 = FloatsToBFloat16s(output_down_float);
|
||||
scale_bf16 = FloatsToBFloat16s(scale_float);
|
||||
}
|
||||
|
||||
// Fp32 Inputs/Output
|
||||
|
|
@ -36,6 +48,12 @@ struct ScaleInputOutput {
|
|||
std::vector<MLFloat16> output_up_half;
|
||||
std::vector<MLFloat16> output_down_half;
|
||||
std::vector<MLFloat16> scale_half;
|
||||
|
||||
// BFloat16 Inputs/Output
|
||||
std::vector<BFloat16> input_bf16;
|
||||
std::vector<BFloat16> output_up_bf16;
|
||||
std::vector<BFloat16> output_down_bf16;
|
||||
std::vector<BFloat16> scale_bf16;
|
||||
};
|
||||
|
||||
TEST(CudaKernelTest, ScaleFloatFloatScaleUp) {
|
||||
|
|
@ -116,5 +134,53 @@ TEST(CudaKernelTest, ScaleHalfInt64ScaleDown) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(CudaKernelTest, ScaleBFloat16BFloat16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
ScaleInputOutput data;
|
||||
OpTester test("Scale", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<BFloat16>("input", {3}, data.input_bf16);
|
||||
test.AddInput<BFloat16>("scale", {1}, data.scale_bf16);
|
||||
test.AddOutput<BFloat16>("output", {3}, data.output_up_bf16);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ScaleFloatBFloat16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
ScaleInputOutput data;
|
||||
OpTester test("Scale", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {3}, data.input_float);
|
||||
test.AddInput<BFloat16>("scale", {1}, data.scale_bf16);
|
||||
test.AddOutput<float>("output", {3}, data.output_up_float);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Gath
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal);
|
||||
|
|
@ -421,6 +422,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal)>,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ namespace cuda {
|
|||
.TypeConstraint("ScaleT", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<double>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<BFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int64_t>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1), \
|
||||
|
|
@ -47,7 +48,7 @@ Status Scale<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
float scale_value;
|
||||
auto scale_tensor = context->Input<Tensor>(1);
|
||||
utils::MLTypeCallDispatcher<float, double, MLFloat16, int64_t, int32_t> t_disp(scale_tensor->GetElementType());
|
||||
utils::MLTypeCallDispatcher<float, double, MLFloat16, BFloat16, int64_t, int32_t> t_disp(scale_tensor->GetElementType());
|
||||
t_disp.Invoke<GetScaleValueImpl>(scale_tensor, scale_value);
|
||||
|
||||
if (scale_down_) {
|
||||
|
|
@ -69,10 +70,12 @@ Status Scale<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
REGISTER_SCALE_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_SCALE_KERNEL_TYPED(float)
|
||||
REGISTER_SCALE_KERNEL_TYPED(double)
|
||||
REGISTER_SCALE_KERNEL_TYPED(BFloat16)
|
||||
|
||||
template Status Scale<MLFloat16>::ComputeInternal(OpKernelContext* context) const;
|
||||
template Status Scale<float>::ComputeInternal(OpKernelContext* context) const;
|
||||
template Status Scale<double>::ComputeInternal(OpKernelContext* context) const;
|
||||
template Status Scale<BFloat16>::ComputeInternal(OpKernelContext* context) const;
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ template void Impl_Scale<T>( \
|
|||
SPECIALIZE_SCALE_IMPL(half)
|
||||
SPECIALIZE_SCALE_IMPL(float)
|
||||
SPECIALIZE_SCALE_IMPL(double)
|
||||
SPECIALIZE_SCALE_IMPL(BFloat16)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -156,6 +156,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Gath
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder);
|
||||
|
|
@ -360,6 +361,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue