Enable add + softmax fusion for Rocm platform (#6259)

* add bias softmax; tests appear to pass

* check fusion occurs for rocm as well

* check for rocm provider compatible as well

* build for cpu scenario as well

* try again; broader cope

* proper scope on kGpuExecutionProvider

* been editing wrong file

* remove commented #include lines

* try again due to mac os ci error

* try again

* test fusion both cuda and rocm to avoid mac ci error
This commit is contained in:
Suffian Khan 2021-01-13 17:09:09 -05:00 committed by GitHub
parent 56ab2166e8
commit 62e404591a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 551 additions and 14 deletions

View file

@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/bias_softmax.h"
#include "core/providers/common.h"
using namespace onnxruntime;
using namespace onnxruntime::rocm;
using namespace onnxruntime::contrib::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
void DispatchBiasSoftmaxForwardImpl(
Tensor* output_tensor,
const Tensor* input_tensor,
const Tensor* input_bias_tensor,
int element_count,
int batch_count,
int batch_stride,
int bias_broadcast_size_per_batch);
template <typename T>
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
miopenHandle_t miopenHandle,
int element_count,
int batch_count,
int broadcast_axis,
int softmax_axis,
const onnxruntime::TensorShape& X_shape,
const onnxruntime::Tensor* X,
const onnxruntime::TensorShape& B_shape,
const onnxruntime::Tensor* B,
onnxruntime::Tensor* Y);
ONNX_OPERATOR_KERNEL_EX(
BiasSoftmax,
kMSDomain,
1,
kRocmExecutionProvider,
KernelDefBuilder().TypeConstraint("T", {
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()
}),
BiasSoftmax);
Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
const TensorShape& X_shape{ctx->Input<Tensor>(0)->Shape()};
const TensorShape& B_shape{ctx->Input<Tensor>(1)->Shape()};
const Tensor* X = ctx->Input<Tensor>(0);
const Tensor* B = ctx->Input<Tensor>(1);
Tensor* Y = ctx->Output(0, X_shape);
const int softmax_axis = static_cast<int>(HandleNegativeAxis(softmax_axis_, X_shape.NumDimensions()));
const int N = static_cast<int>(X_shape.SizeToDimension(softmax_axis));
const int D = static_cast<int>(X_shape.SizeFromDimension(softmax_axis));
const int broadcast_axis = static_cast<int>(HandleNegativeAxis(broadcast_axis_, X_shape.NumDimensions()));
const int broadcast_size = N / static_cast<int>(X_shape.SizeToDimension(broadcast_axis));
const size_t elem_size = X->DataType()->Size();
if (D <= 1024 && D * elem_size <= 4096) {
// expect thread blocks can fill SM at high occupancy without overflowing registers
utils::MLTypeCallDispatcher<DispatchBiasSoftmaxForward, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(Y, X, B, D, N, D, broadcast_size);
} else {
// need to fallback to add kernel + CUDA DNN library softmax call :/
utils::MLTypeCallDispatcher<DispatchBiasSoftMaxForwardViaDnnLibrary, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(MiopenHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
}
return Status::OK();
}
template <typename T>
void DispatchBiasSoftmaxForward<T>::operator()(
Tensor* output,
const Tensor* input,
const Tensor* input_bias,
int element_count,
int batch_count,
int batch_stride,
int bias_broadcast_size_per_batch) {
DispatchBiasSoftmaxForwardImpl<T>(
output,
input,
input_bias,
element_count,
batch_count,
batch_stride,
bias_broadcast_size_per_batch);
}
template <typename T>
void DispatchBiasSoftMaxForwardViaDnnLibrary<T>::operator()(
miopenHandle_t miopenHandle,
int element_count,
int batch_count,
int broadcast_axis,
int softmax_axis,
const onnxruntime::TensorShape& X_shape,
const onnxruntime::Tensor* X,
const onnxruntime::TensorShape& B_shape,
const onnxruntime::Tensor* B,
onnxruntime::Tensor* Y) {
DispatchBiasSoftMaxForwardViaDnnLibraryImpl<T>(
miopenHandle,
element_count,
batch_count,
broadcast_axis,
softmax_axis,
X_shape,
X,
B_shape,
B,
Y);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "gsl/gsl"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct DispatchBiasSoftmaxForward {
void operator()(
Tensor* output,
const Tensor* input,
const Tensor* input_bias,
int element_count,
int batch_count,
int batch_stride,
int bias_broadcast_size_per_batch);
};
template <typename T>
struct DispatchBiasSoftMaxForwardViaDnnLibrary {
void operator()(
miopenHandle_t miopenHandle,
int element_count,
int batch_count,
int broadcast_axis,
int softmax_axis,
const onnxruntime::TensorShape& X_shape,
const onnxruntime::Tensor* X,
const onnxruntime::TensorShape& B_shape,
const onnxruntime::Tensor* B,
onnxruntime::Tensor* Y);
};
class BiasSoftmax final : public onnxruntime::rocm::RocmKernel {
public:
BiasSoftmax(const OpKernelInfo& info) : RocmKernel{info} {
info.GetAttrOrDefault("softmax_axis", &softmax_axis_, static_cast<int64_t>(1));
info.GetAttrOrDefault("broadcast_axis", &broadcast_axis_, static_cast<int64_t>(1));
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t softmax_axis_;
int64_t broadcast_axis_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,331 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/bias_softmax.h"
#include <limits>
#include <algorithm>
#include "hip/hip_runtime.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/common.h"
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/rocm/shared_inc/accumulation_type.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
#include "core/providers/rocm/math/softmax_impl.cuh"
using namespace onnxruntime;
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Duplicated softmax_impl.cu here
// So far attempt to use shared kernel with additional template resulted in lost performance
// Note: The intended case for 'input_bias' is the input sequence mask for transformer models
// As an additive mask, it should be zero for preserved tokens and -infty for tokens to screen
// The mask will broadcast from [batch_size, 1, 1, seq_len] to input [batch_size, num_heads, seq_len, seq_len]
// Here element_count = seq_len and bias_broadcast_size_per_batch = num_heads * seq_len
// The softmax + additive mask fusion follows NVIDIA apex's additive_masked_softmax_warp_forward
// see https://github.com/NVIDIA/apex/blob/4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a/apex/contrib/csrc/multihead_attn/softmax.h
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void BiasSoftmaxWarpForward(
output_t* output,
const input_t* input,
const input_t* input_bias,
int element_count,
int batch_count,
int batch_stride,
int bias_broadcast_count_per_batch) {
// "WARP" refers to cooperative threads and might not equal 32 threads of GPU warp
// thread block is (WARP_SIZE, 128/WARP_SIZE)
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = next_power_of_two < GPU_WARP_SIZE ? next_power_of_two : GPU_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int WARP_BATCH = 1;
// each "WARP" (<=32) processes WARP_BATCH(one of {1,2}) batches
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// last warp may have fewer batches
int local_batches = batch_count - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// thread will process elements (local_index + n * warp_size) within batch
int local_idx = threadIdx.x;
// push input, input_bias output pointers to batch we need to process
input += first_batch * batch_stride + local_idx;
output += first_batch * batch_stride + local_idx;
// load from global memory and apply bias (likely an additive mask)
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
// the bias has assumed shape [batch_size, element_count]
// .. and needs to broadcast to [batch_size, broadcast_size, element_count]
int bias_offset = (first_batch + i) / bias_broadcast_count_per_batch * batch_stride + local_idx;
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] = (acc_t)input[i * element_count + it * WARP_SIZE] + (acc_t)input_bias[bias_offset + it * WARP_SIZE];
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
}
}
}
// find maximum value within batch for numerical stability
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// normalization factor Z = Sum[ exp(element_i), for element_i in batch ]
acc_t sum[WARP_BATCH]{acc_t(0.0)};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = expf((acc_t)(elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// write back normalized value = exp(element_i)/Z to global memory
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
} else {
break;
}
}
}
}
template <typename T>
void DispatchBiasSoftmaxForwardImpl(
Tensor* output_tensor,
const Tensor* input_tensor,
const Tensor* input_bias_tensor,
int element_count,
int batch_count,
int batch_stride,
int bias_broadcast_size_per_batch) {
typedef typename ToHipType<T>::MappedType HipT;
typedef HipT input_t;
typedef HipT output_t;
typedef AccumulationType_t<HipT> acc_t;
const auto* input = reinterpret_cast<const HipT*>(input_tensor->template Data<T>());
const auto* input_bias = reinterpret_cast<const HipT*>(input_bias_tensor->template Data<T>());
auto* output = reinterpret_cast<HipT*>(output_tensor->template MutableData<T>());
if (element_count == 0)
return;
int log2_elements = log2_ceil(element_count);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE);
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
int batches_per_warp = 1;
// use 128 threads per block to maximize gpu utilization
constexpr int threads_per_block = 256;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 0>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 1: // 2
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 1>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 2: // 4
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 2>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 3: // 8
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 3>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 4: // 16
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 4>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 5: // 32
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 5>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 6: // 64
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 6>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 7: // 128
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 7>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 8: // 256
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 8>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 9: // 512
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 9>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 10: // 1024
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 10>), dim3(blocks), dim3(threads), 0, 0,
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
default:
break;
}
}
#define SPECIALIZED_BIAS_SOFTMAX_IMPL(T) \
template void DispatchBiasSoftmaxForwardImpl<T>( \
Tensor * output_tensor, \
const Tensor* input_tensor, \
const Tensor* input_bias_tensor, \
int element_count, \
int batch_count, \
int batch_stride, \
int bias_broadcast_size_per_batch);
SPECIALIZED_BIAS_SOFTMAX_IMPL(double)
SPECIALIZED_BIAS_SOFTMAX_IMPL(float)
SPECIALIZED_BIAS_SOFTMAX_IMPL(MLFloat16)
// For large element count we fall back to explicit Add kernel + CUDA DNN library
// note: This is an unhappy path! There is no performance benefit for the fusion.
template <typename T>
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
miopenHandle_t miopenHandle,
int element_count,
int batch_count,
int broadcast_axis,
int softmax_axis,
const onnxruntime::TensorShape& X_shape,
const onnxruntime::Tensor* X,
const onnxruntime::TensorShape& B_shape,
const onnxruntime::Tensor* B,
onnxruntime::Tensor* Y) {
typedef typename ToHipType<T>::MappedType HipT;
const auto* X_data = reinterpret_cast<const HipT*>(X->template Data<T>());
const auto* B_data = reinterpret_cast<const HipT*>(B->template Data<T>());
auto* Y_data = reinterpret_cast<HipT*>(Y->template MutableData<T>());
// binary elementise kernel requires input pitches
TArray<int64_t> lhs_padded_strides(X_shape.NumDimensions());
for (int i = -1, lhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) {
size_t positive_i = X_shape.NumDimensions() + i;
lhs_padded_strides[positive_i] = lhs_pitch;
lhs_pitch *= X_shape[positive_i];
}
// set pitches for bias so it broadcasts along relevant dimensions
TArray<int64_t> rhs_padded_strides(X_shape.NumDimensions());
for (int i = -1, rhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) {
size_t positive_ix = X_shape.NumDimensions() + i;
size_t positive_ib = B_shape.NumDimensions() + i;
if (broadcast_axis <= positive_ix && positive_ix < softmax_axis) {
rhs_padded_strides[positive_ix] = 0;
continue;
}
rhs_padded_strides[positive_ix] = rhs_pitch;
rhs_pitch *= B_shape[positive_ib];
}
TArray<fast_divmod> fdm_output_strides(X_shape.NumDimensions());
for (int i = 0; i < fdm_output_strides.Size(); i++)
fdm_output_strides[i] = fast_divmod(lhs_padded_strides[i]);
fast_divmod fdm_H, fdm_C;
// invoke elementwise add with broadcast kernel
::onnxruntime::rocm::BinaryElementWiseImpl(
(int32_t)X_shape.NumDimensions(),
&lhs_padded_strides,
X_data,
&rhs_padded_strides,
B_data,
&fdm_output_strides,
fdm_H,
fdm_C,
Y_data,
OP_Add<HipT, HipT, HipT>(),
(size_t)X_shape.Size());
// invoke cuda DNN library for Y = softmax(X)
std::vector<int64_t> dims({batch_count, 1, 1, element_count});
const auto alpha = Consts<HipT>::One;
const auto beta = Consts<HipT>::Zero;
onnxruntime::rocm::MiopenTensor input_tensor, output_tensor;
input_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType<HipT>());
output_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType<HipT>());
miopenSoftmaxForward_V2(
miopenHandle,
&alpha,
input_tensor,
Y_data,
&beta,
output_tensor,
Y_data,
MIOPEN_SOFTMAX_ACCURATE,
MIOPEN_SOFTMAX_MODE_INSTANCE);
}
#define SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(T) \
template void DispatchBiasSoftMaxForwardViaDnnLibraryImpl<T>( \
miopenHandle_t miopenHandle, \
int element_count, \
int batch_count, \
int broadcast_axis, \
int softmax_axis, \
const onnxruntime::TensorShape& X_shape, \
const Tensor* X_data, \
const onnxruntime::TensorShape& B_shape, \
const Tensor* B_data, \
Tensor* Y_data);
// SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(double)
SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(float)
SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(MLFloat16)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -33,6 +33,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax);
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to maintain backward compatibility
@ -141,6 +142,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,

