[CUDA] Update GroupNorm and Add SkipGroupNorm (#18091)

* Add a new operator SkipGroupNorm to support skip and bias inputs.
* Update GroupNorm kernel to support number of channels used in SD XLrefiner.
* Add epsilon in kernel
* Add parity and performance test script
* Remove many limitations including max batch size, max number of groups, c % cPerBlock ==0 etc.

### Motivation and Context

Update GroupNorm to support SD XL Refiner and beyond.
This commit is contained in:
Tianlei Wu 2023-10-31 10:27:20 -07:00 committed by GitHub
parent 29e40987e3
commit 95f053c652
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1548 additions and 258 deletions

View file

@ -95,6 +95,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipGroupNorm">com.microsoft.SkipGroupNorm</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
* <a href="#com.microsoft.SkipSimplifiedLayerNormalization">com.microsoft.SkipSimplifiedLayerNormalization</a>
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
@ -2342,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>activation</tt> : int (required)</dt>
<dd>Activation after group normalization: 0 for None, 1 for Swish</dd>
<dd>Activation after group normalization: 0 for None, 1 for SiLU</dd>
<dt><tt>channels_last</tt> : int</dt>
<dd>1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.</dd>
<dt><tt>epsilon</tt> : float</dt>
@ -2582,6 +2583,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
@ -5083,6 +5085,72 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.SkipGroupNorm"></a><a name="com.microsoft.skipgroupnorm">**com.microsoft.SkipGroupNorm**</a>
This operator element-wise adds x, skip and bias, then apply group normalization and optional activation.
This operator transforms input according to
s = x + skip + bias
y = gamma * (s - mean) / sqrt(variance + epsilon) + beta
The input channels are separated into num_groups groups, each containing num_channels / num_groups channels.
The num_channels must be divisible by num_groups.
The mean and standard-deviation of s are calculated separately over the each group.
The weight and bias are per-channel affine transform parameter vectors of size num_channels.
The activation attribute can be used to enable activation after group normalization.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>activation</tt> : int (required)</dt>
<dd>Activation after group normalization: 0 for None, 1 for SiLU</dd>
<dt><tt>channels_last</tt> : int</dt>
<dd>1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>The epsilon value to use to avoid division by zero</dd>
<dt><tt>groups</tt> : int (required)</dt>
<dd>The number of groups of channels. It should be a divisor of the number of channels C</dd>
</dl>
#### Inputs (4 - 5)
<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
<dt><tt>gamma</tt> : M</dt>
<dd>1D gamma tensor for normalization with shape (C), where C is number of channels</dd>
<dt><tt>beta</tt> : M</dt>
<dd>1D beta tensor for normalization with shape (C), where C is number of channels</dd>
<dt><tt>skip</tt> : T</dt>
<dd>4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>1D bias tensor. Dimensions are (C), where C is number of channels</dd>
</dl>
#### Outputs (1 - 2)
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor of the same shape as X</dd>
<dt><tt>S</tt> (optional) : T</dt>
<dd>The element-wise sum of input x, skip and bias tensors. It has the same shape as X</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X, skip, bias and output Y, S types to float tensors.</dd>
<dt><tt>M</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain gamma and beta to float tensors.</dd>
</dl>
### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>
Skip and Layer Normalization Fusion

View file

@ -861,6 +861,7 @@ Do not modify directly.*
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|

View file

@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
@ -269,6 +270,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,

View file

@ -1,6 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/diffusion/group_norm.h"
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX(
GroupNorm, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
ONNX_OPERATOR_KERNEL_EX(
SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
using namespace ONNX_NAMESPACE;
namespace {
template <typename T>
struct DispatchGroupNorm {
Status operator()(cudaStream_t stream,
Tensor* output,
Tensor* add_out,
const Tensor* input,
const Tensor* skip,
const Tensor* bias,
const Tensor* gamma,
const Tensor* beta,
void* workspace,
@ -32,12 +39,17 @@ struct DispatchGroupNorm {
int height,
int width,
int num_groups,
bool use_swish_activation) {
bool use_swish_activation,
bool broadcast_skip,
int channels_per_block) {
typedef typename ToCudaType<T>::MappedType CudaT;
return LaunchGroupNormKernel<CudaT>(
stream,
reinterpret_cast<CudaT*>(output->MutableData<T>()),
add_out == nullptr ? nullptr : reinterpret_cast<CudaT*>(add_out->MutableData<T>()),
reinterpret_cast<const CudaT*>(input->Data<T>()),
skip == nullptr ? nullptr : reinterpret_cast<const CudaT*>(skip->Data<T>()),
bias == nullptr ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>()),
gamma->Data<float>(),
beta->Data<float>(),
workspace,
@ -47,13 +59,21 @@ struct DispatchGroupNorm {
height,
width,
num_groups,
use_swish_activation);
use_swish_activation,
broadcast_skip,
channels_per_block);
}
};
} // namespace
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
has_skip_ = false;
const std::string& op_name = op_info.GetKernelDef().OpName();
if (op_name == "SkipGroupNorm") {
has_skip_ = true;
}
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
ORT_ENFORCE(epsilon_ >= 0);
@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
use_swish_activation_ = (activation == 1);
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
channels_per_block_ = 0;
}
Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// Compute and cache cPerBlock using number of channels from gamma tensor shape.
if (input_idx == 1) {
auto gamma_shape = tensor.Shape();
if (gamma_shape.NumDimensions() == 1) {
channels_per_block_ = GetChannelsPerBlock(static_cast<int>(gamma_shape[0]), num_groups_);
}
}
return Status::OK();
}
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, input->Shape());
if (!channels_last_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"only the channels_last layout is supported");
}
if (!gamma->IsDataType<float>() || !beta->IsDataType<float>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GroupNorm only supports gamma and beta in float type");
}
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 4 dimensions, got ", input_dims.size());
}
// Only support NHWC format right now.
int batch_size = static_cast<int>(input_dims[0]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);
int num_channels = static_cast<int>(input_dims[3]);
if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}
const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[3]) {
if (gamma_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in gamma and input does not match");
}
@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[3]) {
if (beta_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in beta and input does not match");
}
// Input and output format is NHWC
int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[3]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);
if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}
if (context->GetUseDeterministicCompute()) {
static std::once_flag log_warning;
std::call_once(log_warning, []() {
@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
});
}
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
const Tensor* skip = nullptr;
const Tensor* bias = nullptr;
Tensor* add_out = nullptr;
bool broadcast_skip = false;
if (has_skip_) {
skip = context->Input<Tensor>(3);
bias = context->Input<Tensor>(4);
add_out = context->Output(1, input->Shape());
if (bias != nullptr) { // Bias is optional
// If provided, bias has shape (C).
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 1 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in bias and input does not match");
}
}
// Check whether skip can be broadcasted to input shape.
if (skip->Shape() != input->Shape()) {
const auto& dims = skip->Shape().GetDims();
// The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast.
const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels);
const bool b4 = (dims.size() == 4 && dims[0] == batch_size &&
dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels);
broadcast_skip = b2 || b4;
if (!broadcast_skip) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)");
}
}
}
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_),
context->GetComputeStream());
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, input, gamma, beta, workspace.get(),
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, add_out, input, skip, bias,
gamma, beta, workspace.get(),
epsilon_,
batch_size,
num_channels,
height,
width,
num_groups_,
use_swish_activation_);
use_swish_activation_,
broadcast_skip,
channels_per_block_);
}
} // namespace cuda

