mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Enable add + softmax fusion for Rocm platform (#6259)
* add bias softmax; tests appear to pass * check fusion occurs for rocm as well * check for rocm provider compatible as well * build for cpu scenario as well * try again; broader cope * proper scope on kGpuExecutionProvider * been editing wrong file * remove commented #include lines * try again due to mac os ci error * try again * test fusion both cuda and rocm to avoid mac ci error
This commit is contained in:
parent
56ab2166e8
commit
62e404591a
7 changed files with 551 additions and 14 deletions
127
onnxruntime/contrib_ops/rocm/math/bias_softmax.cc
Normal file
127
onnxruntime/contrib_ops/rocm/math/bias_softmax.cc
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/math/bias_softmax.h"
|
||||
|
||||
#include "core/providers/common.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace onnxruntime::contrib::rocm;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
void DispatchBiasSoftmaxForwardImpl(
|
||||
Tensor* output_tensor,
|
||||
const Tensor* input_tensor,
|
||||
const Tensor* input_bias_tensor,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int batch_stride,
|
||||
int bias_broadcast_size_per_batch);
|
||||
|
||||
template <typename T>
|
||||
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
|
||||
miopenHandle_t miopenHandle,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int broadcast_axis,
|
||||
int softmax_axis,
|
||||
const onnxruntime::TensorShape& X_shape,
|
||||
const onnxruntime::Tensor* X,
|
||||
const onnxruntime::TensorShape& B_shape,
|
||||
const onnxruntime::Tensor* B,
|
||||
onnxruntime::Tensor* Y);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
BiasSoftmax,
|
||||
kMSDomain,
|
||||
1,
|
||||
kRocmExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()
|
||||
}),
|
||||
BiasSoftmax);
|
||||
|
||||
Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const TensorShape& X_shape{ctx->Input<Tensor>(0)->Shape()};
|
||||
const TensorShape& B_shape{ctx->Input<Tensor>(1)->Shape()};
|
||||
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const Tensor* B = ctx->Input<Tensor>(1);
|
||||
Tensor* Y = ctx->Output(0, X_shape);
|
||||
|
||||
const int softmax_axis = static_cast<int>(HandleNegativeAxis(softmax_axis_, X_shape.NumDimensions()));
|
||||
const int N = static_cast<int>(X_shape.SizeToDimension(softmax_axis));
|
||||
const int D = static_cast<int>(X_shape.SizeFromDimension(softmax_axis));
|
||||
|
||||
const int broadcast_axis = static_cast<int>(HandleNegativeAxis(broadcast_axis_, X_shape.NumDimensions()));
|
||||
const int broadcast_size = N / static_cast<int>(X_shape.SizeToDimension(broadcast_axis));
|
||||
|
||||
const size_t elem_size = X->DataType()->Size();
|
||||
if (D <= 1024 && D * elem_size <= 4096) {
|
||||
// expect thread blocks can fill SM at high occupancy without overflowing registers
|
||||
utils::MLTypeCallDispatcher<DispatchBiasSoftmaxForward, float, MLFloat16>
|
||||
t_disp(X->GetElementType());
|
||||
t_disp.Invoke(Y, X, B, D, N, D, broadcast_size);
|
||||
} else {
|
||||
// need to fallback to add kernel + CUDA DNN library softmax call :/
|
||||
utils::MLTypeCallDispatcher<DispatchBiasSoftMaxForwardViaDnnLibrary, float, MLFloat16>
|
||||
t_disp(X->GetElementType());
|
||||
t_disp.Invoke(MiopenHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchBiasSoftmaxForward<T>::operator()(
|
||||
Tensor* output,
|
||||
const Tensor* input,
|
||||
const Tensor* input_bias,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int batch_stride,
|
||||
int bias_broadcast_size_per_batch) {
|
||||
DispatchBiasSoftmaxForwardImpl<T>(
|
||||
output,
|
||||
input,
|
||||
input_bias,
|
||||
element_count,
|
||||
batch_count,
|
||||
batch_stride,
|
||||
bias_broadcast_size_per_batch);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchBiasSoftMaxForwardViaDnnLibrary<T>::operator()(
|
||||
miopenHandle_t miopenHandle,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int broadcast_axis,
|
||||
int softmax_axis,
|
||||
const onnxruntime::TensorShape& X_shape,
|
||||
const onnxruntime::Tensor* X,
|
||||
const onnxruntime::TensorShape& B_shape,
|
||||
const onnxruntime::Tensor* B,
|
||||
onnxruntime::Tensor* Y) {
|
||||
DispatchBiasSoftMaxForwardViaDnnLibraryImpl<T>(
|
||||
miopenHandle,
|
||||
element_count,
|
||||
batch_count,
|
||||
broadcast_axis,
|
||||
softmax_axis,
|
||||
X_shape,
|
||||
X,
|
||||
B_shape,
|
||||
B,
|
||||
Y);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
56
onnxruntime/contrib_ops/rocm/math/bias_softmax.h
Normal file
56
onnxruntime/contrib_ops/rocm/math/bias_softmax.h
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
struct DispatchBiasSoftmaxForward {
|
||||
void operator()(
|
||||
Tensor* output,
|
||||
const Tensor* input,
|
||||
const Tensor* input_bias,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int batch_stride,
|
||||
int bias_broadcast_size_per_batch);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DispatchBiasSoftMaxForwardViaDnnLibrary {
|
||||
void operator()(
|
||||
miopenHandle_t miopenHandle,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int broadcast_axis,
|
||||
int softmax_axis,
|
||||
const onnxruntime::TensorShape& X_shape,
|
||||
const onnxruntime::Tensor* X,
|
||||
const onnxruntime::TensorShape& B_shape,
|
||||
const onnxruntime::Tensor* B,
|
||||
onnxruntime::Tensor* Y);
|
||||
};
|
||||
|
||||
class BiasSoftmax final : public onnxruntime::rocm::RocmKernel {
|
||||
public:
|
||||
BiasSoftmax(const OpKernelInfo& info) : RocmKernel{info} {
|
||||
info.GetAttrOrDefault("softmax_axis", &softmax_axis_, static_cast<int64_t>(1));
|
||||
info.GetAttrOrDefault("broadcast_axis", &broadcast_axis_, static_cast<int64_t>(1));
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t softmax_axis_;
|
||||
int64_t broadcast_axis_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
331
onnxruntime/contrib_ops/rocm/math/bias_softmax_impl.cu
Normal file
331
onnxruntime/contrib_ops/rocm/math/bias_softmax_impl.cu
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/math/bias_softmax.h"
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/rocm/miopen_common.h"
|
||||
#include "core/providers/rocm/shared_inc/accumulation_type.h"
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
|
||||
#include "core/providers/rocm/math/softmax_impl.cuh"
|
||||
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::rocm;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
// Duplicated softmax_impl.cu here
|
||||
// So far attempt to use shared kernel with additional template resulted in lost performance
|
||||
|
||||
// Note: The intended case for 'input_bias' is the input sequence mask for transformer models
|
||||
// As an additive mask, it should be zero for preserved tokens and -infty for tokens to screen
|
||||
// The mask will broadcast from [batch_size, 1, 1, seq_len] to input [batch_size, num_heads, seq_len, seq_len]
|
||||
// Here element_count = seq_len and bias_broadcast_size_per_batch = num_heads * seq_len
|
||||
|
||||
// The softmax + additive mask fusion follows NVIDIA apex's additive_masked_softmax_warp_forward
|
||||
// see https://github.com/NVIDIA/apex/blob/4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a/apex/contrib/csrc/multihead_attn/softmax.h
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void BiasSoftmaxWarpForward(
|
||||
output_t* output,
|
||||
const input_t* input,
|
||||
const input_t* input_bias,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int batch_stride,
|
||||
int bias_broadcast_count_per_batch) {
|
||||
// "WARP" refers to cooperative threads and might not equal 32 threads of GPU warp
|
||||
// thread block is (WARP_SIZE, 128/WARP_SIZE)
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = next_power_of_two < GPU_WARP_SIZE ? next_power_of_two : GPU_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int WARP_BATCH = 1;
|
||||
|
||||
// each "WARP" (<=32) processes WARP_BATCH(one of {1,2}) batches
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// last warp may have fewer batches
|
||||
int local_batches = batch_count - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// thread will process elements (local_index + n * warp_size) within batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// push input, input_bias output pointers to batch we need to process
|
||||
input += first_batch * batch_stride + local_idx;
|
||||
output += first_batch * batch_stride + local_idx;
|
||||
|
||||
// load from global memory and apply bias (likely an additive mask)
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
// the bias has assumed shape [batch_size, element_count]
|
||||
// .. and needs to broadcast to [batch_size, broadcast_size, element_count]
|
||||
int bias_offset = (first_batch + i) / bias_broadcast_count_per_batch * batch_stride + local_idx;
|
||||
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
elements[i][it] = (acc_t)input[i * element_count + it * WARP_SIZE] + (acc_t)input_bias[bias_offset + it * WARP_SIZE];
|
||||
} else {
|
||||
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// find maximum value within batch for numerical stability
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
// normalization factor Z = Sum[ exp(element_i), for element_i in batch ]
|
||||
acc_t sum[WARP_BATCH]{acc_t(0.0)};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = expf((acc_t)(elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// write back normalized value = exp(element_i)/Z to global memory
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
output[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchBiasSoftmaxForwardImpl(
|
||||
Tensor* output_tensor,
|
||||
const Tensor* input_tensor,
|
||||
const Tensor* input_bias_tensor,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int batch_stride,
|
||||
int bias_broadcast_size_per_batch) {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
typedef HipT input_t;
|
||||
typedef HipT output_t;
|
||||
typedef AccumulationType_t<HipT> acc_t;
|
||||
|
||||
const auto* input = reinterpret_cast<const HipT*>(input_tensor->template Data<T>());
|
||||
const auto* input_bias = reinterpret_cast<const HipT*>(input_bias_tensor->template Data<T>());
|
||||
auto* output = reinterpret_cast<HipT*>(output_tensor->template MutableData<T>());
|
||||
|
||||
if (element_count == 0)
|
||||
return;
|
||||
|
||||
int log2_elements = log2_ceil(element_count);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE);
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
int batches_per_warp = 1;
|
||||
|
||||
// use 128 threads per block to maximize gpu utilization
|
||||
constexpr int threads_per_block = 256;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 0>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 1: // 2
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 1>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 2: // 4
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 2>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 3: // 8
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 3>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 4: // 16
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 4>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 5: // 32
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 5>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 6: // 64
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 6>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 7: // 128
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 7>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 8: // 256
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 8>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 9: // 512
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 9>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
case 10: // 1024
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<input_t, output_t, acc_t, 10>), dim3(blocks), dim3(threads), 0, 0,
|
||||
output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_BIAS_SOFTMAX_IMPL(T) \
|
||||
template void DispatchBiasSoftmaxForwardImpl<T>( \
|
||||
Tensor * output_tensor, \
|
||||
const Tensor* input_tensor, \
|
||||
const Tensor* input_bias_tensor, \
|
||||
int element_count, \
|
||||
int batch_count, \
|
||||
int batch_stride, \
|
||||
int bias_broadcast_size_per_batch);
|
||||
|
||||
SPECIALIZED_BIAS_SOFTMAX_IMPL(double)
|
||||
SPECIALIZED_BIAS_SOFTMAX_IMPL(float)
|
||||
SPECIALIZED_BIAS_SOFTMAX_IMPL(MLFloat16)
|
||||
|
||||
// For large element count we fall back to explicit Add kernel + CUDA DNN library
|
||||
// note: This is an unhappy path! There is no performance benefit for the fusion.
|
||||
template <typename T>
|
||||
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
|
||||
miopenHandle_t miopenHandle,
|
||||
int element_count,
|
||||
int batch_count,
|
||||
int broadcast_axis,
|
||||
int softmax_axis,
|
||||
const onnxruntime::TensorShape& X_shape,
|
||||
const onnxruntime::Tensor* X,
|
||||
const onnxruntime::TensorShape& B_shape,
|
||||
const onnxruntime::Tensor* B,
|
||||
onnxruntime::Tensor* Y) {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
const auto* X_data = reinterpret_cast<const HipT*>(X->template Data<T>());
|
||||
const auto* B_data = reinterpret_cast<const HipT*>(B->template Data<T>());
|
||||
auto* Y_data = reinterpret_cast<HipT*>(Y->template MutableData<T>());
|
||||
|
||||
// binary elementise kernel requires input pitches
|
||||
TArray<int64_t> lhs_padded_strides(X_shape.NumDimensions());
|
||||
for (int i = -1, lhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) {
|
||||
size_t positive_i = X_shape.NumDimensions() + i;
|
||||
lhs_padded_strides[positive_i] = lhs_pitch;
|
||||
lhs_pitch *= X_shape[positive_i];
|
||||
}
|
||||
|
||||
// set pitches for bias so it broadcasts along relevant dimensions
|
||||
TArray<int64_t> rhs_padded_strides(X_shape.NumDimensions());
|
||||
for (int i = -1, rhs_pitch = 1; i >= -(int)X_shape.NumDimensions(); i--) {
|
||||
size_t positive_ix = X_shape.NumDimensions() + i;
|
||||
size_t positive_ib = B_shape.NumDimensions() + i;
|
||||
if (broadcast_axis <= positive_ix && positive_ix < softmax_axis) {
|
||||
rhs_padded_strides[positive_ix] = 0;
|
||||
continue;
|
||||
}
|
||||
rhs_padded_strides[positive_ix] = rhs_pitch;
|
||||
rhs_pitch *= B_shape[positive_ib];
|
||||
}
|
||||
|
||||
TArray<fast_divmod> fdm_output_strides(X_shape.NumDimensions());
|
||||
for (int i = 0; i < fdm_output_strides.Size(); i++)
|
||||
fdm_output_strides[i] = fast_divmod(lhs_padded_strides[i]);
|
||||
fast_divmod fdm_H, fdm_C;
|
||||
|
||||
// invoke elementwise add with broadcast kernel
|
||||
::onnxruntime::rocm::BinaryElementWiseImpl(
|
||||
(int32_t)X_shape.NumDimensions(),
|
||||
&lhs_padded_strides,
|
||||
X_data,
|
||||
&rhs_padded_strides,
|
||||
B_data,
|
||||
&fdm_output_strides,
|
||||
fdm_H,
|
||||
fdm_C,
|
||||
Y_data,
|
||||
OP_Add<HipT, HipT, HipT>(),
|
||||
(size_t)X_shape.Size());
|
||||
|
||||
// invoke cuda DNN library for Y = softmax(X)
|
||||
std::vector<int64_t> dims({batch_count, 1, 1, element_count});
|
||||
const auto alpha = Consts<HipT>::One;
|
||||
const auto beta = Consts<HipT>::Zero;
|
||||
onnxruntime::rocm::MiopenTensor input_tensor, output_tensor;
|
||||
input_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType<HipT>());
|
||||
output_tensor.Set(dims, onnxruntime::rocm::MiopenTensor::GetDataType<HipT>());
|
||||
miopenSoftmaxForward_V2(
|
||||
miopenHandle,
|
||||
&alpha,
|
||||
input_tensor,
|
||||
Y_data,
|
||||
&beta,
|
||||
output_tensor,
|
||||
Y_data,
|
||||
MIOPEN_SOFTMAX_ACCURATE,
|
||||
MIOPEN_SOFTMAX_MODE_INSTANCE);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(T) \
|
||||
template void DispatchBiasSoftMaxForwardViaDnnLibraryImpl<T>( \
|
||||
miopenHandle_t miopenHandle, \
|
||||
int element_count, \
|
||||
int batch_count, \
|
||||
int broadcast_axis, \
|
||||
int softmax_axis, \
|
||||
const onnxruntime::TensorShape& X_shape, \
|
||||
const Tensor* X_data, \
|
||||
const onnxruntime::TensorShape& B_shape, \
|
||||
const Tensor* B_data, \
|
||||
Tensor* Y_data);
|
||||
|
||||
// SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(double)
|
||||
SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(float)
|
||||
SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(MLFloat16)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -33,6 +33,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax);
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to maintain backward compatibility
|
||||
|
|
@ -141,6 +142,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
|
||||
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
|
|||
|
||||
// check node is add and has single output
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider}) ||
|
||||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -224,9 +224,9 @@ Status BiasSoftmaxFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
// only support CUDA execution provider
|
||||
// only support GPU execution provider
|
||||
auto& cep = GetCompatibleExecutionProviders();
|
||||
if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end())
|
||||
if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end() && cep.find(kRocmExecutionProvider) == cep.end())
|
||||
return Status::OK();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/providers/compare_provider_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -12,6 +13,12 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
#if USE_ROCM
|
||||
constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider;
|
||||
#else
|
||||
constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider;
|
||||
#endif
|
||||
|
||||
// followed example of fastgelu_op_test.cc
|
||||
// in retrospect would have been better to compare BiasSoftmax to Add + Softmax graph
|
||||
|
||||
|
|
@ -130,7 +137,8 @@ class BiasSoftmaxTester {
|
|||
void RunComparison() {
|
||||
// BiasSoftmax only implemented for cuda architecture
|
||||
int min_cuda_architecture = use_float16_ ? 530 : 0;
|
||||
if (HasCudaEnvironment(min_cuda_architecture)) {
|
||||
if (HasCudaEnvironment(min_cuda_architecture) ||
|
||||
kGpuExecutionProvider == kRocmExecutionProvider) {
|
||||
OpTester tester("BiasSoftmax", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("softmax_axis", softmax_axis_);
|
||||
tester.AddAttribute<int64_t>("broadcast_axis", broadcast_axis_);
|
||||
|
|
@ -146,7 +154,12 @@ class BiasSoftmaxTester {
|
|||
}
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> ep;
|
||||
ep.push_back(DefaultCudaExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
ep.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
ep.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2417,20 +2417,22 @@ struct BiasSoftmaxFusionTester {
|
|||
BiasSoftmaxFusionTester(
|
||||
const PathString& model_uri,
|
||||
onnxruntime::logging::Logger* logger,
|
||||
bool on_cuda_ = true) : logger_(logger), graph_transformation_mgr_{5} {
|
||||
const char* execution_provider = kCudaExecutionProvider) : logger_(logger), graph_transformation_mgr_{5} {
|
||||
model_load_ = Model::Load(model_uri, p_model_, nullptr, *logger_);
|
||||
|
||||
// move to cuda since fusion only takes place in that case
|
||||
if (on_cuda_) {
|
||||
for (auto& node : p_model_->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
}
|
||||
SetExecutionProvider(execution_provider);
|
||||
|
||||
graph_transformation_mgr_.Register(
|
||||
onnxruntime::make_unique<BiasSoftmaxFusion>(), TransformerLevel::Level2);
|
||||
}
|
||||
|
||||
void SetExecutionProvider(const char* ep) {
|
||||
for (auto& node : p_model_->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(ep);
|
||||
}
|
||||
}
|
||||
|
||||
void TestFusionOccurs(int expected_broadcast_axis) {
|
||||
ASSERT_STATUS_OK(model_load_);
|
||||
|
||||
|
|
@ -2466,13 +2468,19 @@ struct BiasSoftmaxFusionTester {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_CudaOnly) {
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_GpuOnly) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
|
||||
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), false);
|
||||
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kCpuExecutionProvider);
|
||||
tester.TestNoFusionOccurs();
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple) {
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Rocm) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
|
||||
BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kRocmExecutionProvider);
|
||||
tester.TestFusionOccurs(1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Cuda) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx";
|
||||
BiasSoftmaxFusionTester tester(model_uri, logger_.get());
|
||||
tester.TestFusionOccurs(1);
|
||||
|
|
|
|||
Loading…
Reference in a new issue