[ROCm] Add SkipGroupNorm for ROCm EP (#19303)

Add SkipGroupNorm for ROCm EP.

---------

Co-authored-by: Peixuan Zuo <peixuanzuo@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
PeixuanZuo 2024-02-21 11:08:48 +08:00 committed by GitHub
parent 8fadc6c913
commit 6226c5f62f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 383 additions and 733 deletions

View file

@ -44,12 +44,7 @@ set(contrib_ops_excluded_files
"bert/packed_multihead_attention.cc"
"bert/packed_multihead_attention_impl.h"
"bert/packed_multihead_attention_impl.cu"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/group_norm_impl_kernel.cuh"
"diffusion/group_norm_common_base.h"
"diffusion/group_norm_common_base.cc"
"diffusion/nhwc_conv.cc"
"math/gemm_float8.cc"
"math/gemm_float8.cu"

View file

@ -1,152 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/diffusion/group_norm.h"
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define GROUP_NORM_TYPES float, MLFloat16
ONNX_OPERATOR_KERNEL_EX(
GroupNorm, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
using namespace ONNX_NAMESPACE;
namespace {
template <typename T>
struct DispatchGroupNorm {
Status operator()(RocmTuningContext* tuning_ctx,
Stream* stream,
Tensor* output,
const Tensor* input,
const Tensor* gamma,
const Tensor* beta,
void* workspace,
float epsilon,
int batch_size,
int num_channels,
int height,
int width,
int num_groups,
bool use_swish_activation) {
typedef typename ToHipType<T>::MappedType HipT;
return LaunchGroupNormKernel<HipT>(
tuning_ctx,
stream,
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
gamma->Data<float>(),
beta->Data<float>(),
workspace,
epsilon,
batch_size,
num_channels,
height,
width,
num_groups,
use_swish_activation);
}
};
} // namespace
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
ORT_ENFORCE(epsilon_ >= 0);
int64_t num_groups;
ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
ORT_ENFORCE(num_groups >= 0);
num_groups_ = static_cast<int>(num_groups);
int64_t activation;
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
use_swish_activation_ = (activation == 1);
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}
Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
return Status::OK();
}
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);
const Tensor* beta = context->Input<Tensor>(2);
Tensor* output = context->Output(0, input->Shape());
if (!channels_last_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"only the channels_last layout is supported");
}
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 4 dimensions, got ", input_dims.size());
}
const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[3]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in gamma and input does not match");
}
const auto& beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[3]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in beta and input does not match");
}
// Input and output format is NHWC
int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[3]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);
if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisible by num_groups");
}
if (context->GetUseDeterministicCompute()) {
static std::once_flag log_warning;
std::call_once(log_warning, []() {
LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic.";
});
}
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(), context->GetComputeStream(),
output, input, gamma, beta, workspace.get(),
epsilon_,
batch_size,
num_channels,
height,
width,
num_groups_,
use_swish_activation_);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -26,13 +26,18 @@ namespace rocm {
using onnxruntime::rocm::CKDataTypeAdaptor;
using Swish = ck::tensor_operation::element_wise::Swish;
// The SiLU function is a special case of Swish function,
// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as:
// SiLU(x) = x * sigmoid(x)
// Swish(x) = x * sigmoid(bx)
// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here.
using Silu = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
template <typename T, typename AccT, bool WithSwish>
template <typename T, typename AccT, bool WithSilu>
auto GetCKGroupNormNHWCTypeStringAndOps() {
using XDataType = typename CKDataTypeAdaptor<T>::type;
using YDataType = typename CKDataTypeAdaptor<T>::type;
@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
using GammaDataType = float;
using BetaDataType = float;
using Activation = std::conditional_t<WithSwish, Swish, Pass>;
using Activation = std::conditional_t<WithSilu, Silu, Pass>;
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
for (auto&& impl : internal::GetDeviceGroupNormInstances<XDataType, GammaDataType, BetaDataType, YDataType,
SaveMeanInvStdDataType, Activation, Rank, NumReduceDim>()) {
std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
std::string silu_suffix = WithSilu ? "_Silu" : "_Pass";
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix;
auto invoker = impl->MakeInvokerPointer();
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
if constexpr (WithSwish) {
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](
const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Input skip or bias is not supported by composable kernel.");
if constexpr (WithSilu) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->withSwish, "Swish version only support groupnorm with swish");
!params->use_silu, "Silu version only support groupnorm with silu");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->withSwish, "Pass version only support groupnorm without swish");
params->use_silu, "Pass version only support groupnorm without silu");
}
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c,
params->c, params->channels_per_group, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->channels_per_group, 1};
std::vector<ck::index_t> reduce_dims{1, 2, 4};
auto activation = Activation{};