View file

@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel {
GroupNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) override;
private:
bool use_swish_activation_;
bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization?
float epsilon_;
int num_groups_;
bool channels_last_;
bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm
int channels_per_block_;
};
} // namespace cuda

View file

@ -16,18 +16,45 @@
*/
// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5
// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cub/cub.cuh>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
static inline int32_t divUp(int32_t m, int32_t n) {
namespace {
// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time.
constexpr static int32_t CHANNELS_PER_THREAD = 2;
constexpr static int kSizes[] = {128, 256, 320, 384, 512};
constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr static int kMaxSize = kSizes[kNumOfSizes - 1];
int NextSize(int x) {
for (size_t i = 0; i < kNumOfSizes; ++i) {
if (x <= kSizes[i]) {
return kSizes[i];
}
}
return x;
}
} // namespace
static inline int32_t DivUp(int32_t m, int32_t n) {
return (m + n - 1) / n;
}
@ -41,14 +68,14 @@ struct GroupSums {
// The sum.
float sum;
// The sum of squares.
float sumSq;
float sum_sq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq);
dst.flag = a.flag + b.flag;
return dst;
}
@ -56,54 +83,85 @@ struct GroupSumsOp {
template <typename T>
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
// The output buffer. Shape is (n, h, w, c).
T* dst;
// The input buffer. Layout NHWC.
// Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c).
T* add_out;
// The input buffer. Shape is (n, h, w, c).
T const* src;
// Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c).
T const* skip;
// Optional input buffer for bias tensor. Shape is (c).
T const* bias;
// The gamma scaling factor.
float const* gamma;
// The beta term to add in GN.
float const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
// The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups.
float* group_sum_buffer;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h;
int32_t w;
// The number of channels.
// Number of channels.
int32_t c;
// The number of groups.
// Number of groups.
int32_t groups;
// Do we apply the Swish activation function?
bool withSwish;
// Do we apply the SiLU activation function?
bool use_silu;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
// Number of activations per instance (h * w)
int32_t hw;
int32_t hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock;
int32_t cPerGroup;
// Number of activations per block
int32_t hw_per_block;
// Number of channels per block in the C dimension.
int32_t channels_per_block;
// Number of channels per group in the C dimension.
int32_t channels_per_group;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The inverse of hw*channels_per_group to compute mean of a group.
float inv_hw_channels_per_group;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
int32_t groups_per_block;
// Number of threads per block
int32_t threads_per_block;
// Epsilon to get stable variance in normalization.
float epsilon;
// Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise.
bool broadcast_skip;
// For SkipGroupNorm, it points to the intermediate result of adding skip and bias.
T* skip_workspace;
};
template <typename T>
inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq);
inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq);
template <>
inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) {
inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) {
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
template <>
inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) {
inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) {
// Fetch two channels per thread.
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
template <typename T, int32_t tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset]
template <typename T>
inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias,
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq);
template <>
inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias,
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) {
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
__half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]);
__half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]);
h2 = h2 + b;
h2 = h2 + s;
*reinterpret_cast<__half2*>(&add_out[offset]) = h2;
float2 f2 = __half22float2(h2);
sum += f2.x + f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
template <>
inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias,
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) {
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
float2 s = *reinterpret_cast<float2 const*>(&skip[skip_offset]);
float2 b = *reinterpret_cast<float2 const*>(&bias[bias_offset]);
f2.x += s.x + b.x;
f2.y += s.y + b.y;
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
sum += f2.x + f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset]
template <typename T>
inline __device__ void AddSkip(T* add_out, const T* src, const T* skip,
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq);
template <>
inline __device__ void AddSkip(half* add_out, const half* src, const half* skip,
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) {
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
__half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]);
h2 = h2 + s;
*reinterpret_cast<__half2*>(&add_out[offset]) = h2;
float2 f2 = __half22float2(h2);
sum += f2.x + f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
template <>
inline __device__ void AddSkip(float* add_out, const float* src, const float* skip,
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) {
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
float2 s = *reinterpret_cast<float2 const*>(&skip[skip_offset]);
f2.x += s.x;
f2.y += s.y;
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
sum += f2.x + f2.y;
sum_sq += f2.x * f2.x + f2.y * f2.y;
}
template <typename T, int32_t THREADS_PER_BLOCK>
__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
typedef cub::BlockScan<GroupSums, THREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
__shared__ typename BlockScan::TempStorage temp_storage;
// Allocate shared memory for the groups. We could reduce the amount of shared memory reserved.
__shared__ float2 smem[THREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The channel loaded by that thread.
int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD;
if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) {
return;
}
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
int32_t hw_begin = blockIdx.y * params.hw_per_block;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
float sum_sq = 0.F;
// Iterate over the activations to compute the sums.
if (ci < params.c) {
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwi) * params.c + ci;
UpdateSum(params.src, offset, sum, sumSq);
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hw_begin) * params.c + ci;
if (params.skip != nullptr) {
// SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c)
const int64_t bias_offset = static_cast<int64_t>(ci);
T* add_out = params.skip_workspace;
if (params.broadcast_skip) {
const int64_t skip_offset = static_cast<int64_t>(ni) * params.c + ci;
if (params.bias != nullptr) {
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq);
}
} else {
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq);
}
}
} else {
if (params.bias != nullptr) {
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq);
}
} else {
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq);
}
}
}
} else { // GroupNorm
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
UpdateSum(params.src, offset, sum, sum_sq);
}
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The group index relative to the first group within the same block.
int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group;
// The channel in the group.
int32_t cj = ci % params.channels_per_group;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq};
// Do the segmented scan.
// Do the segmented scan. InclusiveScan is not deterministic.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if (cj == params.cPerGroup - 2) { //2 channels per thread
smem[gi] = make_float2(out.sum, out.sumSq);
// Store the results for the groups in shared memory (to produce coalesced stores later).
// For each group, only the last thread of that group is picked to save sum to shared memory.
if (cj == params.channels_per_group - CHANNELS_PER_THREAD) {
smem[gi] = make_float2(out.sum, out.sum_sq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
if (threadIdx.x >= params.groups_per_block) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// The global group index.
// Use neighboring threads for coalesced write.
int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x;
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
template <typename T>
void groupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
// Make sure the values are as we expect.
ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0);
// Make sure a group does not span multiple blocks.
ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
if (gj < params.groups) {
float2 sums = smem[threadIdx.x];
const int index = (2 * ni) * params.groups + gj;
atomicAdd(&params.group_sum_buffer[index], sums.x);
atomicAdd(&params.group_sum_buffer[index + params.groups], sums.y);
}
}
template <typename T>
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish);
void GroupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = DivUp(params.c, params.channels_per_block);
// The number of blocks to compute all the activations in a given instance.
grid.y = DivUp(params.hw, params.hw_per_block);
// The number of instances.
grid.z = params.n;
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
GroupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
break;
case 192:
GroupNormNHWCSumKernel<T, 192><<<grid, 192, 0, stream>>>(params);
break;
case 160:
GroupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
break;
case 128:
GroupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
break;
case 64:
GroupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
break;
}
}
template <typename T>
__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev,
float2& gamma_f2, float2& beta_f2, bool silu);
template <>
__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev,
float2& gammaF2, float2& betaF2, bool swish) {
__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev,
float2& gamma_f2, float2& beta_f2, bool silu) {
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
f2.x = (f2.x - mean) * inv_std_dev;
f2.y = (f2.y - mean) * inv_std_dev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
f2.x = gamma_f2.x * f2.x + beta_f2.x;
f2.y = gamma_f2.y * f2.y + beta_f2.y;
// Apply Swish if needed.
if (swish) {
// Apply SiLU activation if needed.
if (silu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo
}
template <>
__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev,
float2& gammaF2, float2& betaF2, bool swish) {
__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev,
float2& gamma_f2, float2& beta_f2, bool silu) {
// Fetch two channels per thread.
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
f2.x = (f2.x - mean) * inv_std_dev;
f2.y = (f2.y - mean) * inv_std_dev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
f2.x = gamma_f2.x * f2.x + beta_f2.x;
f2.y = gamma_f2.y * f2.y + beta_f2.y;
// Apply Swish if needed.
if (swish) {
// Apply SiLU activation if needed.
if (silu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f
*reinterpret_cast<float2*>(&dst[offset]) = f2;
}
template <typename T, int32_t tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
if (ci >= params.c) {
template <typename T>
__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// The channel loaded by that thread.
int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD;
if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) {
return;
}
// The instance in the batch.
int32_t ni = blockIdx.z;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// The group that thread works on.
int32_t gi = ci / params.channels_per_group;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
float sum = 0.F, sum_sq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
const int index = (2 * ni) * params.groups + gi;
sum = params.group_sum_buffer[index];
sum_sq = params.group_sum_buffer[index + params.groups];
}
// Load gamma/beta.
float2 gammaF2 = *reinterpret_cast<float2 const*>(&params.gamma[ci]);
float2 betaF2 = *reinterpret_cast<float2 const*>(&params.beta[ci]);
// Load gamma/beta. Fetch two per thread.
float2 gamma_f2 = *reinterpret_cast<float2 const*>(&params.gamma[ci]);
float2 beta_f2 = *reinterpret_cast<float2 const*>(&params.beta[ci]);
// Compute the mean.
float mean = sum * params.invHWC;
float mean = sum * params.inv_hw_channels_per_group;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var);
float inv_std_dev = rsqrtf(var + params.epsilon);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
int32_t hw_begin = blockIdx.y * params.hw_per_block;
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
computeGroupNorm<T>(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src;
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hw_begin) * params.c + ci;
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
ComputeGroupNorm<T>(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu);
}
}
template <typename T>
void groupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect.
ORT_ENFORCE(params.c % params.cPerBlock == 0);
// Make sure a group does not span multiple blocks.
ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
void GroupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
grid.x = DivUp(params.c, params.channels_per_block);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
grid.y = DivUp(params.hw, params.hw_per_block);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCScaleKernel<T, 160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCScaleKernel<T, 256><<<grid, 256, 0, stream>>>(params);
break;
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
groupNormNHWCScaleKernel<T, 128><<<grid, 128, 0, stream>>>(params);
GroupNormNHWCScaleKernel<T><<<grid, 256, 0, stream>>>(params);
break;
case 192:
GroupNormNHWCScaleKernel<T><<<grid, 192, 0, stream>>>(params);
break;
case 160:
GroupNormNHWCScaleKernel<T><<<grid, 160, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<T, 64><<<grid, 64, 0, stream>>>(params);
GroupNormNHWCScaleKernel<T><<<grid, 128, 0, stream>>>(params);
break;
case 64:
GroupNormNHWCScaleKernel<T><<<grid, 64, 0, stream>>>(params);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
}
int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) {
int32_t max_divisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) {
max_divisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) {
max_divisor = divisor2;
}
}
}
return maxDivisor;
return max_divisor;
}
// Find proper channels per block based on a cost function: The cost is number of channels corresponding to
// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has
// work to do so it is ideal case.
int FindChannelsPerBlock(int num_channels, int channels_per_group) {
int min_cost = -1;
int best_candidate = -1;
for (size_t i = kNumOfSizes; i > 0; --i) {
if (kSizes[i - 1] < channels_per_group) {
break;
}
int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group;
int blocks = (num_channels + channels_per_block - 1) / channels_per_block;
int cost = blocks * kSizes[i - 1] - num_channels;
if (cost == 0) {
return channels_per_block;
}
if (min_cost == -1 || cost < min_cost) {
min_cost = cost;
best_candidate = channels_per_block;
}
}
return best_candidate;
}
int GetChannelsPerBlock(int num_channels, int num_groups) {
int32_t channels_per_group = num_channels / num_groups;
int32_t channels_per_block = channels_per_group;
if (channels_per_group < kMaxSize / 2) {
channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group);
}
return channels_per_block;
}
template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
T* output,
T* add_out,
const T* input,
const T* skip,
const T* bias,
const float* gamma,
const float* beta,
void* workspace,
@ -397,79 +588,94 @@ Status LaunchGroupNormKernel(
int height,
int width,
int num_groups,
bool use_swish_activation) {
if (batch_size > static_cast<int>(kMaxGroupNormBatchSize)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"only support batch_size <= 32. Got", batch_size);
}
if (num_groups != static_cast<int>(kGroupNormNumberOfGroups)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"only num_groups=32 is supported. Got", num_groups);
}
bool use_silu,
bool broadcast_skip,
int channels_per_block) {
GroupNormNHWCParams<T> params;
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (num_channels) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
int32_t channels_per_group = num_channels / num_groups;
// channels_per_block is computed in PrePack.
// If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here.
if (channels_per_block < channels_per_group) {
channels_per_block = GetChannelsPerBlock(num_channels, num_groups);
}
params.withSwish = use_swish_activation;
// TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases
if (channels_per_block % channels_per_group != 0 ||
channels_per_block > kMaxSize ||
(channels_per_group % CHANNELS_PER_THREAD != 0)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GroupNorm in CUDA does not support the input: n=", batch_size,
" h=", height,
" w=", width,
" c=", num_channels,
" groups=", num_groups);
}
params.use_silu = use_silu;
params.dst = output;
params.add_out = add_out;
params.src = input;
params.skip = skip;
params.bias = bias;
params.gamma = gamma;
params.beta = beta;
params.redBuffer = reinterpret_cast<float*>(workspace);
params.group_sum_buffer = reinterpret_cast<float*>(workspace);
params.n = batch_size;
params.h = height;
params.w = width;
params.c = num_channels;
params.groups = num_groups;
params.hw = params.h * params.w;
const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW);
params.hwPerBlock = divUp(params.hw, blocksPerHW);
params.cPerBlock = cPerBlock;
params.cPerGroup = params.c / params.groups;
// This will allocate as many blocks as possible to partition HW.
// For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw.
// TODO: tune this logic to find proper blocks when hw is small.
constexpr int32_t max_blocks_per_hw = 1024;
const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw);
params.hw_per_block = DivUp(params.hw, blocks_per_hw);
params.channels_per_block = channels_per_block;
params.channels_per_group = channels_per_group;
params.hwc = params.hw * params.c;
params.invHWC = 1.F / (float)(params.hw * params.cPerGroup);
params.groupsPerBlock = cPerBlock / params.cPerGroup;
params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group);
params.groups_per_block = channels_per_block / params.channels_per_group;
params.epsilon = epsilon;
params.broadcast_skip = broadcast_skip;
// Workspace for SkipGroupNorm to store intermediate results of src+skip+bias.
params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst;
params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD;
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));
GroupNormNHWCSum<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
DUMP_TENSOR_INIT();
DUMP_TENSOR("input", input, batch_size, num_channels, height * width);
DUMP_TENSOR("gamma", gamma, 1, num_channels);
DUMP_TENSOR("beta", beta, 1, num_channels);
cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream);
groupNormNHWCSum<T>(params, stream);
DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2);
DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups);
GroupNormNHWCScale<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
groupNormNHWCScale<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
DUMP_TENSOR("output", output, batch_size, num_channels, height * width);
return Status::OK();
}
template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output,
const half* input, const float* gamma, const float* beta, void* workspace,
template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output, half* add_out,
const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
int height, int width, int num_groups, bool silu,
bool broadcast_skip, int channels_per_block);
template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output,
const float* input, const float* gamma, const float* beta, void* workspace,
template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output, float* add_out,
const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
int height, int width, int num_groups, bool silu,
bool broadcast_skip, int channels_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -12,29 +12,33 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
constexpr size_t kMaxGroupNormBatchSize = 32;
constexpr size_t kGroupNormNumberOfGroups = 32;
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) {
// Two buffers for sum and squared sum
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
return (sizeof(float) * 2) * batch_size * num_groups;
}
int GetChannelsPerBlock(int num_channels, int num_groups);
template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
T* output, // normalized output tensor
const T* input, // input tensor
const float* gamma, // gamma (also known as weight or scale)
const float* beta, // beta (also known as bias)
void* workspace, // Work space
float epsilon, // epsilon used normalization
int batch_size, // N
int num_channels, // C
int height, // H
int width, // W
int num_groups, // number of groups
bool use_swish_activation // Whether there is Swish activation after group normalization
T* output, // normalized output tensor. Shape is (n, h, w, c)
T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c)
const T* input, // input tensor. Shape is (n, h, w, c)
const T* skip, // optional skip tensor. Shape is (n, h, w, c)
const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm
const float* gamma, // gamma (also known as weight or scale). Shape is (c)
const float* beta, // beta (also known as bias). Shape is (c)
void* workspace, // Work space
float epsilon, // epsilon used normalization
int batch_size, // N
int num_channels, // C
int height, // H
int width, // W
int num_groups, // number of groups
bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization
bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast.
int channels_per_block // Pre-computed channels per block.
);
} // namespace cuda

