[ROCm] BFloat16 support (#10416)

* reducesum bf16 support

* bf16 for add/sub/mul/div

* fix build

* bf16 for Cast

* bf16 for softmax

Co-authored-by: root <root@GCRAMDRR1-MI100-087.redmond.corp.microsoft.com>
This commit is contained in:
ytaous 2022-01-28 22:43:27 -08:00 committed by GitHub
parent b02f4ece5e
commit 85cbe8367e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 177 additions and 37 deletions

View file

@ -45,6 +45,24 @@ SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
// SPECIALIZED_SOFTMAX_HELPER_IMPL(double)
SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16)
// miopenSoftmaxForward/Backward doesn't support BFloat16.
#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxComputeHelper<BFloat16, is_log_softmax>(hipStream_t stream, const BFloat16* X, \
const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \
typedef typename ToHipType<BFloat16>::MappedType HipT; \
int64_t N = input_shape.SizeToDimension(axis); \
int64_t D = input_shape.SizeFromDimension(axis); \
auto Y_data = reinterpret_cast<HipT*>(Y); \
auto X_data = reinterpret_cast<const HipT*>(X); \
dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>( \
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Softmax, \
@ -203,6 +221,7 @@ SPECIALIZED_COMPUTE(float)
// MIOpen double data type not supported
// SPECIALIZED_COMPUTE(double)
SPECIALIZED_COMPUTE(MLFloat16)
SPECIALIZED_COMPUTE(BFloat16)
} // namespace rocm
} // namespace onnxruntime

View file

@ -97,6 +97,7 @@ template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(
SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, int softmax_elements_stride, int batch_count) {
@ -119,6 +120,7 @@ template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)
}

View file

@ -379,7 +379,6 @@ Status PrepareForReduce(const Tensor* X,
const auto input_dims = input_shape.GetDims();
InlinedShapeVector<bool> reduced(rank, false);
prepare_reduce_metadata.output_dims.reserve(input_dims.size());
if (axes.size() > 0) {
prepare_reduce_metadata.output_dims = input_shape.AsShapeVector();
for (auto axis : axes) {
@ -393,6 +392,7 @@ Status PrepareForReduce(const Tensor* X,
}
} else {
// no axes provided (i.e.) default axes => reduce on all dims
prepare_reduce_metadata.output_dims.reserve(input_dims.size());
for (auto dim : input_dims) {
ORT_ENFORCE(keepdims || dim != 0,
"Can't reduce on dim with value of 0 if 'keepdims' is false. "
@ -823,6 +823,111 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t)
SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t)
SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t)
template <>
template <>
Status ReduceKernel<true>::ComputeImpl<BFloat16, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const {
typedef typename ToHipType<BFloat16>::MappedType HipT;
const Tensor* X = ctx->Input<Tensor>(0);
TensorShapeVector axes;
size_t num_inputs = ctx->InputCount();
if (num_inputs == 2) {
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->template Data<int64_t>();
axes.assign(data, data + nDims);
} else {
axes.assign(axes_.begin(), axes_.end());
}
if (axes.empty() && noop_with_empty_axes_) {
auto* Y = ctx->Output(0, X->Shape());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData<BFloat16>(), X->template Data<BFloat16>(),
X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream()));
return Status::OK();
}
PrepareReduceMetadata prepare_reduce_metadata;
ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata));
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
int64_t input_count = prepare_reduce_metadata.input_count;
int64_t output_count = prepare_reduce_metadata.output_count;
auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen;
auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen;
if (input_count == 0) {
assert(Y->Shape().Size() == 0);
return Status::OK();
}
if (input_count == output_count) {
if (Y->template MutableData<BFloat16>() != X->template Data<BFloat16>()) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData<BFloat16>(), X->template Data<BFloat16>(),
input_count * sizeof(BFloat16), hipMemcpyDeviceToDevice, Stream()));
}
return Status::OK();
}
if (fast_reduction_ && !ctx->GetUseDeterministicCompute()) {
int m{}, n{};
const auto applicable_matrix_reduction =
get_applicable_matrix_reduction(miopen_reduce_op, X->Shape().GetDims(), axes, m, n);
switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
return reduce_matrix_rows(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()),
reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), m, n);
}
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
auto buffer = rocm_ep_->GetScratchBuffer<void>(buffer_size_bytes);
return reduce_matrix_columns(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()),
reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), m, n, buffer.get(),
buffer_size_bytes);
}
default:
break;
}
}
HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes(), Stream()));
size_t indices_bytes = 0;
size_t workspace_bytes = 0;
MiopenTensor input_tensor;
MiopenTensor output_tensor;
MiopenReduceDescriptor reduce_desc;
miopenDataType_t miopen_type_X = miopenFloat;
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count);
Impl_Cast<HipT, float>(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()), temp_X.get(),
X->Shape().Size());
ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_NO_INDICES));
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X));
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X));
MIOPEN_RETURN_IF_ERROR(
miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes));
MIOPEN_RETURN_IF_ERROR(
miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes));
IAllocatorUniquePtr<uint32_t> indices_rocm = GetScratchBuffer<uint32_t>(indices_bytes);
IAllocatorUniquePtr<HipT> workspace_rocm = GetScratchBuffer<HipT>(workspace_bytes);
const auto one = Consts<float>::One;
const auto zero = Consts<float>::Zero;
auto temp_Y = GetScratchBuffer<float>(output_count);
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(MiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_Y.get()));
Impl_Cast<float, HipT>(Stream(), temp_Y.get(), reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), output_count);
return Status::OK();
}
namespace ReductionOps {
template <typename T, miopenReduceTensorIndices_t ReduceTensorIndices>
@ -880,7 +985,8 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N
#define REGISTER_KERNEL_HFD(name) \
REGISTER_KERNEL_TYPED(name, MLFloat16) \
REGISTER_KERNEL_TYPED(name, float)
REGISTER_KERNEL_TYPED(name, float) \
REGISTER_KERNEL_TYPED(name, BFloat16)
// REGISTER_KERNEL_TYPED(name, double)
#define REGISTER_KERNEL_HFD_11(name) \
@ -926,6 +1032,7 @@ REGISTER_KERNEL_TYPED_13(ReduceSum, float)
// REGISTER_KERNEL_TYPED_13(ReduceSum, double)
REGISTER_KERNEL_TYPED_13(ReduceSum, int32_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, int64_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, BFloat16)
REGISTER_KERNEL_HFD(ReduceLogSum)
REGISTER_KERNEL_HFD(ReduceSumSquare)

View file

@ -1120,18 +1120,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);
// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
@ -1188,10 +1188,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);
//OpSet 15
@ -1964,18 +1964,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,
// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum)>,
@ -2031,10 +2031,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,
// OpSet 15