View file

@ -18,7 +18,7 @@ namespace internal {
using F16 = ck::half_t;
using F32 = float;
using Swish = ck::tensor_operation::element_wise::Swish;
using Silu = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;
using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface
@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() {
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
F16, F32, F32, F16, F32, Swish, 5, 3>>>
F16, F32, F32, F16, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<
F16, F32, F32, F16, F32, Swish, 5, 3>();
F16, F32, F32, F16, F32, Silu, 5, 3>();
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
@ -113,9 +113,9 @@ GetDeviceGroupNormInstances<
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
F32, F32, F32, F32, F32, Swish, 5, 3>>>
F32, F32, F32, F32, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<
F32, F32, F32, F32, F32, Swish, 5, 3>();
F32, F32, F32, F32, F32, Silu, 5, 3>();
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<

View file

@ -11,12 +11,12 @@ namespace rocm {
namespace internal {
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>> instances;
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Silu, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f16_instances<Swish, 5, 3>{});
device_normalization_f16_instances<Silu, 5, 3>{});
return instances;
}

View file

@ -11,12 +11,12 @@ namespace rocm {
namespace internal {
template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Silu, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f32_instances<Swish, 5, 3>{});
device_normalization_f32_instances<Silu, 5, 3>{});
return instances;
}

View file

@ -8,110 +8,47 @@
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
#include "contrib_ops/rocm/diffusion/group_norm_common_base.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using onnxruntime::rocm::CeilDiv;
int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
}
}
}
return maxDivisor;
}
template <typename T>
struct GroupNormNHWCParams : OpParams {
GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma,
const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish)
: OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) {
int32_t maxBlocksPerHW = 1024;
switch (c) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
hw = h * w;
const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW);
hwPerBlock = CeilDiv(hw, blocksPerHW);
cPerGroup = c / groups;
hwc = hw * c;
invHWC = 1.F / (float)(hw * cPerGroup);
groupsPerBlock = cPerBlock / cPerGroup;
}
struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams<T> {
GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx,
onnxruntime::Stream* ort_stream,
T* output,
T* add_out,
const T* input,
const T* skip,
const T* bias,
const float* gamma,
const float* beta,
float* workspace,
float epsilon,
int batch_size,
int num_channels,
int height,
int width,
int num_groups,
bool use_silu,
bool broadcast_skip,
int channels_per_block)
: OpParams(tuning_ctx, ort_stream),
GroupNormNHWCParams<T>(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size,
num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {}
std::string Signature() const override {
std::string swish_suffix = withSwish ? "_Swish" : "_Pass";
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix;
std::string silu_suffix = this->use_silu ? "_silu" : "_pass";
std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip";
std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast";
std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias";
std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" +
std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix +
skip_suffix + broadcast_suffix + bias_suffix;
return sig;
}
// The output buffer. Layout NHWC.
T* dst;
// The input buffer. Layout NHWC.
T const* src;
// The gamma scaling factor.
float const* gamma;
// The beta term to add in GN.
float const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
float epsilon;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h;
int32_t w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Swish activation function?
bool withSwish;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw;
int32_t hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock;
int32_t cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
};
} // namespace rocm

View file