View file

@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}
Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
return Status::OK();
}
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);

View file

@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"The number of groups of channels. It should be a divisor of the number of channels C",
AttributeProto::INT)
.Attr("activation",
"Activation after group normalization: 0 for None, 1 for Swish",
"Activation after group normalization: 0 for None, 1 for SiLU",
AttributeProto::INT)
.Attr("channels_last",
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.",
@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
constexpr const char* SkipGroupNorm_ver1_doc = R"DOC(
This operator element-wise adds x, skip and bias, then apply group normalization and optional activation.
This operator transforms input according to
s = x + skip + bias
y = gamma * (s - mean) / sqrt(variance + epsilon) + beta
The input channels are separated into num_groups groups, each containing num_channels / num_groups channels.
The num_channels must be divisible by num_groups.
The mean and standard-deviation of s are calculated separately over the each group.
The weight and bias are per-channel affine transform parameter vectors of size num_channels.
The activation attribute can be used to enable activation after group normalization.
)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
SkipGroupNorm, 1,
OpSchema()
.SetDoc(SkipGroupNorm_ver1_doc)
.Attr("epsilon", "The epsilon value to use to avoid division by zero",
AttributeProto::FLOAT, static_cast<float>(1e-5))
.Attr("groups",
"The number of groups of channels. It should be a divisor of the number of channels C",
AttributeProto::INT)
.Attr("activation",
"Activation after group normalization: 0 for None, 1 for SiLU",
AttributeProto::INT)
.Attr("channels_last",
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0,
"X",
"Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 "
" or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels,"
" and H and W are the height and width of the data",
"T")
.Input(1,
"gamma",
"1D gamma tensor for normalization with shape (C), where C is number of channels",
"M")
.Input(2,
"beta",
"1D beta tensor for normalization with shape (C), where C is number of channels",
"M")
.Input(3,
"skip",
"4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)",
"T")
.Input(4,
"bias",
"1D bias tensor. Dimensions are (C), where C is number of channels",
"T",
OpSchema::Optional)
.Output(0,
"Y",
"The output tensor of the same shape as X",
"T")
.Output(1,
"S",
"The element-wise sum of input x, skip and bias tensors. It has the same shape as X",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.")
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (ctx.getNumOutputs() > 1) {
propagateElemTypeFromInputToOutput(ctx, 0, 1);
}
if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
if (ctx.getNumOutputs() > 1) {
propagateShapeFromInputToOutput(ctx, 0, 1);
}
}
}));
constexpr const char* BiasSplitGelu_ver1_doc = R"DOC(
A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
tensor multiplies the Gelu activation result of right tensor.