View file

@ -44,7 +44,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
// check node is add and has single output
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider}) ||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return false;
}
@ -224,9 +224,9 @@ Status BiasSoftmaxFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
// only support CUDA execution provider
// only support GPU execution provider
auto& cep = GetCompatibleExecutionProviders();
if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end())
if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end() && cep.find(kRocmExecutionProvider) == cep.end())
return Status::OK();
for (auto node_index : node_topology_list) {

View file

@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/compare_provider_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include <algorithm>
@ -12,6 +13,12 @@
namespace onnxruntime {
namespace test {
#if USE_ROCM
constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider;
#else
constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider;
#endif
// followed example of fastgelu_op_test.cc
// in retrospect would have been better to compare BiasSoftmax to Add + Softmax graph
@ -130,7 +137,8 @@ class BiasSoftmaxTester {
void RunComparison() {
// BiasSoftmax only implemented for cuda architecture
int min_cuda_architecture = use_float16_ ? 530 : 0;
if (HasCudaEnvironment(min_cuda_architecture)) {
if (HasCudaEnvironment(min_cuda_architecture) ||
kGpuExecutionProvider == kRocmExecutionProvider) {
OpTester tester("BiasSoftmax", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("softmax_axis", softmax_axis_);
tester.AddAttribute<int64_t>("broadcast_axis", broadcast_axis_);
@ -146,7 +154,12 @@ class BiasSoftmaxTester {
}
std::vector<std::unique_ptr<IExecutionProvider>> ep;
ep.push_back(DefaultCudaExecutionProvider());
#ifdef USE_CUDA
ep.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
ep.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep);
}
}

View file

@ -2417,20 +2417,22 @@ struct BiasSoftmaxFusionTester {
BiasSoftmaxFusionTester(
const PathString& model_uri,
onnxruntime::logging::Logger* logger,
bool on_cuda_ = true) : logger_(logger), graph_transformation_mgr_{5} {
const char* execution_provider = kCudaExecutionProvider) : logger_(logger), graph_transformation_mgr_{5} {
model_load_ = Model::Load(model_uri, p_model_, nullptr, *logger_);
// move to cuda since fusion only takes place in that case
if (on_cuda_) {
for (auto& node : p_model_->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}
}
SetExecutionProvider(execution_provider);
graph_transformation_mgr_.Register(
onnxruntime::make_unique<BiasSoftmaxFusion>(), TransformerLevel::Level2);
}
void SetExecutionProvider(const char* ep) {
for (auto& node : p_model_->MainGraph().Nodes()) {
node.SetExecutionProviderType(ep);
}
}
void TestFusionOccurs(int expected_broadcast_axis) {
ASSERT_STATUS_OK(model_load_);
@ -2466,13 +2468,19 @@ struct BiasSoftmaxFusionTester {
}
};
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_CudaOnly) {
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_GpuOnly) {
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), false);
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kCpuExecutionProvider);
tester.TestNoFusionOccurs();
}
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple) {
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Rocm) {
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kRocmExecutionProvider);
tester.TestFusionOccurs(1);
}
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Cuda) {
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
BiasSoftmaxFusionTester tester(model_uri, logger_.get());
tester.TestFusionOccurs(1);