@ -15,9 +15,12 @@ namespace rocm {
template <typename T>
Status LaunchGroupNormKernel(
RocmTuningContext* tuning_ctx,
Stream* stream,
Stream* ort_stream,
T* output,
T* add_out,
const T* input,
const T* skip,
const T* bias,
const float* gamma,
const float* beta,
void* workspace,
@ -27,19 +30,26 @@ Status LaunchGroupNormKernel(
int height,
int width,
int num_groups,
bool use_swish_activation) {
if (batch_size > static_cast<int>(kMaxGroupNormBatchSize)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"only support batch_size <= 32. Got", batch_size);
bool use_silu,
bool broadcast_skip,
int channels_per_block) {
GroupNormNHWCTunableParams<T> params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta,
reinterpret_cast<float*>(workspace), epsilon, batch_size, num_channels,
height, width, num_groups, use_silu, broadcast_skip, channels_per_block);
if (params.channels_per_block % params.channels_per_group != 0 ||
params.channels_per_block > kMaxSize ||
(params.channels_per_group % CHANNELS_PER_THREAD != 0)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GroupNorm in ROCM does not support the input: n=", batch_size,
" h=", height,
" w=", width,
" c=", num_channels,
" groups=", num_groups);
}
if (num_groups != static_cast<int>(kGroupNormNumberOfGroups)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"only num_groups=32 is supported. Got", num_groups);
}
GroupNormNHWCParams<T> params(tuning_ctx, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation);
HIP_RETURN_IF_ERROR(hipMemsetAsync(
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle()));
if (tuning_ctx->IsTunableOpEnabled()) {
static GroupNormNHWCTunableOp<T> op;
@ -50,14 +60,17 @@ Status LaunchGroupNormKernel(
}
template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, Stream* stream, half* output,
const half* input, const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
half* add_out, const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace, float epsilon,
int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block);
template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, Stream* stream, float* output,
const float* input, const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
float* add_out, const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace, float epsilon,
int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,47 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
#include <hip/hip_runtime.h>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
using onnxruntime::rocm::tunable::RocmTuningContext;
namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr size_t kMaxGroupNormBatchSize = 32;
constexpr size_t kGroupNormNumberOfGroups = 32;
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
// Two buffers for sum and squared sum
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
}
template <typename T>
Status LaunchGroupNormKernel(
RocmTuningContext* tuning_ctx,
Stream* stream,
T* output, // normalized output tensor
const T* input, // input tensor
const float* gamma, // gamma (also known as weight or scale)
const float* beta, // beta (also known as bias)
void* workspace, // Work space
float epsilon, // epsilon used normalization
int batch_size, // N
int num_channels, // C
int height, // H
int width, // W
int num_groups, // number of groups
bool use_swish_activation // Whether there is Swish activation after group normalization
);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,213 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// The ROCm kernel is modified from TensorRT 8.5.
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <hip/hip_fp16.h>
#include <hip/hip_runtime_api.h>
#include <hipcub/hipcub.hpp>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
}
struct GroupSums {
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sumSq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};
template <typename T, typename U, int ILP>
inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) {
using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
#pragma unroll
for (int i = 0; i < ILP; i++) {
const U val = static_cast<U>(input_v.val[i]);
sum += val;
sumSq += val * val;
}
}
template <typename T, int ThreadsPerBlock, int ILP>
__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw,
int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) {
// The object in charge of doing the sums for the different blocks.
typedef hipcub::BlockScan<GroupSums, ThreadsPerBlock> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[ThreadsPerBlock];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (ILP channels per thread).
int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
if (ci < c) {
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * hwc + static_cast<int64_t>(hwi) * c + ci;
UpdateSum<T, float, ILP>(src, offset, sum, sumSq);
}
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * ILP / cPerGroup;
int32_t cj = threadIdx.x * ILP - cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if (cj == cPerGroup - ILP) { // ILP channels per thread
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= groupsPerBlock || gj >= groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x);
atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y);
}
template <typename T, typename U, int32_t ILP>
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev,
const U* gamma_v, const U* beta_v, bool swish) {
using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
VecT output_v;
#pragma unroll
for (int i = 0; i < ILP; i++) {
U val = static_cast<U>(input_v.val[i]);
val = (val - mean) * invStdDev;
val = gamma_v[i] * val + beta_v[i];
if (swish) {
val = val * sigmoid(val);
}
output_v.val[i] = static_cast<T>(val);
}
*(reinterpret_cast<VecT*>(dst + offset)) = output_v;
}
template <typename T, int ThreadsPerBlock, int ILP>
__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock,
int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) {
// The channel loaded by that thread (ILP channels per thread for F16x2).
int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
if (ci >= c) {
return;
}
// The instance in the batch.
int32_t ni = blockIdx.z;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < groups) {
sum = redBuffer[(2 * ni + 0) * groups + gi];
sumSq = redBuffer[(2 * ni + 1) * groups + gi];
}
using VecF = onnxruntime::rocm::aligned_vector<float, ILP>;
const VecF gamma_v = *reinterpret_cast<const VecF*>(gamma + ci);
const VecF beta_v = *reinterpret_cast<const VecF*>(beta + ci);
// Compute the mean.
float mean = sum * invHWC;
// Compute the variance.
float var = sumSq * invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * hwc + hwi * c + ci;
// Fetch ILP channels per thread.
computeGroupNorm<T, float, ILP>(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -20,21 +20,21 @@ namespace rocm {
namespace {
template <typename T, bool WithSwish>
template <typename T, bool WithSilu>
std::string GetGroupNormTritonGroupName() {
std::string ret = "GroupNormTriton_";
std::string swish_suffix = WithSwish ? "Swish_" : "Pass_";
ret += swish_suffix;
std::string silu_suffix = WithSilu ? "Silu_" : "Pass_";
ret += silu_suffix;
ret += GetDataTypeName<T>();
return ret;
}
} // namespace
template <typename T, bool WithSwish>
template <typename T, bool WithSilu>
auto GetTritonGroupNormNHWCTypeStringAndOps() {
std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCParams<T>>>> ret;
auto group_name = GetGroupNormTritonGroupName<T, WithSwish>();
std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
auto group_name = GetGroupNormTritonGroupName<T, WithSilu>();
auto* kernel_list = GetOrtTritonKernelByGroup(group_name);
if (kernel_list == nullptr) {
return ret;
@ -45,16 +45,19 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto* metadata = GetOrtTritonKernelMetadata(i);
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCParams<T>* params) -> Status {
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Input skip or bias is not supported by triton kernel.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ").");
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
params->channels_per_group, ").");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
if constexpr (WithSwish) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish.");
if constexpr (WithSilu) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu.");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu.");
}
// Construct args for launch kernel
struct {
@ -73,7 +76,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
(const void*)params->beta,
params->hw,
params->c,
params->cPerGroup,
params->channels_per_group,
params->epsilon};
// Grid dim is (batch_count, groups, 1)

View file

@ -21,7 +21,7 @@ def group_norm_kernel(
eps,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SWISH: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
):
row_x = tl.program_id(0)
row_y = tl.program_id(1)
@ -62,7 +62,7 @@ def group_norm_kernel(
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SWISH:
if ACTIVATION_SILU:
y *= tl.sigmoid(y)
tl.store(y_ptr + offsets, y, mask=mask)
@ -71,7 +71,7 @@ def group_norm_kernel(
# blocks = [16, 32, 64, 128, 256, 512]
# hw_sizes = [8, 16, 32, 64, 128, 256, 512]
# but this will result in too many functions and slow down the compilation.
with_swish = [True, False]
with_silu = [True, False]
dtypes = ["fp32", "fp16"]
blocks = [16, 32, 64, 128]
hw_sizes = [8, 16, 32, 64, 128, 256]
@ -84,14 +84,14 @@ group_pattern = "GroupNormTriton_{}_{}"
def get_function_table():
func_table = []
for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks):
swish_suffix = "Swish" if swish else "Pass"
name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(swish_suffix, dtype)
for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks):
silu_suffix = "Silu" if silu else "Pass"
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(silu_suffix, dtype)
sig = sig_pattern.format(dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)},
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
}
func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs}
func_table.append(func_desc)

View file