View file

@ -18,7 +18,7 @@ std::vector<MLFloat16> MakeMLFloat16(const std::initializer_list<float>& input)
return output;
}
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void TestFloat16(const char* op_name, const std::vector<int64_t>& lhs_dim,
const std::initializer_list<float>& lhs_values, const std::vector<int64_t>& rhs_dim,
const std::initializer_list<float>& rhs_values, const std::vector<int64_t>& out_dim,
@ -29,7 +29,11 @@ void TestFloat16(const char* op_name, const std::vector<int64_t>& lhs_dim,
tester.AddInput<MLFloat16>("B", rhs_dim, MakeMLFloat16(rhs_values));
tester.AddOutput<MLFloat16>("C", out_dim, MakeMLFloat16(out_values));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
@ -39,7 +43,11 @@ void TestFloat16(const char* op_name, const std::vector<int64_t>& lhs_dim,
tester.AddInput<BFloat16>("B", rhs_dim, MakeBFloat16(rhs_values));
tester.AddOutput<BFloat16>("C", out_dim, MakeBFloat16(out_values));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
@ -128,7 +136,7 @@ TEST(MathOpTest, Add_float) {
test.Run();
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values);
#endif
}
@ -163,7 +171,7 @@ TEST(MathOpTest, Add_Broadcast_Axis) {
test.AddOutput<float>("C", dims, out_values);
test.Run(OpTester::ExpectResult::kExpectSuccess, "");
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values);
#endif
}
@ -186,7 +194,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) {
{kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1]
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values);
#endif
}
@ -208,7 +216,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) {
{kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1]
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values);
#endif
}
@ -404,7 +412,7 @@ TEST(MathOpTest, Sub) {
test.Run();
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values);
#endif
}
@ -462,7 +470,7 @@ TEST(MathOpTest, Mul) {
test.Run();
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values);
#endif
}
@ -501,7 +509,7 @@ TEST(MathOpTest, Div) {
test.Run();
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TestFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values);
#endif
}

View file

@ -106,9 +106,9 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) {
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

View file

@ -210,9 +210,9 @@ TEST(MathOpTest, MatMul_BFloat16) {
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({10.0f, 10.0f, 10.0f, -10.0f, -10.0f, -10.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

View file

@ -1491,7 +1491,7 @@ TEST(ReductionOpTest, ReduceSum_half_bert) {
// Add more UTs for half as needed
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(ReductionOpTest, ReduceSumBFloat16) {
OpTester test("ReduceSum", 14);
test.AddAttribute("keepdims", (int64_t)0);
@ -1500,7 +1500,11 @@ TEST(ReductionOpTest, ReduceSumBFloat16) {
test.AddInput<int64_t>("axes", {2}, std::vector<int64_t>{0, 1});
test.AddOutput<BFloat16>("reduced", {2}, MakeBFloat16({36.0f, 42.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif