mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[ROCm] Add SkipGroupNorm for ROCm EP (#19303)
Add SkipGroupNorm for ROCm EP. --------- Co-authored-by: Peixuan Zuo <peixuanzuo@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
8fadc6c913
commit
6226c5f62f
18 changed files with 383 additions and 733 deletions
|
|
@ -44,12 +44,7 @@ set(contrib_ops_excluded_files
|
|||
"bert/packed_multihead_attention.cc"
|
||||
"bert/packed_multihead_attention_impl.h"
|
||||
"bert/packed_multihead_attention_impl.cu"
|
||||
"diffusion/group_norm.cc"
|
||||
"diffusion/group_norm_impl.cu"
|
||||
"diffusion/group_norm_impl.h"
|
||||
"diffusion/group_norm_impl_kernel.cuh"
|
||||
"diffusion/group_norm_common_base.h"
|
||||
"diffusion/group_norm_common_base.cc"
|
||||
"diffusion/nhwc_conv.cc"
|
||||
"math/gemm_float8.cc"
|
||||
"math/gemm_float8.cu"
|
||||
|
|
|
|||
|
|
@ -1,152 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
#define GROUP_NORM_TYPES float, MLFloat16
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GroupNorm, kMSDomain, 1, kRocmExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DispatchGroupNorm {
|
||||
Status operator()(RocmTuningContext* tuning_ctx,
|
||||
Stream* stream,
|
||||
Tensor* output,
|
||||
const Tensor* input,
|
||||
const Tensor* gamma,
|
||||
const Tensor* beta,
|
||||
void* workspace,
|
||||
float epsilon,
|
||||
int batch_size,
|
||||
int num_channels,
|
||||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_swish_activation) {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
return LaunchGroupNormKernel<HipT>(
|
||||
tuning_ctx,
|
||||
stream,
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()),
|
||||
gamma->Data<float>(),
|
||||
beta->Data<float>(),
|
||||
workspace,
|
||||
epsilon,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_groups,
|
||||
use_swish_activation);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
|
||||
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
|
||||
int64_t num_groups;
|
||||
ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
|
||||
ORT_ENFORCE(num_groups >= 0);
|
||||
num_groups_ = static_cast<int>(num_groups);
|
||||
|
||||
int64_t activation;
|
||||
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
|
||||
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
|
||||
use_swish_activation_ = (activation == 1);
|
||||
|
||||
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
|
||||
}
|
||||
|
||||
Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
|
||||
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
|
||||
is_packed = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* gamma = context->Input<Tensor>(1);
|
||||
const Tensor* beta = context->Input<Tensor>(2);
|
||||
Tensor* output = context->Output(0, input->Shape());
|
||||
|
||||
if (!channels_last_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"only the channels_last layout is supported");
|
||||
}
|
||||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 4 dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != input_dims[3]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in gamma and input does not match");
|
||||
}
|
||||
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != input_dims[3]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in beta and input does not match");
|
||||
}
|
||||
|
||||
// Input and output format is NHWC
|
||||
int batch_size = static_cast<int>(input_dims[0]);
|
||||
int num_channels = static_cast<int>(input_dims[3]);
|
||||
int height = static_cast<int>(input_dims[1]);
|
||||
int width = static_cast<int>(input_dims[2]);
|
||||
|
||||
if (num_channels % num_groups_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"number of channels should be divisible by num_groups");
|
||||
}
|
||||
|
||||
if (context->GetUseDeterministicCompute()) {
|
||||
static std::once_flag log_warning;
|
||||
std::call_once(log_warning, []() {
|
||||
LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic.";
|
||||
});
|
||||
}
|
||||
|
||||
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
|
||||
|
||||
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
|
||||
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(), context->GetComputeStream(),
|
||||
output, input, gamma, beta, workspace.get(),
|
||||
epsilon_,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_groups_,
|
||||
use_swish_activation_);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -26,13 +26,18 @@ namespace rocm {
|
|||
|
||||
using onnxruntime::rocm::CKDataTypeAdaptor;
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
// The SiLU function is a special case of Swish function,
|
||||
// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as:
|
||||
// SiLU(x) = x * sigmoid(x)
|
||||
// Swish(x) = x * sigmoid(bx)
|
||||
// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here.
|
||||
using Silu = ck::tensor_operation::element_wise::Swish;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
template <typename T, typename AccT, bool WithSwish>
|
||||
template <typename T, typename AccT, bool WithSilu>
|
||||
auto GetCKGroupNormNHWCTypeStringAndOps() {
|
||||
using XDataType = typename CKDataTypeAdaptor<T>::type;
|
||||
using YDataType = typename CKDataTypeAdaptor<T>::type;
|
||||
|
|
@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
|
|||
using GammaDataType = float;
|
||||
using BetaDataType = float;
|
||||
|
||||
using Activation = std::conditional_t<WithSwish, Swish, Pass>;
|
||||
using Activation = std::conditional_t<WithSilu, Silu, Pass>;
|
||||
|
||||
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
|
||||
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
|
||||
for (auto&& impl : internal::GetDeviceGroupNormInstances<XDataType, GammaDataType, BetaDataType, YDataType,
|
||||
SaveMeanInvStdDataType, Activation, Rank, NumReduceDim>()) {
|
||||
std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
|
||||
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
|
||||
std::string silu_suffix = WithSilu ? "_Silu" : "_Pass";
|
||||
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix;
|
||||
auto invoker = impl->MakeInvokerPointer();
|
||||
|
||||
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
|
||||
if constexpr (WithSwish) {
|
||||
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](
|
||||
const GroupNormNHWCTunableParams<T>* params) -> Status {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
|
||||
"Input skip or bias is not supported by composable kernel.");
|
||||
if constexpr (WithSilu) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!params->withSwish, "Swish version only support groupnorm with swish");
|
||||
!params->use_silu, "Silu version only support groupnorm with silu");
|
||||
} else {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->withSwish, "Pass version only support groupnorm without swish");
|
||||
params->use_silu, "Pass version only support groupnorm without silu");
|
||||
}
|
||||
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
|
||||
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
|
||||
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
|
||||
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group};
|
||||
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c,
|
||||
params->c, params->channels_per_group, 1};
|
||||
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->channels_per_group, 1};
|
||||
std::vector<ck::index_t> reduce_dims{1, 2, 4};
|
||||
|
||||
auto activation = Activation{};
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ namespace internal {
|
|||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Silu = ck::tensor_operation::element_wise::Swish;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface
|
||||
|
|
@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() {
|
|||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<
|
||||
F16, F32, F32, F16, F32, Swish, 5, 3>>>
|
||||
F16, F32, F32, F16, F32, Silu, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F16, F32, F32, F16, F32, Swish, 5, 3>();
|
||||
F16, F32, F32, F16, F32, Silu, 5, 3>();
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<
|
||||
|
|
@ -113,9 +113,9 @@ GetDeviceGroupNormInstances<
|
|||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<
|
||||
F32, F32, F32, F32, F32, Swish, 5, 3>>>
|
||||
F32, F32, F32, F32, F32, Silu, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F32, F32, F32, F32, F32, Swish, 5, 3>();
|
||||
F32, F32, F32, F32, F32, Silu, 5, 3>();
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ namespace rocm {
|
|||
namespace internal {
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Swish, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>> instances;
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Silu, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f16_instances<Swish, 5, 3>{});
|
||||
device_normalization_f16_instances<Silu, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ namespace rocm {
|
|||
namespace internal {
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Silu, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f32_instances<Swish, 5, 3>{});
|
||||
device_normalization_f32_instances<Silu, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,110 +8,47 @@
|
|||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_common_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
using onnxruntime::rocm::CeilDiv;
|
||||
|
||||
int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
|
||||
int32_t maxDivisor = -1;
|
||||
for (int32_t i = 1; i <= std::sqrt(n); i++) {
|
||||
if (n % i == 0) {
|
||||
int32_t divisor1 = n / i;
|
||||
int32_t divisor2 = i;
|
||||
|
||||
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
|
||||
maxDivisor = divisor1;
|
||||
}
|
||||
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
|
||||
maxDivisor = divisor2;
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxDivisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GroupNormNHWCParams : OpParams {
|
||||
GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma,
|
||||
const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish)
|
||||
: OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) {
|
||||
int32_t maxBlocksPerHW = 1024;
|
||||
switch (c) {
|
||||
case 960:
|
||||
case 1920:
|
||||
cPerBlock = 480;
|
||||
break;
|
||||
case 512:
|
||||
case 256:
|
||||
cPerBlock = 256;
|
||||
break;
|
||||
case 128:
|
||||
cPerBlock = 128;
|
||||
break;
|
||||
default:
|
||||
cPerBlock = 320;
|
||||
}
|
||||
|
||||
hw = h * w;
|
||||
const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW);
|
||||
hwPerBlock = CeilDiv(hw, blocksPerHW);
|
||||
cPerGroup = c / groups;
|
||||
hwc = hw * c;
|
||||
invHWC = 1.F / (float)(hw * cPerGroup);
|
||||
groupsPerBlock = cPerBlock / cPerGroup;
|
||||
}
|
||||
struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams<T> {
|
||||
GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx,
|
||||
onnxruntime::Stream* ort_stream,
|
||||
T* output,
|
||||
T* add_out,
|
||||
const T* input,
|
||||
const T* skip,
|
||||
const T* bias,
|
||||
const float* gamma,
|
||||
const float* beta,
|
||||
float* workspace,
|
||||
float epsilon,
|
||||
int batch_size,
|
||||
int num_channels,
|
||||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_silu,
|
||||
bool broadcast_skip,
|
||||
int channels_per_block)
|
||||
: OpParams(tuning_ctx, ort_stream),
|
||||
GroupNormNHWCParams<T>(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size,
|
||||
num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {}
|
||||
|
||||
std::string Signature() const override {
|
||||
std::string swish_suffix = withSwish ? "_Swish" : "_Pass";
|
||||
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix;
|
||||
std::string silu_suffix = this->use_silu ? "_silu" : "_pass";
|
||||
std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip";
|
||||
std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast";
|
||||
std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias";
|
||||
std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" +
|
||||
std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix +
|
||||
skip_suffix + broadcast_suffix + bias_suffix;
|
||||
return sig;
|
||||
}
|
||||
|
||||
// The output buffer. Layout NHWC.
|
||||
T* dst;
|
||||
// The input buffer. Layout NHWC.
|
||||
T const* src;
|
||||
// The gamma scaling factor.
|
||||
float const* gamma;
|
||||
// The beta term to add in GN.
|
||||
float const* beta;
|
||||
// The temporary buffer to do the global parallel reduction. Size:
|
||||
// BLOCKS_PER_BATCH x C x 2.
|
||||
float* redBuffer;
|
||||
float epsilon;
|
||||
|
||||
// The number of instances in the batch.
|
||||
int32_t n;
|
||||
// The height and width of each activation map.
|
||||
int32_t h;
|
||||
int32_t w;
|
||||
// The number of channels.
|
||||
int32_t c;
|
||||
// The number of groups.
|
||||
int32_t groups;
|
||||
// Do we apply the Swish activation function?
|
||||
bool withSwish;
|
||||
|
||||
// Precomputed values and parameters to control the execution of the kernels.
|
||||
|
||||
// The number of activations per instance (h * w) and the number of
|
||||
// activations per block.
|
||||
int32_t hw;
|
||||
int32_t hwPerBlock;
|
||||
// The number of channels per group and blocks per activation in the C
|
||||
// dimension.
|
||||
int32_t cPerBlock;
|
||||
int32_t cPerGroup;
|
||||
|
||||
// The precomputed stride between instances.
|
||||
int32_t hwc;
|
||||
// The inverse of hwc in floats (to compute mean/var).
|
||||
float invHWC;
|
||||
// The precomputed number of groups per block.
|
||||
int32_t groupsPerBlock;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -15,9 +15,12 @@ namespace rocm {
|
|||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
RocmTuningContext* tuning_ctx,
|
||||
Stream* stream,
|
||||
Stream* ort_stream,
|
||||
T* output,
|
||||
T* add_out,
|
||||
const T* input,
|
||||
const T* skip,
|
||||
const T* bias,
|
||||
const float* gamma,
|
||||
const float* beta,
|
||||
void* workspace,
|
||||
|
|
@ -27,19 +30,26 @@ Status LaunchGroupNormKernel(
|
|||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_swish_activation) {
|
||||
if (batch_size > static_cast<int>(kMaxGroupNormBatchSize)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
|
||||
"only support batch_size <= 32. Got", batch_size);
|
||||
bool use_silu,
|
||||
bool broadcast_skip,
|
||||
int channels_per_block) {
|
||||
GroupNormNHWCTunableParams<T> params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta,
|
||||
reinterpret_cast<float*>(workspace), epsilon, batch_size, num_channels,
|
||||
height, width, num_groups, use_silu, broadcast_skip, channels_per_block);
|
||||
|
||||
if (params.channels_per_block % params.channels_per_group != 0 ||
|
||||
params.channels_per_block > kMaxSize ||
|
||||
(params.channels_per_group % CHANNELS_PER_THREAD != 0)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"GroupNorm in ROCM does not support the input: n=", batch_size,
|
||||
" h=", height,
|
||||
" w=", width,
|
||||
" c=", num_channels,
|
||||
" groups=", num_groups);
|
||||
}
|
||||
|
||||
if (num_groups != static_cast<int>(kGroupNormNumberOfGroups)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
|
||||
"only num_groups=32 is supported. Got", num_groups);
|
||||
}
|
||||
|
||||
GroupNormNHWCParams<T> params(tuning_ctx, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation);
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(
|
||||
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle()));
|
||||
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
static GroupNormNHWCTunableOp<T> op;
|
||||
|
|
@ -50,14 +60,17 @@ Status LaunchGroupNormKernel(
|
|||
}
|
||||
|
||||
template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, Stream* stream, half* output,
|
||||
const half* input, const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
half* add_out, const half* input, const half* skip, const half* bias,
|
||||
const float* gamma, const float* beta, void* workspace, float epsilon,
|
||||
int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block);
|
||||
|
||||
template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, Stream* stream, float* output,
|
||||
const float* input, const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
float* add_out, const float* input, const float* skip, const float* bias,
|
||||
const float* gamma, const float* beta, void* workspace, float epsilon,
|
||||
int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
using onnxruntime::rocm::tunable::RocmTuningContext;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
constexpr size_t kMaxGroupNormBatchSize = 32;
|
||||
constexpr size_t kGroupNormNumberOfGroups = 32;
|
||||
|
||||
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
|
||||
// Two buffers for sum and squared sum
|
||||
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
RocmTuningContext* tuning_ctx,
|
||||
Stream* stream,
|
||||
T* output, // normalized output tensor
|
||||
const T* input, // input tensor
|
||||
const float* gamma, // gamma (also known as weight or scale)
|
||||
const float* beta, // beta (also known as bias)
|
||||
void* workspace, // Work space
|
||||
float epsilon, // epsilon used normalization
|
||||
int batch_size, // N
|
||||
int num_channels, // C
|
||||
int height, // H
|
||||
int width, // W
|
||||
int num_groups, // number of groups
|
||||
bool use_swish_activation // Whether there is Swish activation after group normalization
|
||||
);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// The ROCm kernel is modified from TensorRT 8.5.
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
static inline __device__ __host__ float sigmoid(float x) {
|
||||
return 1.F / (1.F + expf(-x));
|
||||
}
|
||||
|
||||
struct GroupSums {
|
||||
// Is it the 1st element of the group?
|
||||
int32_t flag;
|
||||
// The sum.
|
||||
float sum;
|
||||
// The sum of squares.
|
||||
float sumSq;
|
||||
};
|
||||
|
||||
struct GroupSumsOp {
|
||||
inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
|
||||
GroupSums dst;
|
||||
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
|
||||
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
|
||||
dst.flag = a.flag + b.flag;
|
||||
return dst;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int ILP>
|
||||
inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) {
|
||||
using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
|
||||
const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
const U val = static_cast<U>(input_v.val[i]);
|
||||
sum += val;
|
||||
sumSq += val * val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int ILP>
|
||||
__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw,
|
||||
int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) {
|
||||
// The object in charge of doing the sums for the different blocks.
|
||||
typedef hipcub::BlockScan<GroupSums, ThreadsPerBlock> BlockScan;
|
||||
|
||||
// Allocate shared memory for BlockScan.
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
// Allocate shared memory for the groups. We could reduce the amount of shared
|
||||
// memory reserved.
|
||||
__shared__ float2 smem[ThreadsPerBlock];
|
||||
|
||||
// The instance in the batch.
|
||||
int32_t ni = blockIdx.z;
|
||||
// The channel loaded by that thread (ILP channels per thread).
|
||||
int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
|
||||
|
||||
// The first activation loaded by that block.
|
||||
int32_t hwBegin = blockIdx.y * hwPerBlock;
|
||||
// The last activation loaded by that block.
|
||||
int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
|
||||
|
||||
// The sums.
|
||||
float sum = 0.F;
|
||||
float sumSq = 0.F;
|
||||
|
||||
// Iterate over the activations to compute the sums.
|
||||
if (ci < c) {
|
||||
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
|
||||
// The offset.
|
||||
int64_t offset = static_cast<int64_t>(ni) * hwc + static_cast<int64_t>(hwi) * c + ci;
|
||||
UpdateSum<T, float, ILP>(src, offset, sum, sumSq);
|
||||
}
|
||||
}
|
||||
|
||||
// The group that thread works on and the channel in the group (modulus).
|
||||
int32_t gi = threadIdx.x * ILP / cPerGroup;
|
||||
int32_t cj = threadIdx.x * ILP - cPerGroup * gi;
|
||||
|
||||
// The data for the summations.
|
||||
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
|
||||
|
||||
// Do the segmented scan.
|
||||
GroupSums out;
|
||||
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
|
||||
|
||||
// Store the results for the groups in shared memory (to produce coalesced
|
||||
// stores later).
|
||||
if (cj == cPerGroup - ILP) { // ILP channels per thread
|
||||
smem[gi] = make_float2(out.sum, out.sumSq);
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The global group index.
|
||||
int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x;
|
||||
|
||||
// Threads that have nothing left to do, exit.
|
||||
if (threadIdx.x >= groupsPerBlock || gj >= groups) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The first threads (those storing to global memory, load the values).
|
||||
float2 sums = smem[threadIdx.x];
|
||||
|
||||
// Store to global memory.
|
||||
atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x);
|
||||
atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int32_t ILP>
|
||||
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev,
|
||||
const U* gamma_v, const U* beta_v, bool swish) {
|
||||
using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
|
||||
const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
|
||||
VecT output_v;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
U val = static_cast<U>(input_v.val[i]);
|
||||
val = (val - mean) * invStdDev;
|
||||
val = gamma_v[i] * val + beta_v[i];
|
||||
|
||||
if (swish) {
|
||||
val = val * sigmoid(val);
|
||||
}
|
||||
output_v.val[i] = static_cast<T>(val);
|
||||
}
|
||||
*(reinterpret_cast<VecT*>(dst + offset)) = output_v;
|
||||
}
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int ILP>
|
||||
__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock,
|
||||
int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) {
|
||||
// The channel loaded by that thread (ILP channels per thread for F16x2).
|
||||
int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
|
||||
if (ci >= c) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The instance in the batch.
|
||||
int32_t ni = blockIdx.z;
|
||||
|
||||
// The group that thread works on and the channel in the group (modulus).
|
||||
int32_t gi = ci / cPerGroup;
|
||||
|
||||
// Load the sum and sum of squares for the group.
|
||||
float sum = 0.F, sumSq = 0.F;
|
||||
if (gi < groups) {
|
||||
sum = redBuffer[(2 * ni + 0) * groups + gi];
|
||||
sumSq = redBuffer[(2 * ni + 1) * groups + gi];
|
||||
}
|
||||
|
||||
using VecF = onnxruntime::rocm::aligned_vector<float, ILP>;
|
||||
|
||||
const VecF gamma_v = *reinterpret_cast<const VecF*>(gamma + ci);
|
||||
const VecF beta_v = *reinterpret_cast<const VecF*>(beta + ci);
|
||||
|
||||
// Compute the mean.
|
||||
float mean = sum * invHWC;
|
||||
// Compute the variance.
|
||||
float var = sumSq * invHWC - (mean * mean);
|
||||
// Compute the inverse of the stddev.
|
||||
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon);
|
||||
|
||||
// The first activation loaded by that block.
|
||||
int32_t hwBegin = blockIdx.y * hwPerBlock;
|
||||
// The last activation loaded by that block.
|
||||
int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
|
||||
|
||||
// Iterate over the activations to compute the sums.
|
||||
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
|
||||
// The src/dst offset.
|
||||
int64_t offset = (int64_t)ni * hwc + hwi * c + ci;
|
||||
|
||||
// Fetch ILP channels per thread.
|
||||
computeGroupNorm<T, float, ILP>(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -20,21 +20,21 @@ namespace rocm {
|
|||
|
||||
namespace {
|
||||
|
||||
template <typename T, bool WithSwish>
|
||||
template <typename T, bool WithSilu>
|
||||
std::string GetGroupNormTritonGroupName() {
|
||||
std::string ret = "GroupNormTriton_";
|
||||
std::string swish_suffix = WithSwish ? "Swish_" : "Pass_";
|
||||
ret += swish_suffix;
|
||||
std::string silu_suffix = WithSilu ? "Silu_" : "Pass_";
|
||||
ret += silu_suffix;
|
||||
ret += GetDataTypeName<T>();
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, bool WithSwish>
|
||||
template <typename T, bool WithSilu>
|
||||
auto GetTritonGroupNormNHWCTypeStringAndOps() {
|
||||
std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCParams<T>>>> ret;
|
||||
auto group_name = GetGroupNormTritonGroupName<T, WithSwish>();
|
||||
std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
|
||||
auto group_name = GetGroupNormTritonGroupName<T, WithSilu>();
|
||||
auto* kernel_list = GetOrtTritonKernelByGroup(group_name);
|
||||
if (kernel_list == nullptr) {
|
||||
return ret;
|
||||
|
|
@ -45,16 +45,19 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
|
|||
auto* metadata = GetOrtTritonKernelMetadata(i);
|
||||
auto block_size = metadata->constants.at("BLOCK_SIZE");
|
||||
auto hw_size = metadata->constants.at("HW_SIZE");
|
||||
auto impl = [i, block_size, hw_size](const GroupNormNHWCParams<T>* params) -> Status {
|
||||
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
|
||||
"Input skip or bias is not supported by triton kernel.");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size,
|
||||
"Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ").");
|
||||
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
|
||||
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
|
||||
params->channels_per_group, ").");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
|
||||
if constexpr (WithSwish) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish.");
|
||||
if constexpr (WithSilu) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu.");
|
||||
} else {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish.");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu.");
|
||||
}
|
||||
// Construct args for launch kernel
|
||||
struct {
|
||||
|
|
@ -73,7 +76,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
|
|||
(const void*)params->beta,
|
||||
params->hw,
|
||||
params->c,
|
||||
params->cPerGroup,
|
||||
params->channels_per_group,
|
||||
params->epsilon};
|
||||
|
||||
// Grid dim is (batch_count, groups, 1)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def group_norm_kernel(
|
|||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HW_SIZE: tl.constexpr,
|
||||
ACTIVATION_SWISH: tl.constexpr,
|
||||
ACTIVATION_SILU: tl.constexpr,
|
||||
):
|
||||
row_x = tl.program_id(0)
|
||||
row_y = tl.program_id(1)
|
||||
|
|
@ -62,7 +62,7 @@ def group_norm_kernel(
|
|||
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
x_hat = (x - group_mean) * rstd
|
||||
y = x_hat * gamma + beta
|
||||
if ACTIVATION_SWISH:
|
||||
if ACTIVATION_SILU:
|
||||
y *= tl.sigmoid(y)
|
||||
tl.store(y_ptr + offsets, y, mask=mask)
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ def group_norm_kernel(
|
|||
# blocks = [16, 32, 64, 128, 256, 512]
|
||||
# hw_sizes = [8, 16, 32, 64, 128, 256, 512]
|
||||
# but this will result in too many functions and slow down the compilation.
|
||||
with_swish = [True, False]
|
||||
with_silu = [True, False]
|
||||
dtypes = ["fp32", "fp16"]
|
||||
blocks = [16, 32, 64, 128]
|
||||
hw_sizes = [8, 16, 32, 64, 128, 256]
|
||||
|
|
@ -84,14 +84,14 @@ group_pattern = "GroupNormTriton_{}_{}"
|
|||
def get_function_table():
|
||||
func_table = []
|
||||
|
||||
for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks):
|
||||
swish_suffix = "Swish" if swish else "Pass"
|
||||
name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp)
|
||||
group = group_pattern.format(swish_suffix, dtype)
|
||||
for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks):
|
||||
silu_suffix = "Silu" if silu else "Pass"
|
||||
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
|
||||
group = group_pattern.format(silu_suffix, dtype)
|
||||
sig = sig_pattern.format(dtype, dtype)
|
||||
kwargs = {
|
||||
"num_warps": warp,
|
||||
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)},
|
||||
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
|
||||
}
|
||||
func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs}
|
||||
func_table.append(func_desc)
|
||||
|
|
|
|||
|
|
@ -20,115 +20,117 @@ namespace rocm {
|
|||
using onnxruntime::rocm::GPU_WARP_SIZE;
|
||||
|
||||
template <typename T>
|
||||
void groupNormNHWCSum(const GroupNormNHWCParams<T>* params) {
|
||||
// Make sure the values are as we expect.
|
||||
ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0);
|
||||
// Make sure a group does not span multiple blocks.
|
||||
ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
|
||||
|
||||
void GroupNormNHWCSum(const GroupNormNHWCTunableParams<T>* params) {
|
||||
dim3 grid;
|
||||
|
||||
// The number of blocks to compute all the channels.
|
||||
grid.x = params->c / params->cPerBlock;
|
||||
grid.x = DivUp(params->c, params->channels_per_block);
|
||||
// The number of blocks to compute all the activations in a given instance.
|
||||
grid.y = CeilDiv(params->hw, params->hwPerBlock);
|
||||
grid.y = DivUp(params->hw, params->hw_per_block);
|
||||
// The number of instances.
|
||||
grid.z = params->n;
|
||||
|
||||
#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
|
||||
groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
|
||||
params->src, params->redBuffer, params->cPerBlock, \
|
||||
params->hwPerBlock, params->hw, params->hwc, params->c, \
|
||||
params->cPerGroup, params->groups, params->groupsPerBlock); \
|
||||
#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
|
||||
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
|
||||
params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \
|
||||
params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \
|
||||
params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \
|
||||
break;
|
||||
|
||||
switch (params->cPerBlock) {
|
||||
case 320:
|
||||
LAUNCH_GROUPNORM_SUM(256, 2)
|
||||
case 480:
|
||||
LAUNCH_GROUPNORM_SUM(256, 2)
|
||||
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
|
||||
switch (params->threads_per_block) {
|
||||
case 256:
|
||||
LAUNCH_GROUPNORM_SUM(128, 2)
|
||||
LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD)
|
||||
case 192:
|
||||
LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD)
|
||||
case 160:
|
||||
LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD)
|
||||
case 128:
|
||||
LAUNCH_GROUPNORM_SUM(64, 2)
|
||||
LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD)
|
||||
case 64:
|
||||
LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD)
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
Status GroupNormNHWCSumOp(const GroupNormNHWCParams<T>* params) {
|
||||
Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams<T>* params) {
|
||||
dim3 grid;
|
||||
grid.x = params->c / params->cPerBlock;
|
||||
grid.y = CeilDiv(params->hw, params->hwPerBlock);
|
||||
grid.x = DivUp(params->c, params->channels_per_block);
|
||||
grid.y = DivUp(params->hw, params->hw_per_block);
|
||||
grid.z = params->n;
|
||||
|
||||
groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
|
||||
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
|
||||
params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock,
|
||||
params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock);
|
||||
params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias,
|
||||
params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c,
|
||||
params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void groupNormNHWCScale(const GroupNormNHWCParams<T>* params) {
|
||||
// Make sure the dimensions are aligned with what we expect.
|
||||
ORT_ENFORCE(params->c % params->cPerBlock == 0);
|
||||
// Make sure a group does not span multiple blocks.
|
||||
ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
|
||||
|
||||
void GroupNormNHWCScale(const GroupNormNHWCTunableParams<T>* params) {
|
||||
dim3 grid;
|
||||
|
||||
// The number of blocks to compute all the channels.
|
||||
grid.x = params->c / params->cPerBlock;
|
||||
grid.x = DivUp(params->c, params->channels_per_block);
|
||||
// The number of blocks to compute all the activations in a given instance.
|
||||
grid.y = CeilDiv(params->hw, params->hwPerBlock);
|
||||
grid.y = DivUp(params->hw, params->hw_per_block);
|
||||
// The number of instances.
|
||||
grid.z = params->n;
|
||||
|
||||
#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
|
||||
groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize> \
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
|
||||
params->dst, params->src, params->gamma, params->beta, \
|
||||
params->redBuffer, params->epsilon, params->c, params->cPerBlock, \
|
||||
params->cPerGroup, params->groups, params->hwc, params->invHWC, \
|
||||
params->hw, params->hwPerBlock, params->withSwish); \
|
||||
#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
|
||||
GroupNormNHWCScaleKernel<T, VecSize> \
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
|
||||
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \
|
||||
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \
|
||||
params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \
|
||||
params->hw, params->hw_per_block, params->use_silu); \
|
||||
break;
|
||||
|
||||
switch (params->cPerBlock) {
|
||||
case 320:
|
||||
LAUNCH_GROUPNORM_SCALE(256, 2)
|
||||
case 480:
|
||||
LAUNCH_GROUPNORM_SCALE(256, 2)
|
||||
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
|
||||
switch (params->threads_per_block) {
|
||||
case 256:
|
||||
LAUNCH_GROUPNORM_SCALE(128, 2)
|
||||
LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD)
|
||||
case 192:
|
||||
LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD)
|
||||
case 160:
|
||||
LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD)
|
||||
case 128:
|
||||
LAUNCH_GROUPNORM_SCALE(64, 2)
|
||||
LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD)
|
||||
case 64:
|
||||
LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD)
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
Status GroupNormNHWCScaleOp(const GroupNormNHWCParams<T>* params) {
|
||||
Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams<T>* params) {
|
||||
dim3 grid;
|
||||
grid.x = params->c / params->cPerBlock;
|
||||
grid.y = CeilDiv(params->hw, params->hwPerBlock);
|
||||
grid.x = DivUp(params->c, params->channels_per_block);
|
||||
grid.y = DivUp(params->hw, params->hw_per_block);
|
||||
grid.z = params->n;
|
||||
|
||||
groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize>
|
||||
GroupNormNHWCScaleKernel<T, VecSize>
|
||||
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
|
||||
params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock,
|
||||
params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish);
|
||||
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace,
|
||||
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group,
|
||||
params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block,
|
||||
params->use_silu);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
class GroupNormNHWCOp {
|
||||
public:
|
||||
Status operator()(const GroupNormNHWCParams<T>* params) {
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
|
||||
Status operator()(const GroupNormNHWCTunableParams<T>* params) {
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
|
||||
0,
|
||||
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
|
||||
params->StreamHandle()));
|
||||
auto status = GroupNormNHWCSumOp<T, ThreadsPerBlock, VecSize>(params);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
HIP_RETURN_IF_ERROR(hipGetLastError());
|
||||
|
|
@ -138,29 +140,30 @@ class GroupNormNHWCOp {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IsSupported(const GroupNormNHWCParams<T>* params) {
|
||||
Status IsSupported(const GroupNormNHWCTunableParams<T>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0),
|
||||
"The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup,
|
||||
!(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0),
|
||||
"The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group,
|
||||
") isn't divisible by the number of vector size: ", VecSize);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 &&
|
||||
params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0),
|
||||
"The value of attributes don't meet the requirements.");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize &&
|
||||
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize &&
|
||||
params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
|
||||
"Configuration: Threads (", ThreadsPerBlock, "), vector size (",
|
||||
VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock);
|
||||
VecSize, ") is redundant for the number of channels per group: ",
|
||||
params->channels_per_block);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
|
||||
groupNormNHWCSum<T>(params);
|
||||
Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams<T>* params) {
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
|
||||
0,
|
||||
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
|
||||
params->StreamHandle()));
|
||||
GroupNormNHWCSum<T>(params);
|
||||
HIP_RETURN_IF_ERROR(hipGetLastError());
|
||||
groupNormNHWCScale<T>(params);
|
||||
GroupNormNHWCScale<T>(params);
|
||||
HIP_RETURN_IF_ERROR(hipGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
|
|||
ADD_OP_FOR_ALL_VEC_SIZE(name, 320)
|
||||
|
||||
template <typename T>
|
||||
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
|
||||
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCTunableParams<T>> {
|
||||
public:
|
||||
GroupNormNHWCTunableOp() {
|
||||
this->RegisterOp(GroupNormNHWCStaticSelection<T>);
|
||||
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp)
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/false>()) {
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/false>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/true>()) {
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/true>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
#ifdef USE_TRITON_KERNEL
|
||||
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/false>()) {
|
||||
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/false>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/true>()) {
|
||||
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/true>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
|
||||
|
|
@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,11 @@ def sigmoid_function(x):
|
|||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
|
||||
def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
|
||||
def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip):
|
||||
add_output = None
|
||||
if has_skip:
|
||||
input_x = input_x + skip_x + bias_x
|
||||
add_output = input_x
|
||||
n, h, w, c = input_x.shape
|
||||
input_x = input_x.transpose([0, 3, 1, 2])
|
||||
assert c % num_groups == 0
|
||||
|
|
@ -45,46 +49,70 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
|
|||
x = x.transpose([0, 2, 3, 1])
|
||||
x = x * gamma + beta
|
||||
|
||||
if with_swish:
|
||||
if with_silu:
|
||||
x = x * sigmoid_function(x)
|
||||
return x
|
||||
return x, add_output
|
||||
|
||||
|
||||
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func):
|
||||
def run_group_norm(
|
||||
batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func
|
||||
):
|
||||
np.random.seed(0)
|
||||
width = height
|
||||
input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
|
||||
gamma = np.random.rand(num_channels).astype(np.float32)
|
||||
beta = np.random.rand(num_channels).astype(np.float32)
|
||||
# the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18
|
||||
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32)
|
||||
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32)
|
||||
epsilon = 1e-05
|
||||
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
use_swish = swish
|
||||
|
||||
host_x = input_x.astype(dtype)
|
||||
input_d = ke.DeviceArray(host_x)
|
||||
skip_x = (
|
||||
np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
|
||||
if has_skip
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype)
|
||||
add_output = (
|
||||
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
if has_skip
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
use_silu = silu
|
||||
broadcast_skip = False
|
||||
channels_per_block = 0 # Compute in params initialization
|
||||
|
||||
input_d = ke.DeviceArray(input_x.astype(dtype))
|
||||
skip_d = ke.DeviceArray(skip_x.astype(dtype))
|
||||
bias_d = ke.DeviceArray(bias_x.astype(dtype))
|
||||
gamma_d = ke.DeviceArray(gamma)
|
||||
beta_d = ke.DeviceArray(beta)
|
||||
workspace_d = ke.DeviceArray(workspace)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
y_add_d = ke.DeviceArray(add_output)
|
||||
f = getattr(ke, func)
|
||||
|
||||
my_op = f(
|
||||
y_d,
|
||||
workspace_d,
|
||||
y_add_d,
|
||||
input_d,
|
||||
skip_d,
|
||||
bias_d,
|
||||
gamma_d,
|
||||
beta_d,
|
||||
workspace_d,
|
||||
epsilon,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_channels,
|
||||
num_groups,
|
||||
epsilon,
|
||||
use_swish,
|
||||
use_silu,
|
||||
broadcast_skip,
|
||||
channels_per_block,
|
||||
)
|
||||
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
|
||||
y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip)
|
||||
y_ref = y_ref.astype(dtype)
|
||||
|
||||
for impl in my_op.ListOps():
|
||||
if not my_op.SelectOp(impl):
|
||||
|
|
@ -95,6 +123,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
|
|||
y_d.UpdateHostNumpyArray()
|
||||
|
||||
np.testing.assert_allclose(y_ref, output_y, atol=1e-02)
|
||||
if has_skip:
|
||||
y_add_d_ref = y_add_d_ref.astype(dtype)
|
||||
y_add_d.UpdateHostNumpyArray()
|
||||
np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02)
|
||||
|
||||
|
||||
dtypes = ["float32", "float16"]
|
||||
|
|
@ -102,19 +134,21 @@ dtypes = ["float32", "float16"]
|
|||
|
||||
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("swish", [True])
|
||||
def test_group_norm(sd_sizes, dtype, swish):
|
||||
@pytest.mark.parametrize("silu", [True])
|
||||
@pytest.mark.parametrize("has_skip", [True, False])
|
||||
def test_group_norm(sd_sizes, dtype, silu, has_skip):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
run_group_norm(*sd_sizes, dtype, swish, func)
|
||||
run_group_norm(*sd_sizes, dtype, silu, has_skip, func)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("swish", [True])
|
||||
def test_group_norm_ck(sd_sizes, dtype, swish):
|
||||
swish_suffix = "Swish" if swish else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
|
||||
run_group_norm(*sd_sizes, dtype, swish, ck_f_name)
|
||||
@pytest.mark.parametrize("silu", [True])
|
||||
@pytest.mark.parametrize("has_skip", [False])
|
||||
def test_group_norm_ck(sd_sizes, dtype, silu, has_skip):
|
||||
silu_suffix = "Silu" if silu else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
|
||||
run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -136,37 +170,67 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
|
|||
|
||||
|
||||
def profile_group_norm_func(
|
||||
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func
|
||||
batch_size: int,
|
||||
height: int,
|
||||
width: int,
|
||||
num_channels: int,
|
||||
num_groups: int,
|
||||
dtype: str,
|
||||
silu: bool,
|
||||
has_skip: bool,
|
||||
func,
|
||||
):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
gamma = np.random.rand(num_channels).astype(np.float32)
|
||||
beta = np.random.rand(num_channels).astype(np.float32)
|
||||
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32)
|
||||
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32)
|
||||
epsilon = 0.05
|
||||
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
use_swish = swish
|
||||
|
||||
skip_x = (
|
||||
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
if has_skip
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype)
|
||||
add_output = (
|
||||
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
if has_skip
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
use_silu = silu
|
||||
broadcast_skip = False
|
||||
channels_per_block = 0 # Compute in params initialization
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip_x)
|
||||
bias_d = ke.DeviceArray(bias_x)
|
||||
gamma_d = ke.DeviceArray(gamma)
|
||||
beta_d = ke.DeviceArray(beta)
|
||||
workspace_d = ke.DeviceArray(workspace)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
y_add_d = ke.DeviceArray(add_output)
|
||||
f = getattr(ke, func)
|
||||
|
||||
my_op = f(
|
||||
y_d,
|
||||
workspace_d,
|
||||
y_add_d,
|
||||
input_d,
|
||||
skip_d,
|
||||
bias_d,
|
||||
gamma_d,
|
||||
beta_d,
|
||||
workspace_d,
|
||||
epsilon,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_channels,
|
||||
num_groups,
|
||||
epsilon,
|
||||
use_swish,
|
||||
use_silu,
|
||||
broadcast_skip,
|
||||
channels_per_block,
|
||||
)
|
||||
for impl in my_op.ListOps():
|
||||
duration_ms = -1
|
||||
|
|
@ -181,14 +245,14 @@ def profile_group_norm_func(
|
|||
)
|
||||
|
||||
|
||||
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True):
|
||||
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func)
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func)
|
||||
# ck function
|
||||
swish_suffix = "Swish" if swish else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name)
|
||||
silu_suffix = "Silu" if silu else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name)
|
||||
|
||||
|
||||
sd_profile_sizes = [
|
||||
|
|
@ -227,7 +291,8 @@ if __name__ == "__main__":
|
|||
group.add_argument("num_channels", type=int)
|
||||
group.add_argument("num_groups", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--swish", action="store_true")
|
||||
group.add_argument("--silu", action="store_true")
|
||||
group.add_argument("--has_skip", action="store_true")
|
||||
group.add_argument("--sort", action="store_true")
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
|
|
@ -241,6 +306,7 @@ if __name__ == "__main__":
|
|||
args.num_channels,
|
||||
args.num_groups,
|
||||
args.dtype,
|
||||
args.swish,
|
||||
args.silu,
|
||||
args.has_skip,
|
||||
args.sort,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,17 +12,21 @@
|
|||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes;
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
class GroupNormNHWC : public IKernelExplorer {
|
||||
public:
|
||||
GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon,
|
||||
int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu,
|
||||
bool broadcast_skip, int channels_per_block)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
|
||||
channels_per_block) {
|
||||
type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize);
|
||||
}
|
||||
|
||||
|
|
@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer {
|
|||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::GroupNormNHWCOp<T, ThreadsPerBlock, VecSize> op_{};
|
||||
std::string type_string_{};
|
||||
|
|
@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer {
|
|||
template <typename T>
|
||||
class GroupNormNHWCStaticSelection : public IKernelExplorer {
|
||||
public:
|
||||
GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
|
||||
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
|
||||
channels_per_block) {
|
||||
type_string_ = "GroupNormNHWCStaticSelection";
|
||||
}
|
||||
|
||||
|
|
@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
|
|||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
|
||||
ParamsT params_{};
|
||||
std::string type_string_{};
|
||||
};
|
||||
|
|
@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
|
|||
template <typename T>
|
||||
class GroupNormNHWCTunable : public IKernelExplorer {
|
||||
public:
|
||||
GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
|
||||
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
|
||||
channels_per_block) {
|
||||
params_.TuningContext()->EnableTunableOpAndTuning();
|
||||
}
|
||||
|
||||
|
|
@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer {
|
|||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::GroupNormNHWCTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
template <typename T, bool WithSwish>
|
||||
template <typename T, bool WithSilu>
|
||||
class CKGroupNormNHWC : public IKernelExplorer {
|
||||
public:
|
||||
CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSwish>()) {
|
||||
CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
|
||||
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
|
||||
channels_per_block) {
|
||||
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSilu>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
|
|
@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer {
|
|||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
|
||||
using OpT = rocm::tunable::Op<ParamsT>;
|
||||
ParamsT params_{};
|
||||
std::vector<OpT> ops_;
|
||||
|
|
@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer {
|
|||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
#ifdef USE_TRITON_KERNEL
|
||||
template <typename T, bool WithSwish>
|
||||
template <typename T, bool WithSilu>
|
||||
class GroupNormNHWCTriton : public IKernelExplorer {
|
||||
public:
|
||||
GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSwish>()) {
|
||||
GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
|
||||
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
|
||||
bool use_silu, bool broadcast_skip, int channels_per_block)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
|
||||
channels_per_block) {
|
||||
for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSilu>()) {
|
||||
name_strings_.emplace_back(name);
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
|
|
@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer {
|
|||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
|
||||
using OpT = rocm::tunable::Op<ParamsT>;
|
||||
ParamsT params_{};
|
||||
std::vector<OpT> ops_;
|
||||
|
|
@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
|
|||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
int, int, int, int, int, float, bool>()) \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, float, \
|
||||
int, int, int, int, int, bool, bool, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
|
|
@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
|
|||
#define REGISTER_COMMON(name, type, ...) \
|
||||
py::class_<type<__VA_ARGS__>>(m, name) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
int, int, int, int, int, float, bool>()) \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, float, \
|
||||
int, int, int, int, int, bool, bool, int>()) \
|
||||
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
|
||||
.def("Profile", &type<__VA_ARGS__>::Profile) \
|
||||
.def("Run", &type<__VA_ARGS__>::Run) \
|
||||
|
|
@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer {
|
|||
#define REGISTER_OP_TYPED(name, type) \
|
||||
REGISTER_COMMON(#name "_" #type, name, type)
|
||||
|
||||
#define REGISTER_CK(type, with_swish, swish_suffix) \
|
||||
REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish)
|
||||
#define REGISTER_CK(type, with_silu, silu_suffix) \
|
||||
REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu)
|
||||
|
||||
#define REGISTER_TRITON(type, with_swish, swish_suffix) \
|
||||
REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish)
|
||||
#define REGISTER_TRITON(type, with_silu, silu_suffix) \
|
||||
REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu)
|
||||
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half);
|
||||
|
|
@ -248,16 +270,16 @@ KE_REGISTER(m) {
|
|||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
REGISTER_CK(half, false, "Pass");
|
||||
REGISTER_CK(half, true, "Swish");
|
||||
REGISTER_CK(half, true, "Silu");
|
||||
REGISTER_CK(float, false, "Pass");
|
||||
REGISTER_CK(float, true, "Swish");
|
||||
REGISTER_CK(float, true, "Silu");
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
#ifdef USE_TRITON_KERNEL
|
||||
REGISTER_TRITON(half, false, "Pass");
|
||||
REGISTER_TRITON(half, true, "Swish");
|
||||
REGISTER_TRITON(half, true, "Silu");
|
||||
REGISTER_TRITON(float, false, "Pass");
|
||||
REGISTER_TRITON(float, true, "Swish");
|
||||
REGISTER_TRITON(float, true, "Silu");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) {
|
|||
|
||||
int min_cuda_architecture = 530;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
|
||||
|
||||
std::array<int, 2> channels_last_values = {-1, 1};
|
||||
|
||||
for (const int channels_last : channels_last_values) {
|
||||
if (enable_cuda) {
|
||||
if (enable_cuda || enable_rocm) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
if (enable_rocm && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
}
|
||||
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
|
|
@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
|
|||
|
||||
int min_cuda_architecture = 530;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
|
||||
|
||||
std::array<bool, 2> has_add_out_values = {true, false};
|
||||
std::array<int, 2> skip_dims = {2, 4};
|
||||
|
|
@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
|
|||
constexpr int channels_last = 1;
|
||||
for (const int skip_dim : skip_dims) {
|
||||
for (const bool has_add_out : has_add_out_values) {
|
||||
if (enable_cuda) {
|
||||
if (enable_cuda || enable_rocm) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
if (enable_rocm && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
}
|
||||
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
|
|||
s = s.replace("rocm_device_prop_", "cuda_device_prop_")
|
||||
s = s.replace("rocm_device_arch_", "cuda_device_arch_")
|
||||
|
||||
s = s.replace("HipTuningContext", "RocmTuningContext")
|
||||
|
||||
# We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names
|
||||
# And we do this last, undoing or fixing hipify mistakes.
|
||||
if "fft" in src_file_path:
|
||||
|
|
|
|||
Loading…
Reference in a new issue