@ -20,115 +20,117 @@ namespace rocm {
using onnxruntime::rocm::GPU_WARP_SIZE;
template <typename T>
void groupNormNHWCSum(const GroupNormNHWCParams<T>* params) {
// Make sure the values are as we expect.
ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0);
// Make sure a group does not span multiple blocks.
ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
void GroupNormNHWCSum(const GroupNormNHWCTunableParams<T>* params) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params->c / params->cPerBlock;
grid.x = DivUp(params->c, params->channels_per_block);
// The number of blocks to compute all the activations in a given instance.
grid.y = CeilDiv(params->hw, params->hwPerBlock);
grid.y = DivUp(params->hw, params->hw_per_block);
// The number of instances.
grid.z = params->n;
#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->src, params->redBuffer, params->cPerBlock, \
params->hwPerBlock, params->hw, params->hwc, params->c, \
params->cPerGroup, params->groups, params->groupsPerBlock); \
#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \
params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \
params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \
break;
switch (params->cPerBlock) {
case 320:
LAUNCH_GROUPNORM_SUM(256, 2)
case 480:
LAUNCH_GROUPNORM_SUM(256, 2)
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params->threads_per_block) {
case 256:
LAUNCH_GROUPNORM_SUM(128, 2)
LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD)
case 192:
LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD)
case 160:
LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD)
case 128:
LAUNCH_GROUPNORM_SUM(64, 2)
LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD)
case 64:
LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD)
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
}
template <typename T, int ThreadsPerBlock, int VecSize>
Status GroupNormNHWCSumOp(const GroupNormNHWCParams<T>* params) {
Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams<T>* params) {
dim3 grid;
grid.x = params->c / params->cPerBlock;
grid.y = CeilDiv(params->hw, params->hwPerBlock);
grid.x = DivUp(params->c, params->channels_per_block);
grid.y = DivUp(params->hw, params->hw_per_block);
grid.z = params->n;
groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock,
params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock);
params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias,
params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c,
params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip);
return HIP_CALL(hipGetLastError());
}
template <typename T>
void groupNormNHWCScale(const GroupNormNHWCParams<T>* params) {
// Make sure the dimensions are aligned with what we expect.
ORT_ENFORCE(params->c % params->cPerBlock == 0);
// Make sure a group does not span multiple blocks.
ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
void GroupNormNHWCScale(const GroupNormNHWCTunableParams<T>* params) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params->c / params->cPerBlock;
grid.x = DivUp(params->c, params->channels_per_block);
// The number of blocks to compute all the activations in a given instance.
grid.y = CeilDiv(params->hw, params->hwPerBlock);
grid.y = DivUp(params->hw, params->hw_per_block);
// The number of instances.
grid.z = params->n;
#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->dst, params->src, params->gamma, params->beta, \
params->redBuffer, params->epsilon, params->c, params->cPerBlock, \
params->cPerGroup, params->groups, params->hwc, params->invHWC, \
params->hw, params->hwPerBlock, params->withSwish); \
#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
GroupNormNHWCScaleKernel<T, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \
params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \
params->hw, params->hw_per_block, params->use_silu); \
break;
switch (params->cPerBlock) {
case 320:
LAUNCH_GROUPNORM_SCALE(256, 2)
case 480:
LAUNCH_GROUPNORM_SCALE(256, 2)
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params->threads_per_block) {
case 256:
LAUNCH_GROUPNORM_SCALE(128, 2)
LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD)
case 192:
LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD)
case 160:
LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD)
case 128:
LAUNCH_GROUPNORM_SCALE(64, 2)
LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD)
case 64:
LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD)
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
}
template <typename T, int ThreadsPerBlock, int VecSize>
Status GroupNormNHWCScaleOp(const GroupNormNHWCParams<T>* params) {
Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams<T>* params) {
dim3 grid;
grid.x = params->c / params->cPerBlock;
grid.y = CeilDiv(params->hw, params->hwPerBlock);
grid.x = DivUp(params->c, params->channels_per_block);
grid.y = DivUp(params->hw, params->hw_per_block);
grid.z = params->n;
groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize>
GroupNormNHWCScaleKernel<T, VecSize>
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock,
params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish);
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace,
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group,
params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block,
params->use_silu);
return HIP_CALL(hipGetLastError());
}
template <typename T, int ThreadsPerBlock, int VecSize>
class GroupNormNHWCOp {
public:
Status operator()(const GroupNormNHWCParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
Status operator()(const GroupNormNHWCTunableParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
0,
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
params->StreamHandle()));
auto status = GroupNormNHWCSumOp<T, ThreadsPerBlock, VecSize>(params);
ORT_RETURN_IF_ERROR(status);
HIP_RETURN_IF_ERROR(hipGetLastError());
@ -138,29 +140,30 @@ class GroupNormNHWCOp {
return Status::OK();
}
Status IsSupported(const GroupNormNHWCParams<T>* params) {
Status IsSupported(const GroupNormNHWCTunableParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0),
"The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup,
!(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0),
"The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group,
") isn't divisible by the number of vector size: ", VecSize);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 &&
params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0),
"The value of attributes don't meet the requirements.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize &&
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize &&
params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
"Configuration: Threads (", ThreadsPerBlock, "), vector size (",
VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock);
VecSize, ") is redundant for the number of channels per group: ",
params->channels_per_block);
return Status::OK();
}
};
template <typename T>
Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
groupNormNHWCSum<T>(params);
Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
0,
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
params->StreamHandle()));
GroupNormNHWCSum<T>(params);
HIP_RETURN_IF_ERROR(hipGetLastError());
groupNormNHWCScale<T>(params);
GroupNormNHWCScale<T>(params);
HIP_RETURN_IF_ERROR(hipGetLastError());
return Status::OK();
}
@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
ADD_OP_FOR_ALL_VEC_SIZE(name, 320)
template <typename T>
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCTunableParams<T>> {
public:
GroupNormNHWCTunableOp() {
this->RegisterOp(GroupNormNHWCStaticSelection<T>);
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp)
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/false>()) {
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/false>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/true>()) {
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/true>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif // USE_COMPOSABLE_KERNEL
#ifdef USE_TRITON_KERNEL
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/false>()) {
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/false>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/true>()) {
for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/true>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}

