diff --git a/onnxruntime/contrib_ops/rocm/math/bias_softmax.cc b/onnxruntime/contrib_ops/rocm/math/bias_softmax.cc new file mode 100644 index 0000000000..9d318a3edf --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/bias_softmax.cc @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/math/bias_softmax.h" + +#include "core/providers/common.h" + +using namespace onnxruntime; +using namespace onnxruntime::rocm; +using namespace onnxruntime::contrib::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +void DispatchBiasSoftmaxForwardImpl( + Tensor* output_tensor, + const Tensor* input_tensor, + const Tensor* input_bias_tensor, + int element_count, + int batch_count, + int batch_stride, + int bias_broadcast_size_per_batch); + +template +void DispatchBiasSoftMaxForwardViaDnnLibraryImpl( + miopenHandle_t miopenHandle, + int element_count, + int batch_count, + int broadcast_axis, + int softmax_axis, + const onnxruntime::TensorShape& X_shape, + const onnxruntime::Tensor* X, + const onnxruntime::TensorShape& B_shape, + const onnxruntime::Tensor* B, + onnxruntime::Tensor* Y); + +ONNX_OPERATOR_KERNEL_EX( + BiasSoftmax, + kMSDomain, + 1, + kRocmExecutionProvider, + KernelDefBuilder().TypeConstraint("T", { + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType() + }), + BiasSoftmax); + +Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const { + const TensorShape& X_shape{ctx->Input(0)->Shape()}; + const TensorShape& B_shape{ctx->Input(1)->Shape()}; + + const Tensor* X = ctx->Input(0); + const Tensor* B = ctx->Input(1); + Tensor* Y = ctx->Output(0, X_shape); + + const int softmax_axis = static_cast(HandleNegativeAxis(softmax_axis_, X_shape.NumDimensions())); + const int N = static_cast(X_shape.SizeToDimension(softmax_axis)); + const int D = static_cast(X_shape.SizeFromDimension(softmax_axis)); + + const int broadcast_axis = static_cast(HandleNegativeAxis(broadcast_axis_, X_shape.NumDimensions())); + const int broadcast_size = N / static_cast(X_shape.SizeToDimension(broadcast_axis)); + + const size_t elem_size = X->DataType()->Size(); + if (D <= 1024 && D * elem_size <= 4096) { + // expect thread blocks can fill SM at high occupancy without overflowing registers + utils::MLTypeCallDispatcher + t_disp(X->GetElementType()); + t_disp.Invoke(Y, X, B, D, N, D, broadcast_size); + } else { + // need to fallback to add kernel + CUDA DNN library softmax call :/ + utils::MLTypeCallDispatcher + t_disp(X->GetElementType()); + t_disp.Invoke(MiopenHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y); + } + + return Status::OK(); +} + +template +void DispatchBiasSoftmaxForward::operator()( + Tensor* output, + const Tensor* input, + const Tensor* input_bias, + int element_count, + int batch_count, + int batch_stride, + int bias_broadcast_size_per_batch) { + DispatchBiasSoftmaxForwardImpl( + output, + input, + input_bias, + element_count, + batch_count, + batch_stride, + bias_broadcast_size_per_batch); +} + +template +void DispatchBiasSoftMaxForwardViaDnnLibrary::operator()( + miopenHandle_t miopenHandle, + int element_count, + int batch_count, + int broadcast_axis, + int softmax_axis, + const onnxruntime::TensorShape& X_shape, + const onnxruntime::Tensor* X, + const onnxruntime::TensorShape& B_shape, + const onnxruntime::Tensor* B, + onnxruntime::Tensor* Y) { + DispatchBiasSoftMaxForwardViaDnnLibraryImpl( + miopenHandle, + element_count, + batch_count, + broadcast_axis, + softmax_axis, + X_shape, + X, + B_shape, + B, + Y); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/rocm/math/bias_softmax.h b/onnxruntime/contrib_ops/rocm/math/bias_softmax.h new file mode 100644 index 0000000000..04bc4d93b0 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/bias_softmax.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "gsl/gsl" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct DispatchBiasSoftmaxForward { + void operator()( + Tensor* output, + const Tensor* input, + const Tensor* input_bias, + int element_count, + int batch_count, + int batch_stride, + int bias_broadcast_size_per_batch); +}; + +template +struct DispatchBiasSoftMaxForwardViaDnnLibrary { + void operator()( + miopenHandle_t miopenHandle, + int element_count, + int batch_count, + int broadcast_axis, + int softmax_axis, + const onnxruntime::TensorShape& X_shape, + const onnxruntime::Tensor* X, + const onnxruntime::TensorShape& B_shape, + const onnxruntime::Tensor* B, + onnxruntime::Tensor* Y); +}; + +class BiasSoftmax final : public onnxruntime::rocm::RocmKernel { + public: + BiasSoftmax(const OpKernelInfo& info) : RocmKernel{info} { + info.GetAttrOrDefault("softmax_axis", &softmax_axis_, static_cast(1)); + info.GetAttrOrDefault("broadcast_axis", &broadcast_axis_, static_cast(1)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t softmax_axis_; + int64_t broadcast_axis_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/bias_softmax_impl.cu b/onnxruntime/contrib_ops/rocm/math/bias_softmax_impl.cu new file mode 100644 index 0000000000..e8aad12e68 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/bias_softmax_impl.cu @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/math/bias_softmax.h" + +#include +#include + +#include "hip/hip_runtime.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh" +#include "core/providers/common.h" +#include "core/providers/rocm/miopen_common.h" +#include "core/providers/rocm/shared_inc/accumulation_type.h" +#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh" +#include "core/providers/rocm/math/softmax_impl.cuh" + +using namespace onnxruntime; +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +// Duplicated softmax_impl.cu here +// So far attempt to use shared kernel with additional template resulted in lost performance + +// Note: The intended case for 'input_bias' is the input sequence mask for transformer models +// As an additive mask, it should be zero for preserved tokens and -infty for tokens to screen +// The mask will broadcast from [batch_size, 1, 1, seq_len] to input [batch_size, num_heads, seq_len, seq_len] +// Here element_count = seq_len and bias_broadcast_size_per_batch = num_heads * seq_len + +// The softmax + additive mask fusion follows NVIDIA apex's additive_masked_softmax_warp_forward +// see https://github.com/NVIDIA/apex/blob/4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a/apex/contrib/csrc/multihead_attn/softmax.h + +template +__global__ void BiasSoftmaxWarpForward( + output_t* output, + const input_t* input, + const input_t* input_bias, + int element_count, + int batch_count, + int batch_stride, + int bias_broadcast_count_per_batch) { + // "WARP" refers to cooperative threads and might not equal 32 threads of GPU warp + // thread block is (WARP_SIZE, 128/WARP_SIZE) + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = next_power_of_two < GPU_WARP_SIZE ? next_power_of_two : GPU_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + // constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int WARP_BATCH = 1; + + // each "WARP" (<=32) processes WARP_BATCH(one of {1,2}) batches + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // last warp may have fewer batches + int local_batches = batch_count - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // thread will process elements (local_index + n * warp_size) within batch + int local_idx = threadIdx.x; + + // push input, input_bias output pointers to batch we need to process + input += first_batch * batch_stride + local_idx; + output += first_batch * batch_stride + local_idx; + + // load from global memory and apply bias (likely an additive mask) + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + // the bias has assumed shape [batch_size, element_count] + // .. and needs to broadcast to [batch_size, broadcast_size, element_count] + int bias_offset = (first_batch + i) / bias_broadcast_count_per_batch * batch_stride + local_idx; + + int batch_element_count = (i >= local_batches) ? 0 : element_count; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + elements[i][it] = (acc_t)input[i * element_count + it * WARP_SIZE] + (acc_t)input_bias[bias_offset + it * WARP_SIZE]; + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // find maximum value within batch for numerical stability + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + // normalization factor Z = Sum[ exp(element_i), for element_i in batch ] + acc_t sum[WARP_BATCH]{acc_t(0.0)}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = expf((acc_t)(elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + +// write back normalized value = exp(element_i)/Z to global memory +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + output[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; + } else { + break; + } + } + } +} + +template +void DispatchBiasSoftmaxForwardImpl( + Tensor* output_tensor, + const Tensor* input_tensor, + const Tensor* input_bias_tensor, + int element_count, + int batch_count, + int batch_stride, + int bias_broadcast_size_per_batch) { + typedef typename ToHipType::MappedType HipT; + typedef HipT input_t; + typedef HipT output_t; + typedef AccumulationType_t acc_t; + + const auto* input = reinterpret_cast(input_tensor->template Data()); + const auto* input_bias = reinterpret_cast(input_bias_tensor->template Data()); + auto* output = reinterpret_cast(output_tensor->template MutableData()); + + if (element_count == 0) + return; + + int log2_elements = log2_ceil(element_count); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE); + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + // int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + int batches_per_warp = 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 256; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 1: // 2 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 2: // 4 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 3: // 8 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 4: // 16 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 5: // 32 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 6: // 64 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 7: // 128 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 8: // 256 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 9: // 512 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + case 10: // 1024 + hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward), dim3(blocks), dim3(threads), 0, 0, + output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch); + break; + default: + break; + } +} + +#define SPECIALIZED_BIAS_SOFTMAX_IMPL(T) \ + template void DispatchBiasSoftmaxForwardImpl( \ + Tensor * output_tensor, \ + const Tensor* input_tensor, \ + const Tensor* input_bias_tensor, \ + int element_count, \ + int batch_count, \ + int batch_stride, \ + int bias_broadcast_size_per_batch); + +SPECIALIZED_BIAS_SOFTMAX_IMPL(double) +SPECIALIZED_BIAS_SOFTMAX_IMPL(float) +SPECIALIZED_BIAS_SOFTMAX_IMPL(MLFloat16) + +// For large element count we fall back to explicit Add kernel + CUDA DNN library +// note: This is an unhappy path! There is no performance benefit for the fusion. +template +void DispatchBiasSoftMaxForwardViaDnnLibraryImpl( + miopenHandle_t miopenHandle, + int element_count, + int batch_count, + int broadcast_axis, + int softmax_axis, + const onnxruntime::TensorShape& X_shape, + const onnxruntime::Tensor* X, + const onnxruntime::TensorShape& B_shape, + const onnxruntime::Tensor* B, + onnxruntime::Tensor* Y) { + typedef typename ToHipType::MappedType HipT; + + const auto* X_data = reinterpret_cast(X->template Data()); + const auto* B_data = reinterpret_cast(B->template Data()); + auto* Y_data = reinterpret_cast(Y->template MutableData()); + + // binary elementise kernel requires input pitches + TArray lhs_padded_strides(X_shape.NumDimensions()); + for (int i = -1, lhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) { + size_t positive_i = X_shape.NumDimensions() + i; + lhs_padded_strides[positive_i] = lhs_pitch; + lhs_pitch *= X_shape[positive_i]; + } + + // set pitches for bias so it broadcasts along relevant dimensions + TArray rhs_padded_strides(X_shape.NumDimensions()); + for (int i = -1, rhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) { + size_t positive_ix = X_shape.NumDimensions() + i; + size_t positive_ib = B_shape.NumDimensions() + i; + if (broadcast_axis <= positive_ix && positive_ix < softmax_axis) { + rhs_padded_strides[positive_ix] = 0; + continue; + } + rhs_padded_strides[positive_ix] = rhs_pitch; + rhs_pitch *= B_shape[positive_ib]; + } + + TArray fdm_output_strides(X_shape.NumDimensions()); + for (int i = 0; i < fdm_output_strides.Size(); i++) + fdm_output_strides[i] = fast_divmod(lhs_padded_strides[i]); + fast_divmod fdm_H, fdm_C; + + // invoke elementwise add with broadcast kernel + ::onnxruntime::rocm::BinaryElementWiseImpl( + (int32_t)X_shape.NumDimensions(), + &lhs_padded_strides, + X_data, + &rhs_padded_strides, + B_data, + &fdm_output_strides, + fdm_H, + fdm_C, + Y_data, + OP_Add(), + (size_t)X_shape.Size()); + + // invoke cuda DNN library for Y = softmax(X) + std::vector dims({batch_count, 1, 1, element_count}); + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + onnxruntime::rocm::MiopenTensor input_tensor, output_tensor; + input_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType()); + output_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType()); + miopenSoftmaxForward_V2( + miopenHandle, + &alpha, + input_tensor, + Y_data, + &beta, + output_tensor, + Y_data, + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_INSTANCE); +} + +#define SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(T) \ + template void DispatchBiasSoftMaxForwardViaDnnLibraryImpl( \ + miopenHandle_t miopenHandle, \ + int element_count, \ + int batch_count, \ + int broadcast_axis, \ + int softmax_axis, \ + const onnxruntime::TensorShape& X_shape, \ + const Tensor* X_data, \ + const onnxruntime::TensorShape& B_shape, \ + const Tensor* B_data, \ + Tensor* Y_data); + +// SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(double) +SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(float) +SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(MLFloat16) + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 58342602df..300308d4dc 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -33,6 +33,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility @@ -141,6 +142,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index 0f915abc46..6f4e761e56 100644 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -44,7 +44,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s // check node is add and has single output if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || - !graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider}) || + !graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } @@ -224,9 +224,9 @@ Status BiasSoftmaxFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - // only support CUDA execution provider + // only support GPU execution provider auto& cep = GetCompatibleExecutionProviders(); - if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end()) + if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end() && cep.find(kRocmExecutionProvider) == cep.end()) return Status::OK(); for (auto node_index : node_topology_list) { diff --git a/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc b/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc index 10de6930d5..0e3d281b6a 100644 --- a/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/compare_provider_test_utils.h" #include "test/providers/provider_test_utils.h" #include @@ -12,6 +13,12 @@ namespace onnxruntime { namespace test { +#if USE_ROCM +constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; +#else +constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; +#endif + // followed example of fastgelu_op_test.cc // in retrospect would have been better to compare BiasSoftmax to Add + Softmax graph @@ -130,7 +137,8 @@ class BiasSoftmaxTester { void RunComparison() { // BiasSoftmax only implemented for cuda architecture int min_cuda_architecture = use_float16_ ? 530 : 0; - if (HasCudaEnvironment(min_cuda_architecture)) { + if (HasCudaEnvironment(min_cuda_architecture) || + kGpuExecutionProvider == kRocmExecutionProvider) { OpTester tester("BiasSoftmax", 1, onnxruntime::kMSDomain); tester.AddAttribute("softmax_axis", softmax_axis_); tester.AddAttribute("broadcast_axis", broadcast_axis_); @@ -146,7 +154,12 @@ class BiasSoftmaxTester { } std::vector> ep; - ep.push_back(DefaultCudaExecutionProvider()); + #ifdef USE_CUDA + ep.push_back(DefaultCudaExecutionProvider()); + #elif USE_ROCM + ep.push_back(DefaultRocmExecutionProvider()); + #endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 42c101a67c..0c5a148143 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2417,20 +2417,22 @@ struct BiasSoftmaxFusionTester { BiasSoftmaxFusionTester( const PathString& model_uri, onnxruntime::logging::Logger* logger, - bool on_cuda_ = true) : logger_(logger), graph_transformation_mgr_{5} { + const char* execution_provider = kCudaExecutionProvider) : logger_(logger), graph_transformation_mgr_{5} { model_load_ = Model::Load(model_uri, p_model_, nullptr, *logger_); // move to cuda since fusion only takes place in that case - if (on_cuda_) { - for (auto& node : p_model_->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - } + SetExecutionProvider(execution_provider); graph_transformation_mgr_.Register( onnxruntime::make_unique(), TransformerLevel::Level2); } + void SetExecutionProvider(const char* ep) { + for (auto& node : p_model_->MainGraph().Nodes()) { + node.SetExecutionProviderType(ep); + } + } + void TestFusionOccurs(int expected_broadcast_axis) { ASSERT_STATUS_OK(model_load_); @@ -2466,13 +2468,19 @@ struct BiasSoftmaxFusionTester { } }; -TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_CudaOnly) { +TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_GpuOnly) { auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx"; - BiasSoftmaxFusionTester tester(model_uri, logger_.get(), false); + BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kCpuExecutionProvider); tester.TestNoFusionOccurs(); } -TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple) { +TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Rocm) { + auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx"; + BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kRocmExecutionProvider); + tester.TestFusionOccurs(1); +} + +TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Cuda) { auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx"; BiasSoftmaxFusionTester tester(model_uri, logger_.get()); tester.TestFusionOccurs(1);