View file

@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
@ -205,6 +206,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());

View file

@ -200,6 +200,7 @@ class SymbolicShapeInference:
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"MultiHeadAttention": self._infer_MultiHeadAttention,
@ -2376,6 +2377,11 @@ class SymbolicShapeInference:
def _infer_GroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node)
def _infer_SkipGroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node, 0, 0)
if len(node.output) > 1:
self._propagate_shape_and_type(node, 0, 1)
def _infer_BiasSplitGelu(self, node): # noqa: N802
input_shape = self._get_shape(node, 0)
bias_shape = self._get_shape(node, 1)

View file

@ -1,6 +1,6 @@
import logging
from collections import OrderedDict
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple, Union
import numpy
import torch
@ -229,7 +229,7 @@ class CudaSession:
del self.io_binding
del self.ort_session
def allocate_buffers(self, shape_dict: Dict[str, tuple]):
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
for name, shape in shape_dict.items():

View file

@ -0,0 +1,286 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <random>
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/framework/test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
using namespace std;
namespace onnxruntime {
namespace test {
TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) {
constexpr int64_t B = 2;
constexpr int64_t C = 16;
constexpr int64_t H = 2;
constexpr int64_t W = 2;
std::vector<int64_t> dims_nhwc{B, H, W, C};
std::vector<float> input_data_nhwc = {
-0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f,
-0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f,
-0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f,
-2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f,
1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f,
-0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f,
-0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f,
-1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f,
-0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f,
1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f,
0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f,
0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f,
0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f,
-0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f,
1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f,
-0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f};
std::vector<float> gamma_data = {
0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f,
0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f};
std::vector<float> beta_data = {
-1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f,
-1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f};
std::vector<float> skip_data_nhwc = {
0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f,
-0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f,
0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f,
-0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f,
0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f,
-0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f,
-0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f,
0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f,
-0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f,
0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f,
-0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f,
-1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f,
-0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f,
0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f,
0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f,
-0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f};
std::vector<float> bias_data = {
-0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f,
1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f};
std::vector<float> norm_data_nhwc = {
-1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f,
1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f,
-0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f,
0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f,
-1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f,
-0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f,
2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f,
-1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f,
-1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f,
1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f,
-0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f,
-0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f,
-0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f,
-0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f,
0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f,
-1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f,
-1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f,
1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f,
-0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f,
-2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f,
-1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f,
-0.542480f, 1.990234f};
std::vector<float> add_out_data_nhwc = {
-0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f,
0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f,
-1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f,
-1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f,
2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f,
-0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f,
-1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f,
0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f,
-1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f,
2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f,
-0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f,
-0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f,
-1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f,
1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f,
1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f,
0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f};
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
std::array<int, 2> channels_last_values = {-1, 1};
for (const int channels_last : channels_last_values) {
if (enable_cuda) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}
OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 4);
test.AddAttribute<int64_t>("activation", 0);
// We interpret channels_last==-1 as the attribute not being provided
if (channels_last != -1) {
test.AddAttribute<int64_t>("channels_last", channels_last);
}
test.AddInput<MLFloat16>("X", dims_nhwc, ToFloat16(input_data_nhwc));
test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);
test.AddInput<MLFloat16>("skip", dims_nhwc, ToFloat16(skip_data_nhwc));
test.AddInput<MLFloat16>("bias", {C}, ToFloat16(bias_data));
constexpr float rel_error = 0.0f;
constexpr float abs_error = 0.02f;
test.AddOutput<MLFloat16>("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error);
test.AddOutput<MLFloat16>("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
constexpr int64_t B = 1;
constexpr int64_t C = 64;
constexpr int64_t H = 1;
constexpr int64_t W = 1;
std::vector<int64_t> dims_nhwc{B, H, W, C};
std::vector<float> input_data_nhwc = {
0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f,
-0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f,
0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f,
-1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f,
-0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f,
-1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f,
-1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f,
-0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f};
std::vector<float> gamma_data = {
0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f,
0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f,
0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f,
1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f,
-0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f,
0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f,
0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f,
-0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f};
std::vector<float> beta_data = {
-0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f,
1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f,
-2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f,
-1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f,
-1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f,
-0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f,
-1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f,
-0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f};
std::vector<float> skip_data = {
0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f,
0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f,
-0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f,
-0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f,
-0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f,
-0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f,
-1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f,
-1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f,
1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f,
-0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f,
1.248047f, -2.515625f, 0.080383f, 0.376221f};
std::vector<float> norm_data_nhwc = {
0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f,
1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f,
-2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f,
-2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f,
-0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f,
-0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f,
-1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f,
-0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f};
std::vector<float> add_out_data_nhwc = {
1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f,
1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f,
-0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f,
1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f,
-1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f,
-0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f,
-0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f,
-0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f,
0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f,
0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f,
1.401367f, -3.828125f, 2.292969f, 1.061523f};
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
std::array<bool, 2> has_add_out_values = {true, false};
std::array<int, 2> skip_dims = {2, 4};
constexpr int channels_last = 1;
for (const int skip_dim : skip_dims) {
for (const bool has_add_out : has_add_out_values) {
if (enable_cuda) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}
OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 8);
test.AddAttribute<int64_t>("activation", 0);
// We interpret channels_last==-1 as the attribute not being provided
if (channels_last != -1) {
test.AddAttribute<int64_t>("channels_last", channels_last);
}
test.AddInput<MLFloat16>("X", dims_nhwc, ToFloat16(input_data_nhwc));
test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);
if (skip_dim == 2) {
test.AddInput<MLFloat16>("skip", {B, C}, ToFloat16(skip_data));
} else {
test.AddInput<MLFloat16>("skip", {B, 1, 1, C}, ToFloat16(skip_data));
}
// no bias
constexpr float rel_error = 0.0f;
constexpr float abs_error = 0.02f;
test.AddOutput<MLFloat16>("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error);
if (has_add_out) {
test.AddOutput<MLFloat16>("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error);
}
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,541 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import statistics
from dataclasses import dataclass
from enum import Enum
from time import perf_counter
from typing import Optional, Tuple
import numpy
import torch
from onnx import TensorProto, helper
from onnxruntime import InferenceSession
from onnxruntime.transformers.io_binding_helper import CudaSession
torch.manual_seed(0)
class GroupNormOpType(Enum):
GROUP_NORM = 1
SKIP_GROUP_NORM = 2
@dataclass
class GroupNormConfig:
batch_size: int
height: int
width: int
channels: int
epsilon: float = 1e-5
num_groups: int = 32
activation: bool = False
channels_last: bool = True
fp16: bool = False
op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM
has_bias: bool = False
has_add_out: bool = False
broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C)
def get_skip_symbolic_shape(self):
skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]}
return skip_shape[self.broadcast_skip]
def get_skip_shape(self):
skip_shape = {
0: [self.batch_size, self.height, self.width, self.channels],
2: [self.batch_size, self.channels],
4: [self.batch_size, 1, 1, self.channels],
}
return skip_shape[self.broadcast_skip]
def broadcast(self, skip: torch.Tensor):
if self.broadcast_skip == 2:
return skip.reshape(self.batch_size, 1, 1, self.channels)
return skip
@staticmethod
def create(
b: int,
h: int,
w: int,
c: int,
fp16: bool = False,
activation: bool = False,
template: int = 0,
num_groups: int = 32,
):
if template == 0:
return GroupNormConfig(
b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups
)
if template == 1:
return GroupNormConfig(
b,
h,
w,
c,
fp16=fp16,
activation=activation,
op_type=GroupNormOpType.SKIP_GROUP_NORM,
has_bias=True,
has_add_out=True,
broadcast_skip=0,
num_groups=num_groups,
)
if template == 2:
return GroupNormConfig(
b,
h,
w,
c,
fp16=fp16,
activation=activation,
op_type=GroupNormOpType.SKIP_GROUP_NORM,
has_bias=False,
has_add_out=False,
broadcast_skip=2,
num_groups=num_groups,
)
if template == 3:
return GroupNormConfig(
b,
h,
w,
c,
fp16=fp16,
activation=activation,
op_type=GroupNormOpType.SKIP_GROUP_NORM,
has_bias=True,
has_add_out=False,
broadcast_skip=4,
num_groups=num_groups,
)
if template == 4: # No bias
return GroupNormConfig(
b,
h,
w,
c,
fp16=fp16,
activation=activation,
op_type=GroupNormOpType.SKIP_GROUP_NORM,
has_bias=False,
has_add_out=True,
broadcast_skip=0,
num_groups=num_groups,
)
if template == 5: # No bias, no add_out
return GroupNormConfig(
b,
h,
w,
c,
fp16=fp16,
activation=activation,
op_type=GroupNormOpType.SKIP_GROUP_NORM,
has_bias=False,
has_add_out=False,
broadcast_skip=0,
num_groups=num_groups,
)
return None
def create_group_norm_graph(config: GroupNormConfig) -> bytes:
inputs = ["input", "gamma", "beta"]
outputs = ["output"]
op_type = "GroupNorm"
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
op_type = "SkipGroupNorm"
inputs = [*inputs, "skip"]
if config.has_bias:
inputs = [*inputs, "bias"]
if config.has_add_out:
outputs = [*outputs, "add_out"]
nodes = [
helper.make_node(
op_type,
inputs,
outputs,
op_type + "_0",
activation=int(config.activation),
channels_last=int(config.channels_last),
epsilon=config.epsilon,
groups=config.num_groups,
domain="com.microsoft",
),
]
float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT
input_shapes = [
helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]),
]
output_shapes = [
helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]),
]
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
input_shapes = [
*input_shapes,
helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()),
]
if config.has_bias:
input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])]
if config.has_add_out:
output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])]
graph = helper.make_graph(
nodes,
"Group_Norm_Graph",
input_shapes,
output_shapes,
)
model = helper.make_model(graph)
return model.SerializeToString()
def group_norm_ort(
src: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
skip: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
config: GroupNormConfig,
measure_latency=False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]:
onnx_model_str = create_group_norm_graph(config)
ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"])
session = CudaSession(ort_session, device=torch.device("cuda:0"))
io_shape = {
"input": [config.batch_size, config.height, config.width, config.channels],
"gamma": [config.channels],
"beta": [config.channels],
"output": [config.batch_size, config.height, config.width, config.channels],
}
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
io_shape["skip"] = config.get_skip_shape()
if config.has_bias:
io_shape["bias"] = [config.channels]
if config.has_add_out:
io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels]
session.allocate_buffers(io_shape)
ort_inputs = {
"input": src,
"gamma": gamma,
"beta": beta,
}
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
ort_inputs["skip"] = skip
if config.has_bias:
ort_inputs["bias"] = bias
ort_outputs = session.infer(ort_inputs)
output = ort_outputs["output"]
add_out = (
ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None
)
if measure_latency:
latency_list = []
for _ in range(10000):
start_time = perf_counter()
session.infer(ort_inputs)
end_time = perf_counter()
latency_list.append(end_time - start_time)
average_latency = statistics.mean(latency_list)
return output, add_out, average_latency
return output, add_out, None
def group_norm_torch(
src: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
skip: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
config: GroupNormConfig,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
add_out = src
if skip is not None:
assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM
add_out = add_out + config.broadcast(skip)
if bias is not None:
assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM
add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0])
x = add_out
if config.channels_last:
x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW
weight = gamma.to(x.dtype)
bias = beta.to(x.dtype)
output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon)
if config.activation:
torch.nn.functional.silu(output, inplace=True)
if config.channels_last:
output = output.permute(0, 2, 3, 1) # from NCHW to NHWC
return output, add_out
def print_tensor(name, tensor):
# Print in the format that could be directly added to unit tests in C++.
torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000)
print(name)
if tensor is not None:
print("shape", tensor.shape)
text = str(tensor.clone().flatten())
print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,"))
else:
print(tensor)
def run_parity(config, measure_latency=True, verbose=False):
float_type = torch.float16 if config.fp16 else torch.float32
input_tensor = torch.randn(
config.batch_size,
config.height,
config.width,
config.channels,
device="cuda",
dtype=float_type,
requires_grad=False,
)
gamma = torch.randn(
config.channels,
device="cuda",
dtype=torch.float32,
requires_grad=False,
)
beta = torch.randn(
config.channels,
device="cuda",
dtype=torch.float32,
requires_grad=False,
)
skip = None
bias = None
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
skip = torch.randn(
*config.get_skip_shape(),
device="cuda",
dtype=float_type,
requires_grad=False,
)
if config.has_bias:
bias = torch.randn(
config.channels,
device="cuda",
dtype=float_type,
requires_grad=False,
)
if verbose:
print(config)
print_tensor("input", input_tensor)
print_tensor("gamma", gamma)
print_tensor("beta", beta)
print_tensor("skip", skip)
print_tensor("bias", bias)
out_ort, ort_add_out, latency = group_norm_ort(
input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency
)
if verbose:
print_tensor("out_ort", out_ort)
print_tensor("ort_add_out", ort_add_out)
torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config)
if verbose:
print_tensor("torch_out", torch_out)
print_tensor("torch_add_out", torch_add_out)
average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy()))
is_close = numpy.allclose(
out_ort.detach().cpu().numpy(),
torch_out.detach().cpu().numpy(),
rtol=1e-1 if config.fp16 else 1e-3,
atol=1e-1 if config.fp16 else 1e-3,
equal_nan=True,
)
is_add_out_close = (
numpy.allclose(
ort_add_out.detach().cpu().numpy(),
torch_add_out.detach().cpu().numpy(),
rtol=1e-1 if config.fp16 else 1e-3,
atol=1e-1 if config.fp16 else 1e-3,
equal_nan=True,
)
if ort_add_out is not None
else ""
)
# Compare results
print(
config.op_type.name,
" B:",
config.batch_size,
" H:",
config.height,
" W:",
config.width,
" C:",
config.channels,
" G:",
config.num_groups,
" activation:",
int(config.activation),
" channels_last:",
int(config.channels_last),
" fp16:",
int(config.fp16),
f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "",
" AvgDiff:",
average_diff,
" Pass:",
is_close,
is_add_out_close,
)
def get_latent_height_width():
default_size = [(512, 512), (768, 768), (1024, 1024)]
small_img_size = [(512, 768), (768, 512)]
xl_img_size = [
(1152, 896),
(896, 1152),
(1216, 832),
(832, 1216),
(1344, 768),
(768, 1344),
(1536, 640),
(640, 1536),
]
return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size]
def get_channels():
return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304]
def run_activation(template: int, fp16, measure_latency=False):
print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32")
for b in [2]:
for h, w in get_latent_height_width():
for c in get_channels():
config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template)
run_parity(config, measure_latency=measure_latency)
def run_no_activation(template: int, fp16, measure_latency=False):
print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32")
for b in [1, 2, 4]:
for h, w in get_latent_height_width():
for c in get_channels():
config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template)
run_parity(config, measure_latency=measure_latency)
def run_all_groups(template: int, fp16, measure_latency=False):
group_sizes = [1, 2, 4, 8, 16, 32]
print("Test GroupNorm for different group sizes:", group_sizes)
for group_size in group_sizes:
for h, w in get_latent_height_width()[:3]:
for c in get_channels()[:2]:
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template)
run_parity(config, measure_latency=measure_latency)
def run_odd_channels(template: int, fp16, measure_latency=False):
# Test some random number of channels that can be divisible by 2 * num_groups
for h, w in get_latent_height_width():
for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]:
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template)
run_parity(config, measure_latency=measure_latency)
def run_small_inputs(template: int, fp16):
config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template)
run_parity(config, measure_latency=False)
config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template)
run_parity(config, measure_latency=False)
config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template)
run_parity(config, measure_latency=False)
def run_performance(fp16):
# Run perf test to tune parameters for given number of channels.
for h, w in get_latent_height_width()[:3]:
for c in get_channels():
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0)
run_parity(config, measure_latency=True)
def run_all(template: int):
for fp16 in [True, False]:
run_small_inputs(template, fp16)
run_odd_channels(template, fp16)
run_all_groups(template, fp16)
run_activation(template, fp16)
run_no_activation(template, fp16)
def run_not_implemented():
# Expect failure. Check whether the error message is expected.
try:
config = GroupNormConfig(1, 2, 2, 513, num_groups=3)
run_parity(config)
except RuntimeError as e:
assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e)
def main():
run_performance(True)
run_not_implemented()
for template in range(6):
run_all(template)
if __name__ == "__main__":
main()