View file

@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,

View file

@ -35,7 +35,11 @@ def sigmoid_function(x):
return 1.0 / (1.0 + np.exp(-x))
def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip):
add_output = None
if has_skip:
input_x = input_x + skip_x + bias_x
add_output = input_x
n, h, w, c = input_x.shape
input_x = input_x.transpose([0, 3, 1, 2])
assert c % num_groups == 0
@ -45,46 +49,70 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
x = x.transpose([0, 2, 3, 1])
x = x * gamma + beta
if with_swish:
if with_silu:
x = x * sigmoid_function(x)
return x
return x, add_output
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func):
def run_group_norm(
batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func
):
np.random.seed(0)
width = height
input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
gamma = np.random.rand(num_channels).astype(np.float32)
beta = np.random.rand(num_channels).astype(np.float32)
# the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32)
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32)
epsilon = 1e-05
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
use_swish = swish
host_x = input_x.astype(dtype)
input_d = ke.DeviceArray(host_x)
skip_x = (
np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
if has_skip
else np.empty((0), dtype=dtype)
)
bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype)
add_output = (
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
if has_skip
else np.empty((0), dtype=dtype)
)
use_silu = silu
broadcast_skip = False
channels_per_block = 0 # Compute in params initialization
input_d = ke.DeviceArray(input_x.astype(dtype))
skip_d = ke.DeviceArray(skip_x.astype(dtype))
bias_d = ke.DeviceArray(bias_x.astype(dtype))
gamma_d = ke.DeviceArray(gamma)
beta_d = ke.DeviceArray(beta)
workspace_d = ke.DeviceArray(workspace)
y_d = ke.DeviceArray(output_y)
y_add_d = ke.DeviceArray(add_output)
f = getattr(ke, func)
my_op = f(
y_d,
workspace_d,
y_add_d,
input_d,
skip_d,
bias_d,
gamma_d,
beta_d,
workspace_d,
epsilon,
batch_size,
num_channels,
height,
width,
num_channels,
num_groups,
epsilon,
use_swish,
use_silu,
broadcast_skip,
channels_per_block,
)
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip)
y_ref = y_ref.astype(dtype)
for impl in my_op.ListOps():
if not my_op.SelectOp(impl):
@ -95,6 +123,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
y_d.UpdateHostNumpyArray()
np.testing.assert_allclose(y_ref, output_y, atol=1e-02)
if has_skip:
y_add_d_ref = y_add_d_ref.astype(dtype)
y_add_d.UpdateHostNumpyArray()
np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02)
dtypes = ["float32", "float16"]
@ -102,19 +134,21 @@ dtypes = ["float32", "float16"]
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("swish", [True])
def test_group_norm(sd_sizes, dtype, swish):
@pytest.mark.parametrize("silu", [True])
@pytest.mark.parametrize("has_skip", [True, False])
def test_group_norm(sd_sizes, dtype, silu, has_skip):
for func in dtype_to_funcs(dtype):
run_group_norm(*sd_sizes, dtype, swish, func)
run_group_norm(*sd_sizes, dtype, silu, has_skip, func)
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("swish", [True])
def test_group_norm_ck(sd_sizes, dtype, swish):
swish_suffix = "Swish" if swish else "Pass"
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
run_group_norm(*sd_sizes, dtype, swish, ck_f_name)
@pytest.mark.parametrize("silu", [True])
@pytest.mark.parametrize("has_skip", [False])
def test_group_norm_ck(sd_sizes, dtype, silu, has_skip):
silu_suffix = "Silu" if silu else "Pass"
ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name)
@dataclass
@ -136,37 +170,67 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
def profile_group_norm_func(
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func
batch_size: int,
height: int,
width: int,
num_channels: int,
num_groups: int,
dtype: str,
silu: bool,
has_skip: bool,
func,
):
np.random.seed(0)
input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
gamma = np.random.rand(num_channels).astype(np.float32)
beta = np.random.rand(num_channels).astype(np.float32)
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32)
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32)
epsilon = 0.05
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
use_swish = swish
skip_x = (
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
if has_skip
else np.empty((0), dtype=dtype)
)
bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype)
add_output = (
np.random.rand(batch_size, height, width, num_channels).astype(dtype)
if has_skip
else np.empty((0), dtype=dtype)
)
use_silu = silu
broadcast_skip = False
channels_per_block = 0 # Compute in params initialization
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip_x)
bias_d = ke.DeviceArray(bias_x)
gamma_d = ke.DeviceArray(gamma)
beta_d = ke.DeviceArray(beta)
workspace_d = ke.DeviceArray(workspace)
y_d = ke.DeviceArray(output_y)
y_add_d = ke.DeviceArray(add_output)
f = getattr(ke, func)
my_op = f(
y_d,
workspace_d,
y_add_d,
input_d,
skip_d,
bias_d,
gamma_d,
beta_d,
workspace_d,
epsilon,
batch_size,
num_channels,
height,
width,
num_channels,
num_groups,
epsilon,
use_swish,
use_silu,
broadcast_skip,
channels_per_block,
)
for impl in my_op.ListOps():
duration_ms = -1
@ -181,14 +245,14 @@ def profile_group_norm_func(
)
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True):
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func)
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func)
# ck function
swish_suffix = "Swish" if swish else "Pass"
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name)
silu_suffix = "Silu" if silu else "Pass"
ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name)
sd_profile_sizes = [
@ -227,7 +291,8 @@ if __name__ == "__main__":
group.add_argument("num_channels", type=int)
group.add_argument("num_groups", type=int)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--swish", action="store_true")
group.add_argument("--silu", action="store_true")
group.add_argument("--has_skip", action="store_true")
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:
@ -241,6 +306,7 @@ if __name__ == "__main__":
args.num_channels,
args.num_groups,
args.dtype,
args.swish,
args.silu,
args.has_skip,
args.sort,
)

View file

@ -12,17 +12,21 @@
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
namespace py = pybind11;
using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes;
namespace onnxruntime {
template <typename T, int ThreadsPerBlock, int VecSize>
class GroupNormNHWC : public IKernelExplorer {
public:
GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon,
int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu,
bool broadcast_skip, int channels_per_block)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
channels_per_block) {
type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize);
}
@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer {
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
ParamsT params_{};
contrib::rocm::GroupNormNHWCOp<T, ThreadsPerBlock, VecSize> op_{};
std::string type_string_{};
@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer {
template <typename T>
class GroupNormNHWCStaticSelection : public IKernelExplorer {
public:
GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
channels_per_block) {
type_string_ = "GroupNormNHWCStaticSelection";
}
@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
ParamsT params_{};
std::string type_string_{};
};
@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
template <typename T>
class GroupNormNHWCTunable : public IKernelExplorer {
public:
GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
channels_per_block) {
params_.TuningContext()->EnableTunableOpAndTuning();
}
@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer {
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
ParamsT params_{};
contrib::rocm::GroupNormNHWCTunableOp<T> op_{};
};
#ifdef USE_COMPOSABLE_KERNEL
template <typename T, bool WithSwish>
template <typename T, bool WithSilu>
class CKGroupNormNHWC : public IKernelExplorer {
public:
CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSwish>()) {
CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
channels_per_block) {
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSilu>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer {
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_{};
std::vector<OpT> ops_;
@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer {
#endif // USE_COMPOSABLE_KERNEL
#ifdef USE_TRITON_KERNEL
template <typename T, bool WithSwish>
template <typename T, bool WithSilu>
class GroupNormNHWCTriton : public IKernelExplorer {
public:
GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSwish>()) {
GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
channels_per_block) {
for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSilu>()) {
name_strings_.emplace_back(name);
ops_.emplace_back(std::move(op));
}
@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer {
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_{};
std::vector<OpT> ops_;
@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
int, int, int, int, int, float, bool>()) \
DeviceArray&, DeviceArray&, DeviceArray&, float, \
int, int, int, int, int, bool, bool, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
#define REGISTER_COMMON(name, type, ...) \
py::class_<type<__VA_ARGS__>>(m, name) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
int, int, int, int, int, float, bool>()) \
DeviceArray&, DeviceArray&, DeviceArray&, float, \
int, int, int, int, int, bool, bool, int>()) \
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
.def("Profile", &type<__VA_ARGS__>::Profile) \
.def("Run", &type<__VA_ARGS__>::Run) \
@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer {
#define REGISTER_OP_TYPED(name, type) \
REGISTER_COMMON(#name "_" #type, name, type)
#define REGISTER_CK(type, with_swish, swish_suffix) \
REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish)
#define REGISTER_CK(type, with_silu, silu_suffix) \
REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu)
#define REGISTER_TRITON(type, with_swish, swish_suffix) \
REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish)
#define REGISTER_TRITON(type, with_silu, silu_suffix) \
REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu)
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half);
@ -248,16 +270,16 @@ KE_REGISTER(m) {
#ifdef USE_COMPOSABLE_KERNEL
REGISTER_CK(half, false, "Pass");
REGISTER_CK(half, true, "Swish");
REGISTER_CK(half, true, "Silu");
REGISTER_CK(float, false, "Pass");
REGISTER_CK(float, true, "Swish");
REGISTER_CK(float, true, "Silu");
#endif // USE_COMPOSABLE_KERNEL
#ifdef USE_TRITON_KERNEL
REGISTER_TRITON(half, false, "Pass");
REGISTER_TRITON(half, true, "Swish");
REGISTER_TRITON(half, true, "Silu");
REGISTER_TRITON(float, false, "Pass");
REGISTER_TRITON(float, true, "Swish");
REGISTER_TRITON(float, true, "Silu");
#endif
}

View file

@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) {
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
std::array<int, 2> channels_last_values = {-1, 1};
for (const int channels_last : channels_last_values) {
if (enable_cuda) {
if (enable_cuda || enable_rocm) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
std::array<bool, 2> has_add_out_values = {true, false};
std::array<int, 2> skip_dims = {2, 4};
@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
constexpr int channels_last = 1;
for (const int skip_dim : skip_dims) {
for (const bool has_add_out : has_add_out_values) {
if (enable_cuda) {
if (enable_cuda || enable_rocm) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;

View file

@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
s = s.replace("rocm_device_prop_", "cuda_device_prop_")
s = s.replace("rocm_device_arch_", "cuda_device_arch_")
s = s.replace("HipTuningContext", "RocmTuningContext")
# We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names
# And we do this last, undoing or fixing hipify mistakes.
if "fft" in src_file_path: