mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
[CUDA] Update GroupNorm and Add SkipGroupNorm (#18091)
* Add a new operator SkipGroupNorm to support skip and bias inputs. * Update GroupNorm kernel to support number of channels used in SD XLrefiner. * Add epsilon in kernel * Add parity and performance test script * Remove many limitations including max batch size, max number of groups, c % cPerBlock ==0 etc. ### Motivation and Context Update GroupNorm to support SD XL Refiner and beyond.
This commit is contained in:
parent
29e40987e3
commit
95f053c652
14 changed files with 1548 additions and 258 deletions
|
|
@ -95,6 +95,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
|
||||
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
|
||||
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
|
||||
* <a href="#com.microsoft.SkipGroupNorm">com.microsoft.SkipGroupNorm</a>
|
||||
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
|
||||
* <a href="#com.microsoft.SkipSimplifiedLayerNormalization">com.microsoft.SkipSimplifiedLayerNormalization</a>
|
||||
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
|
||||
|
|
@ -2342,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
<dl>
|
||||
<dt><tt>activation</tt> : int (required)</dt>
|
||||
<dd>Activation after group normalization: 0 for None, 1 for Swish</dd>
|
||||
<dd>Activation after group normalization: 0 for None, 1 for SiLU</dd>
|
||||
<dt><tt>channels_last</tt> : int</dt>
|
||||
<dd>1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.</dd>
|
||||
<dt><tt>epsilon</tt> : float</dt>
|
||||
|
|
@ -2582,6 +2583,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
|
@ -5083,6 +5085,72 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.SkipGroupNorm"></a><a name="com.microsoft.skipgroupnorm">**com.microsoft.SkipGroupNorm**</a>
|
||||
|
||||
This operator element-wise adds x, skip and bias, then apply group normalization and optional activation.
|
||||
|
||||
This operator transforms input according to
|
||||
s = x + skip + bias
|
||||
y = gamma * (s - mean) / sqrt(variance + epsilon) + beta
|
||||
|
||||
The input channels are separated into num_groups groups, each containing num_channels / num_groups channels.
|
||||
The num_channels must be divisible by num_groups.
|
||||
The mean and standard-deviation of s are calculated separately over the each group.
|
||||
The weight and bias are per-channel affine transform parameter vectors of size num_channels.
|
||||
|
||||
The activation attribute can be used to enable activation after group normalization.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>activation</tt> : int (required)</dt>
|
||||
<dd>Activation after group normalization: 0 for None, 1 for SiLU</dd>
|
||||
<dt><tt>channels_last</tt> : int</dt>
|
||||
<dd>1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.</dd>
|
||||
<dt><tt>epsilon</tt> : float</dt>
|
||||
<dd>The epsilon value to use to avoid division by zero</dd>
|
||||
<dt><tt>groups</tt> : int (required)</dt>
|
||||
<dd>The number of groups of channels. It should be a divisor of the number of channels C</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (4 - 5)
|
||||
|
||||
<dl>
|
||||
<dt><tt>X</tt> : T</dt>
|
||||
<dd>Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
|
||||
<dt><tt>gamma</tt> : M</dt>
|
||||
<dd>1D gamma tensor for normalization with shape (C), where C is number of channels</dd>
|
||||
<dt><tt>beta</tt> : M</dt>
|
||||
<dd>1D beta tensor for normalization with shape (C), where C is number of channels</dd>
|
||||
<dt><tt>skip</tt> : T</dt>
|
||||
<dd>4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)</dd>
|
||||
<dt><tt>bias</tt> (optional) : T</dt>
|
||||
<dd>1D bias tensor. Dimensions are (C), where C is number of channels</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 2)
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T</dt>
|
||||
<dd>The output tensor of the same shape as X</dd>
|
||||
<dt><tt>S</tt> (optional) : T</dt>
|
||||
<dd>The element-wise sum of input x, skip and bias tensors. It has the same shape as X</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
|
||||
<dd>Constrain input X, skip, bias and output Y, S types to float tensors.</dd>
|
||||
<dt><tt>M</tt> : tensor(float16), tensor(float)</dt>
|
||||
<dd>Constrain gamma and beta to float tensors.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>
|
||||
|
||||
Skip and Layer Normalization Fusion
|
||||
|
|
|
|||
|
|
@ -861,6 +861,7 @@ Do not modify directly.*
|
|||
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
|
||||
|
|
@ -269,6 +270,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/diffusion/group_norm.h"
|
||||
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
|
||||
|
|
@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
GroupNorm, kMSDomain, 1, kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct DispatchGroupNorm {
|
||||
Status operator()(cudaStream_t stream,
|
||||
Tensor* output,
|
||||
Tensor* add_out,
|
||||
const Tensor* input,
|
||||
const Tensor* skip,
|
||||
const Tensor* bias,
|
||||
const Tensor* gamma,
|
||||
const Tensor* beta,
|
||||
void* workspace,
|
||||
|
|
@ -32,12 +39,17 @@ struct DispatchGroupNorm {
|
|||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_swish_activation) {
|
||||
bool use_swish_activation,
|
||||
bool broadcast_skip,
|
||||
int channels_per_block) {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
return LaunchGroupNormKernel<CudaT>(
|
||||
stream,
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()),
|
||||
add_out == nullptr ? nullptr : reinterpret_cast<CudaT*>(add_out->MutableData<T>()),
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()),
|
||||
skip == nullptr ? nullptr : reinterpret_cast<const CudaT*>(skip->Data<T>()),
|
||||
bias == nullptr ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>()),
|
||||
gamma->Data<float>(),
|
||||
beta->Data<float>(),
|
||||
workspace,
|
||||
|
|
@ -47,13 +59,21 @@ struct DispatchGroupNorm {
|
|||
height,
|
||||
width,
|
||||
num_groups,
|
||||
use_swish_activation);
|
||||
use_swish_activation,
|
||||
broadcast_skip,
|
||||
channels_per_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
|
||||
has_skip_ = false;
|
||||
const std::string& op_name = op_info.GetKernelDef().OpName();
|
||||
if (op_name == "SkipGroupNorm") {
|
||||
has_skip_ = true;
|
||||
}
|
||||
|
||||
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
|
||||
|
|
@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
|
|||
use_swish_activation_ = (activation == 1);
|
||||
|
||||
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
|
||||
|
||||
channels_per_block_ = 0;
|
||||
}
|
||||
|
||||
Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/,
|
||||
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
|
||||
is_packed = false;
|
||||
|
||||
// Compute and cache cPerBlock using number of channels from gamma tensor shape.
|
||||
if (input_idx == 1) {
|
||||
auto gamma_shape = tensor.Shape();
|
||||
if (gamma_shape.NumDimensions() == 1) {
|
||||
channels_per_block_ = GetChannelsPerBlock(static_cast<int>(gamma_shape[0]), num_groups_);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
||||
|
|
@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
|||
Tensor* output = context->Output(0, input->Shape());
|
||||
|
||||
if (!channels_last_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"only the channels_last layout is supported");
|
||||
}
|
||||
|
||||
if (!gamma->IsDataType<float>() || !beta->IsDataType<float>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"GroupNorm only supports gamma and beta in float type");
|
||||
}
|
||||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 4 dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
// Only support NHWC format right now.
|
||||
int batch_size = static_cast<int>(input_dims[0]);
|
||||
int height = static_cast<int>(input_dims[1]);
|
||||
int width = static_cast<int>(input_dims[2]);
|
||||
int num_channels = static_cast<int>(input_dims[3]);
|
||||
|
||||
if (num_channels % num_groups_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"number of channels should be divisiable by num_groups");
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != input_dims[3]) {
|
||||
if (gamma_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in gamma and input does not match");
|
||||
}
|
||||
|
|
@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != input_dims[3]) {
|
||||
if (beta_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in beta and input does not match");
|
||||
}
|
||||
|
||||
// Input and output format is NHWC
|
||||
int batch_size = static_cast<int>(input_dims[0]);
|
||||
int num_channels = static_cast<int>(input_dims[3]);
|
||||
int height = static_cast<int>(input_dims[1]);
|
||||
int width = static_cast<int>(input_dims[2]);
|
||||
|
||||
if (num_channels % num_groups_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"number of channels should be divisiable by num_groups");
|
||||
}
|
||||
|
||||
if (context->GetUseDeterministicCompute()) {
|
||||
static std::once_flag log_warning;
|
||||
std::call_once(log_warning, []() {
|
||||
|
|
@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
|||
});
|
||||
}
|
||||
|
||||
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
|
||||
const Tensor* skip = nullptr;
|
||||
const Tensor* bias = nullptr;
|
||||
Tensor* add_out = nullptr;
|
||||
|
||||
bool broadcast_skip = false;
|
||||
if (has_skip_) {
|
||||
skip = context->Input<Tensor>(3);
|
||||
bias = context->Input<Tensor>(4);
|
||||
add_out = context->Output(1, input->Shape());
|
||||
|
||||
if (bias != nullptr) { // Bias is optional
|
||||
// If provided, bias has shape (C).
|
||||
const auto& bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"bias is expected to have 1 dimension, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in bias and input does not match");
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether skip can be broadcasted to input shape.
|
||||
if (skip->Shape() != input->Shape()) {
|
||||
const auto& dims = skip->Shape().GetDims();
|
||||
// The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast.
|
||||
const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels);
|
||||
const bool b4 = (dims.size() == 4 && dims[0] == batch_size &&
|
||||
dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels);
|
||||
broadcast_skip = b2 || b4;
|
||||
if (!broadcast_skip) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_),
|
||||
context->GetComputeStream());
|
||||
|
||||
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
|
||||
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, input, gamma, beta, workspace.get(),
|
||||
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, add_out, input, skip, bias,
|
||||
gamma, beta, workspace.get(),
|
||||
epsilon_,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_groups_,
|
||||
use_swish_activation_);
|
||||
use_swish_activation_,
|
||||
broadcast_skip,
|
||||
channels_per_block_);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel {
|
|||
GroupNorm(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
bool& is_packed, PrePackedWeights* prepacked_weights) override;
|
||||
|
||||
private:
|
||||
bool use_swish_activation_;
|
||||
bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization?
|
||||
float epsilon_;
|
||||
int num_groups_;
|
||||
bool channels_last_;
|
||||
bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm
|
||||
int channels_per_block_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -16,18 +16,45 @@
|
|||
*/
|
||||
|
||||
// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5
|
||||
// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
static inline int32_t divUp(int32_t m, int32_t n) {
|
||||
namespace {
|
||||
|
||||
// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time.
|
||||
constexpr static int32_t CHANNELS_PER_THREAD = 2;
|
||||
|
||||
constexpr static int kSizes[] = {128, 256, 320, 384, 512};
|
||||
constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
|
||||
constexpr static int kMaxSize = kSizes[kNumOfSizes - 1];
|
||||
|
||||
int NextSize(int x) {
|
||||
for (size_t i = 0; i < kNumOfSizes; ++i) {
|
||||
if (x <= kSizes[i]) {
|
||||
return kSizes[i];
|
||||
}
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static inline int32_t DivUp(int32_t m, int32_t n) {
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
|
|
@ -41,14 +68,14 @@ struct GroupSums {
|
|||
// The sum.
|
||||
float sum;
|
||||
// The sum of squares.
|
||||
float sumSq;
|
||||
float sum_sq;
|
||||
};
|
||||
|
||||
struct GroupSumsOp {
|
||||
inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
|
||||
GroupSums dst;
|
||||
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
|
||||
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
|
||||
dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq);
|
||||
dst.flag = a.flag + b.flag;
|
||||
return dst;
|
||||
}
|
||||
|
|
@ -56,54 +83,85 @@ struct GroupSumsOp {
|
|||
|
||||
template <typename T>
|
||||
struct GroupNormNHWCParams {
|
||||
// The output buffer. Layout NHWC.
|
||||
// The output buffer. Shape is (n, h, w, c).
|
||||
T* dst;
|
||||
// The input buffer. Layout NHWC.
|
||||
|
||||
// Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c).
|
||||
T* add_out;
|
||||
|
||||
// The input buffer. Shape is (n, h, w, c).
|
||||
T const* src;
|
||||
|
||||
// Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c).
|
||||
T const* skip;
|
||||
|
||||
// Optional input buffer for bias tensor. Shape is (c).
|
||||
T const* bias;
|
||||
|
||||
// The gamma scaling factor.
|
||||
float const* gamma;
|
||||
|
||||
// The beta term to add in GN.
|
||||
float const* beta;
|
||||
// The temporary buffer to do the global parallel reduction. Size:
|
||||
// BLOCKS_PER_BATCH x C x 2.
|
||||
float* redBuffer;
|
||||
|
||||
// The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups.
|
||||
float* group_sum_buffer;
|
||||
|
||||
// The number of instances in the batch.
|
||||
int32_t n;
|
||||
|
||||
// The height and width of each activation map.
|
||||
int32_t h;
|
||||
int32_t w;
|
||||
// The number of channels.
|
||||
|
||||
// Number of channels.
|
||||
int32_t c;
|
||||
// The number of groups.
|
||||
|
||||
// Number of groups.
|
||||
int32_t groups;
|
||||
// Do we apply the Swish activation function?
|
||||
bool withSwish;
|
||||
|
||||
// Do we apply the SiLU activation function?
|
||||
bool use_silu;
|
||||
|
||||
// Precomputed values and parameters to control the execution of the kernels.
|
||||
|
||||
// The number of activations per instance (h * w) and the number of
|
||||
// activations per block.
|
||||
// Number of activations per instance (h * w)
|
||||
int32_t hw;
|
||||
int32_t hwPerBlock;
|
||||
// The number of channels per group and blocks per activation in the C
|
||||
// dimension.
|
||||
int32_t cPerBlock;
|
||||
int32_t cPerGroup;
|
||||
|
||||
// Number of activations per block
|
||||
int32_t hw_per_block;
|
||||
|
||||
// Number of channels per block in the C dimension.
|
||||
int32_t channels_per_block;
|
||||
|
||||
// Number of channels per group in the C dimension.
|
||||
int32_t channels_per_group;
|
||||
|
||||
// The precomputed stride between instances.
|
||||
int32_t hwc;
|
||||
// The inverse of hwc in floats (to compute mean/var).
|
||||
float invHWC;
|
||||
// The inverse of hw*channels_per_group to compute mean of a group.
|
||||
float inv_hw_channels_per_group;
|
||||
// The precomputed number of groups per block.
|
||||
int32_t groupsPerBlock;
|
||||
int32_t groups_per_block;
|
||||
|
||||
// Number of threads per block
|
||||
int32_t threads_per_block;
|
||||
|
||||
// Epsilon to get stable variance in normalization.
|
||||
float epsilon;
|
||||
|
||||
// Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise.
|
||||
bool broadcast_skip;
|
||||
|
||||
// For SkipGroupNorm, it points to the intermediate result of adding skip and bias.
|
||||
T* skip_workspace;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq);
|
||||
inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq);
|
||||
|
||||
template <>
|
||||
inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) {
|
||||
inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) {
|
||||
// Fetch two channels per thread.
|
||||
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
|
||||
|
||||
|
|
@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl
|
|||
sum += f2.x + f2.y;
|
||||
|
||||
// Update the sum of squares.
|
||||
sumSq += f2.x * f2.x + f2.y * f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) {
|
||||
inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) {
|
||||
// Fetch two channels per thread.
|
||||
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
|
||||
|
||||
|
|
@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f
|
|||
sum += f2.x + f2.y;
|
||||
|
||||
// Update the sum of squares.
|
||||
sumSq += f2.x * f2.x + f2.y * f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
template <typename T, int32_t tTHREADS_PER_BLOCK>
|
||||
__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
|
||||
// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset]
|
||||
template <typename T>
|
||||
inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias,
|
||||
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq);
|
||||
|
||||
template <>
|
||||
inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias,
|
||||
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) {
|
||||
// Fetch two channels per thread.
|
||||
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
|
||||
__half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]);
|
||||
__half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]);
|
||||
h2 = h2 + b;
|
||||
h2 = h2 + s;
|
||||
|
||||
*reinterpret_cast<__half2*>(&add_out[offset]) = h2;
|
||||
|
||||
float2 f2 = __half22float2(h2);
|
||||
sum += f2.x + f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias,
|
||||
int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) {
|
||||
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
|
||||
float2 s = *reinterpret_cast<float2 const*>(&skip[skip_offset]);
|
||||
float2 b = *reinterpret_cast<float2 const*>(&bias[bias_offset]);
|
||||
f2.x += s.x + b.x;
|
||||
f2.y += s.y + b.y;
|
||||
|
||||
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
|
||||
|
||||
sum += f2.x + f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset]
|
||||
template <typename T>
|
||||
inline __device__ void AddSkip(T* add_out, const T* src, const T* skip,
|
||||
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq);
|
||||
|
||||
template <>
|
||||
inline __device__ void AddSkip(half* add_out, const half* src, const half* skip,
|
||||
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) {
|
||||
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
|
||||
__half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]);
|
||||
h2 = h2 + s;
|
||||
|
||||
*reinterpret_cast<__half2*>(&add_out[offset]) = h2;
|
||||
|
||||
float2 f2 = __half22float2(h2);
|
||||
sum += f2.x + f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ void AddSkip(float* add_out, const float* src, const float* skip,
|
||||
int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) {
|
||||
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
|
||||
float2 s = *reinterpret_cast<float2 const*>(&skip[skip_offset]);
|
||||
f2.x += s.x;
|
||||
f2.y += s.y;
|
||||
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
|
||||
sum += f2.x + f2.y;
|
||||
sum_sq += f2.x * f2.x + f2.y * f2.y;
|
||||
}
|
||||
|
||||
template <typename T, int32_t THREADS_PER_BLOCK>
|
||||
__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
|
||||
// The object in charge of doing the sums for the different blocks.
|
||||
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
|
||||
typedef cub::BlockScan<GroupSums, THREADS_PER_BLOCK> BlockScan;
|
||||
|
||||
// Allocate shared memory for BlockScan.
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
// Allocate shared memory for the groups. We could reduce the amount of shared
|
||||
// memory reserved.
|
||||
__shared__ float2 smem[tTHREADS_PER_BLOCK];
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
|
||||
// Allocate shared memory for the groups. We could reduce the amount of shared memory reserved.
|
||||
__shared__ float2 smem[THREADS_PER_BLOCK];
|
||||
|
||||
// The instance in the batch.
|
||||
int32_t ni = blockIdx.z;
|
||||
// The channel loaded by that thread (2 channels per thread for F16x2).
|
||||
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
|
||||
|
||||
// The channel loaded by that thread.
|
||||
int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD;
|
||||
|
||||
if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The first activation loaded by that block.
|
||||
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
|
||||
int32_t hw_begin = blockIdx.y * params.hw_per_block;
|
||||
// The last activation loaded by that block.
|
||||
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
|
||||
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
|
||||
|
||||
// The sums.
|
||||
float sum = 0.F;
|
||||
float sumSq = 0.F;
|
||||
float sum_sq = 0.F;
|
||||
|
||||
// Iterate over the activations to compute the sums.
|
||||
if (ci < params.c) {
|
||||
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
|
||||
// The offset.
|
||||
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwi) * params.c + ci;
|
||||
UpdateSum(params.src, offset, sum, sumSq);
|
||||
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hw_begin) * params.c + ci;
|
||||
if (params.skip != nullptr) {
|
||||
// SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c)
|
||||
const int64_t bias_offset = static_cast<int64_t>(ci);
|
||||
T* add_out = params.skip_workspace;
|
||||
if (params.broadcast_skip) {
|
||||
const int64_t skip_offset = static_cast<int64_t>(ni) * params.c + ci;
|
||||
|
||||
if (params.bias != nullptr) {
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq);
|
||||
}
|
||||
} else {
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (params.bias != nullptr) {
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq);
|
||||
}
|
||||
} else {
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // GroupNorm
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
UpdateSum(params.src, offset, sum, sum_sq);
|
||||
}
|
||||
}
|
||||
|
||||
// The group that thread works on and the channel in the group (modulus).
|
||||
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
|
||||
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
|
||||
// The group index relative to the first group within the same block.
|
||||
int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group;
|
||||
// The channel in the group.
|
||||
int32_t cj = ci % params.channels_per_group;
|
||||
|
||||
// The data for the summations.
|
||||
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
|
||||
GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq};
|
||||
|
||||
// Do the segmented scan.
|
||||
// Do the segmented scan. InclusiveScan is not deterministic.
|
||||
GroupSums out;
|
||||
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
|
||||
BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp());
|
||||
|
||||
// Store the results for the groups in shared memory (to produce coalesced
|
||||
// stores later).
|
||||
if (cj == params.cPerGroup - 2) { //2 channels per thread
|
||||
smem[gi] = make_float2(out.sum, out.sumSq);
|
||||
// Store the results for the groups in shared memory (to produce coalesced stores later).
|
||||
// For each group, only the last thread of that group is picked to save sum to shared memory.
|
||||
if (cj == params.channels_per_group - CHANNELS_PER_THREAD) {
|
||||
smem[gi] = make_float2(out.sum, out.sum_sq);
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The global group index.
|
||||
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
|
||||
|
||||
// Threads that have nothing left to do, exit.
|
||||
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
|
||||
if (threadIdx.x >= params.groups_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The first threads (those storing to global memory, load the values).
|
||||
float2 sums = smem[threadIdx.x];
|
||||
// The global group index.
|
||||
// Use neighboring threads for coalesced write.
|
||||
int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x;
|
||||
|
||||
// Store to global memory.
|
||||
atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
|
||||
atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void groupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
|
||||
// Make sure the values are as we expect.
|
||||
ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0);
|
||||
// Make sure a group does not span multiple blocks.
|
||||
ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
|
||||
|
||||
dim3 grid;
|
||||
|
||||
// The number of blocks to compute all the channels.
|
||||
grid.x = params.c / params.cPerBlock;
|
||||
// The number of blocks to compute all the activations in a given instance.
|
||||
grid.y = divUp(params.hw, params.hwPerBlock);
|
||||
// The number of instances.
|
||||
grid.z = params.n;
|
||||
|
||||
switch (params.cPerBlock) {
|
||||
case 320:
|
||||
groupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
|
||||
break;
|
||||
case 480:
|
||||
groupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
|
||||
break;
|
||||
case 256:
|
||||
groupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
|
||||
break;
|
||||
case 128:
|
||||
groupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
|
||||
break;
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("Not implemented");
|
||||
if (gj < params.groups) {
|
||||
float2 sums = smem[threadIdx.x];
|
||||
const int index = (2 * ni) * params.groups + gj;
|
||||
atomicAdd(¶ms.group_sum_buffer[index], sums.x);
|
||||
atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish);
|
||||
void GroupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
|
||||
dim3 grid;
|
||||
|
||||
// The number of blocks to compute all the channels.
|
||||
grid.x = DivUp(params.c, params.channels_per_block);
|
||||
|
||||
// The number of blocks to compute all the activations in a given instance.
|
||||
grid.y = DivUp(params.hw, params.hw_per_block);
|
||||
|
||||
// The number of instances.
|
||||
grid.z = params.n;
|
||||
|
||||
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
|
||||
switch (params.threads_per_block) {
|
||||
case 256:
|
||||
GroupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
|
||||
break;
|
||||
case 192:
|
||||
GroupNormNHWCSumKernel<T, 192><<<grid, 192, 0, stream>>>(params);
|
||||
break;
|
||||
case 160:
|
||||
GroupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
|
||||
break;
|
||||
case 128:
|
||||
GroupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
|
||||
break;
|
||||
case 64:
|
||||
GroupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev,
|
||||
float2& gamma_f2, float2& beta_f2, bool silu);
|
||||
|
||||
template <>
|
||||
__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev,
|
||||
float2& gammaF2, float2& betaF2, bool swish) {
|
||||
__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev,
|
||||
float2& gamma_f2, float2& beta_f2, bool silu) {
|
||||
// Fetch two channels per thread.
|
||||
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
|
||||
|
||||
|
|
@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo
|
|||
float2 f2 = __half22float2(h2);
|
||||
|
||||
// Normalize the channels.
|
||||
f2.x = (f2.x - mean) * invStdDev;
|
||||
f2.y = (f2.y - mean) * invStdDev;
|
||||
f2.x = (f2.x - mean) * inv_std_dev;
|
||||
f2.y = (f2.y - mean) * inv_std_dev;
|
||||
|
||||
// Scale by gamma and add beta.
|
||||
f2.x = gammaF2.x * f2.x + betaF2.x;
|
||||
f2.y = gammaF2.y * f2.y + betaF2.y;
|
||||
f2.x = gamma_f2.x * f2.x + beta_f2.x;
|
||||
f2.y = gamma_f2.y * f2.y + beta_f2.y;
|
||||
|
||||
// Apply Swish if needed.
|
||||
if (swish) {
|
||||
// Apply SiLU activation if needed.
|
||||
if (silu) {
|
||||
f2.x = f2.x * sigmoid(f2.x);
|
||||
f2.y = f2.y * sigmoid(f2.y);
|
||||
}
|
||||
|
|
@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo
|
|||
}
|
||||
|
||||
template <>
|
||||
__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev,
|
||||
float2& gammaF2, float2& betaF2, bool swish) {
|
||||
__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev,
|
||||
float2& gamma_f2, float2& beta_f2, bool silu) {
|
||||
// Fetch two channels per thread.
|
||||
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
|
||||
|
||||
// Normalize the channels.
|
||||
f2.x = (f2.x - mean) * invStdDev;
|
||||
f2.y = (f2.y - mean) * invStdDev;
|
||||
f2.x = (f2.x - mean) * inv_std_dev;
|
||||
f2.y = (f2.y - mean) * inv_std_dev;
|
||||
|
||||
// Scale by gamma and add beta.
|
||||
f2.x = gammaF2.x * f2.x + betaF2.x;
|
||||
f2.y = gammaF2.y * f2.y + betaF2.y;
|
||||
f2.x = gamma_f2.x * f2.x + beta_f2.x;
|
||||
f2.y = gamma_f2.y * f2.y + beta_f2.y;
|
||||
|
||||
// Apply Swish if needed.
|
||||
if (swish) {
|
||||
// Apply SiLU activation if needed.
|
||||
if (silu) {
|
||||
f2.x = f2.x * sigmoid(f2.x);
|
||||
f2.y = f2.y * sigmoid(f2.y);
|
||||
}
|
||||
|
|
@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f
|
|||
*reinterpret_cast<float2*>(&dst[offset]) = f2;
|
||||
}
|
||||
|
||||
template <typename T, int32_t tTHREADS_PER_BLOCK>
|
||||
__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
|
||||
// The channel loaded by that thread (2 channels per thread for F16x2).
|
||||
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
|
||||
if (ci >= params.c) {
|
||||
template <typename T>
|
||||
__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
|
||||
// The channel loaded by that thread.
|
||||
int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD;
|
||||
if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The instance in the batch.
|
||||
int32_t ni = blockIdx.z;
|
||||
|
||||
// The group that thread works on and the channel in the group (modulus).
|
||||
int32_t gi = ci / params.cPerGroup;
|
||||
// The group that thread works on.
|
||||
int32_t gi = ci / params.channels_per_group;
|
||||
|
||||
// Load the sum and sum of squares for the group.
|
||||
float sum = 0.F, sumSq = 0.F;
|
||||
float sum = 0.F, sum_sq = 0.F;
|
||||
if (gi < params.groups) {
|
||||
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
|
||||
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
|
||||
const int index = (2 * ni) * params.groups + gi;
|
||||
sum = params.group_sum_buffer[index];
|
||||
sum_sq = params.group_sum_buffer[index + params.groups];
|
||||
}
|
||||
|
||||
// Load gamma/beta.
|
||||
float2 gammaF2 = *reinterpret_cast<float2 const*>(¶ms.gamma[ci]);
|
||||
float2 betaF2 = *reinterpret_cast<float2 const*>(¶ms.beta[ci]);
|
||||
// Load gamma/beta. Fetch two per thread.
|
||||
float2 gamma_f2 = *reinterpret_cast<float2 const*>(¶ms.gamma[ci]);
|
||||
float2 beta_f2 = *reinterpret_cast<float2 const*>(¶ms.beta[ci]);
|
||||
|
||||
// Compute the mean.
|
||||
float mean = sum * params.invHWC;
|
||||
float mean = sum * params.inv_hw_channels_per_group;
|
||||
// Compute the variance.
|
||||
float var = sumSq * params.invHWC - (mean * mean);
|
||||
float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean);
|
||||
// Compute the inverse of the stddev.
|
||||
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var);
|
||||
float inv_std_dev = rsqrtf(var + params.epsilon);
|
||||
|
||||
// The first activation loaded by that block.
|
||||
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
|
||||
// The last activation loaded by that block.
|
||||
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
|
||||
int32_t hw_begin = blockIdx.y * params.hw_per_block;
|
||||
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
|
||||
|
||||
// Iterate over the activations to compute the sums.
|
||||
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
|
||||
// The src/dst offset.
|
||||
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
|
||||
|
||||
// Fetch two channels per thread.
|
||||
computeGroupNorm<T>(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
|
||||
const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src;
|
||||
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hw_begin) * params.c + ci;
|
||||
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) {
|
||||
ComputeGroupNorm<T>(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void groupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
|
||||
// Make sure the dimensions are aligned with what we expect.
|
||||
ORT_ENFORCE(params.c % params.cPerBlock == 0);
|
||||
// Make sure a group does not span multiple blocks.
|
||||
ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
|
||||
|
||||
void GroupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t stream) {
|
||||
dim3 grid;
|
||||
|
||||
// The number of blocks to compute all the channels.
|
||||
grid.x = params.c / params.cPerBlock;
|
||||
grid.x = DivUp(params.c, params.channels_per_block);
|
||||
// The number of blocks to compute all the activations in a given instance.
|
||||
grid.y = divUp(params.hw, params.hwPerBlock);
|
||||
grid.y = DivUp(params.hw, params.hw_per_block);
|
||||
// The number of instances.
|
||||
grid.z = params.n;
|
||||
|
||||
switch (params.cPerBlock) {
|
||||
case 320:
|
||||
groupNormNHWCScaleKernel<T, 160><<<grid, 160, 0, stream>>>(params);
|
||||
break;
|
||||
case 480:
|
||||
groupNormNHWCScaleKernel<T, 256><<<grid, 256, 0, stream>>>(params);
|
||||
break;
|
||||
// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
|
||||
switch (params.threads_per_block) {
|
||||
case 256:
|
||||
groupNormNHWCScaleKernel<T, 128><<<grid, 128, 0, stream>>>(params);
|
||||
GroupNormNHWCScaleKernel<T><<<grid, 256, 0, stream>>>(params);
|
||||
break;
|
||||
case 192:
|
||||
GroupNormNHWCScaleKernel<T><<<grid, 192, 0, stream>>>(params);
|
||||
break;
|
||||
case 160:
|
||||
GroupNormNHWCScaleKernel<T><<<grid, 160, 0, stream>>>(params);
|
||||
break;
|
||||
case 128:
|
||||
groupNormNHWCScaleKernel<T, 64><<<grid, 64, 0, stream>>>(params);
|
||||
GroupNormNHWCScaleKernel<T><<<grid, 128, 0, stream>>>(params);
|
||||
break;
|
||||
case 64:
|
||||
GroupNormNHWCScaleKernel<T><<<grid, 64, 0, stream>>>(params);
|
||||
break;
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
|
||||
int32_t maxDivisor = -1;
|
||||
int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) {
|
||||
int32_t max_divisor = -1;
|
||||
for (int32_t i = 1; i <= std::sqrt(n); i++) {
|
||||
if (n % i == 0) {
|
||||
int32_t divisor1 = n / i;
|
||||
int32_t divisor2 = i;
|
||||
|
||||
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
|
||||
maxDivisor = divisor1;
|
||||
if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) {
|
||||
max_divisor = divisor1;
|
||||
}
|
||||
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
|
||||
maxDivisor = divisor2;
|
||||
if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) {
|
||||
max_divisor = divisor2;
|
||||
}
|
||||
}
|
||||
}
|
||||
return maxDivisor;
|
||||
return max_divisor;
|
||||
}
|
||||
|
||||
// Find proper channels per block based on a cost function: The cost is number of channels corresponding to
|
||||
// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has
|
||||
// work to do so it is ideal case.
|
||||
int FindChannelsPerBlock(int num_channels, int channels_per_group) {
|
||||
int min_cost = -1;
|
||||
int best_candidate = -1;
|
||||
for (size_t i = kNumOfSizes; i > 0; --i) {
|
||||
if (kSizes[i - 1] < channels_per_group) {
|
||||
break;
|
||||
}
|
||||
|
||||
int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group;
|
||||
int blocks = (num_channels + channels_per_block - 1) / channels_per_block;
|
||||
int cost = blocks * kSizes[i - 1] - num_channels;
|
||||
if (cost == 0) {
|
||||
return channels_per_block;
|
||||
}
|
||||
|
||||
if (min_cost == -1 || cost < min_cost) {
|
||||
min_cost = cost;
|
||||
best_candidate = channels_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
return best_candidate;
|
||||
}
|
||||
|
||||
int GetChannelsPerBlock(int num_channels, int num_groups) {
|
||||
int32_t channels_per_group = num_channels / num_groups;
|
||||
int32_t channels_per_block = channels_per_group;
|
||||
if (channels_per_group < kMaxSize / 2) {
|
||||
channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group);
|
||||
}
|
||||
return channels_per_block;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
cudaStream_t stream,
|
||||
T* output,
|
||||
T* add_out,
|
||||
const T* input,
|
||||
const T* skip,
|
||||
const T* bias,
|
||||
const float* gamma,
|
||||
const float* beta,
|
||||
void* workspace,
|
||||
|
|
@ -397,79 +588,94 @@ Status LaunchGroupNormKernel(
|
|||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_swish_activation) {
|
||||
if (batch_size > static_cast<int>(kMaxGroupNormBatchSize)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
|
||||
"only support batch_size <= 32. Got", batch_size);
|
||||
}
|
||||
|
||||
if (num_groups != static_cast<int>(kGroupNormNumberOfGroups)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
|
||||
"only num_groups=32 is supported. Got", num_groups);
|
||||
}
|
||||
|
||||
bool use_silu,
|
||||
bool broadcast_skip,
|
||||
int channels_per_block) {
|
||||
GroupNormNHWCParams<T> params;
|
||||
int32_t cPerBlock = 320;
|
||||
int32_t maxBlocksPerHW = 1024;
|
||||
switch (num_channels) {
|
||||
case 960:
|
||||
case 1920:
|
||||
cPerBlock = 480;
|
||||
break;
|
||||
case 512:
|
||||
case 256:
|
||||
cPerBlock = 256;
|
||||
break;
|
||||
case 128:
|
||||
cPerBlock = 128;
|
||||
break;
|
||||
default:
|
||||
cPerBlock = 320;
|
||||
|
||||
int32_t channels_per_group = num_channels / num_groups;
|
||||
// channels_per_block is computed in PrePack.
|
||||
// If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here.
|
||||
if (channels_per_block < channels_per_group) {
|
||||
channels_per_block = GetChannelsPerBlock(num_channels, num_groups);
|
||||
}
|
||||
|
||||
params.withSwish = use_swish_activation;
|
||||
// TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases
|
||||
if (channels_per_block % channels_per_group != 0 ||
|
||||
channels_per_block > kMaxSize ||
|
||||
(channels_per_group % CHANNELS_PER_THREAD != 0)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"GroupNorm in CUDA does not support the input: n=", batch_size,
|
||||
" h=", height,
|
||||
" w=", width,
|
||||
" c=", num_channels,
|
||||
" groups=", num_groups);
|
||||
}
|
||||
|
||||
params.use_silu = use_silu;
|
||||
params.dst = output;
|
||||
params.add_out = add_out;
|
||||
params.src = input;
|
||||
params.skip = skip;
|
||||
params.bias = bias;
|
||||
params.gamma = gamma;
|
||||
params.beta = beta;
|
||||
params.redBuffer = reinterpret_cast<float*>(workspace);
|
||||
params.group_sum_buffer = reinterpret_cast<float*>(workspace);
|
||||
params.n = batch_size;
|
||||
params.h = height;
|
||||
params.w = width;
|
||||
params.c = num_channels;
|
||||
params.groups = num_groups;
|
||||
params.hw = params.h * params.w;
|
||||
const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW);
|
||||
params.hwPerBlock = divUp(params.hw, blocksPerHW);
|
||||
params.cPerBlock = cPerBlock;
|
||||
params.cPerGroup = params.c / params.groups;
|
||||
|
||||
// This will allocate as many blocks as possible to partition HW.
|
||||
// For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw.
|
||||
// TODO: tune this logic to find proper blocks when hw is small.
|
||||
constexpr int32_t max_blocks_per_hw = 1024;
|
||||
const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw);
|
||||
params.hw_per_block = DivUp(params.hw, blocks_per_hw);
|
||||
|
||||
params.channels_per_block = channels_per_block;
|
||||
params.channels_per_group = channels_per_group;
|
||||
params.hwc = params.hw * params.c;
|
||||
params.invHWC = 1.F / (float)(params.hw * params.cPerGroup);
|
||||
params.groupsPerBlock = cPerBlock / params.cPerGroup;
|
||||
params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group);
|
||||
params.groups_per_block = channels_per_block / params.channels_per_group;
|
||||
params.epsilon = epsilon;
|
||||
params.broadcast_skip = broadcast_skip;
|
||||
|
||||
// Workspace for SkipGroupNorm to store intermediate results of src+skip+bias.
|
||||
params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst;
|
||||
|
||||
params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD;
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
|
||||
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));
|
||||
|
||||
GroupNormNHWCSum<T>(params, stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("input", input, batch_size, num_channels, height * width);
|
||||
DUMP_TENSOR("gamma", gamma, 1, num_channels);
|
||||
DUMP_TENSOR("beta", beta, 1, num_channels);
|
||||
cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream);
|
||||
groupNormNHWCSum<T>(params, stream);
|
||||
DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2);
|
||||
DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups);
|
||||
|
||||
GroupNormNHWCScale<T>(params, stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
groupNormNHWCScale<T>(params, stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
DUMP_TENSOR("output", output, batch_size, num_channels, height * width);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output,
|
||||
const half* input, const float* gamma, const float* beta, void* workspace,
|
||||
template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output, half* add_out,
|
||||
const half* input, const half* skip, const half* bias,
|
||||
const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
int height, int width, int num_groups, bool silu,
|
||||
bool broadcast_skip, int channels_per_block);
|
||||
|
||||
template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output,
|
||||
const float* input, const float* gamma, const float* beta, void* workspace,
|
||||
template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output, float* add_out,
|
||||
const float* input, const float* skip, const float* bias,
|
||||
const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
int height, int width, int num_groups, bool silu,
|
||||
bool broadcast_skip, int channels_per_block);
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,29 +12,33 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
constexpr size_t kMaxGroupNormBatchSize = 32;
|
||||
constexpr size_t kGroupNormNumberOfGroups = 32;
|
||||
|
||||
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
|
||||
constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) {
|
||||
// Two buffers for sum and squared sum
|
||||
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
|
||||
return (sizeof(float) * 2) * batch_size * num_groups;
|
||||
}
|
||||
|
||||
int GetChannelsPerBlock(int num_channels, int num_groups);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
cudaStream_t stream,
|
||||
T* output, // normalized output tensor
|
||||
const T* input, // input tensor
|
||||
const float* gamma, // gamma (also known as weight or scale)
|
||||
const float* beta, // beta (also known as bias)
|
||||
void* workspace, // Work space
|
||||
float epsilon, // epsilon used normalization
|
||||
int batch_size, // N
|
||||
int num_channels, // C
|
||||
int height, // H
|
||||
int width, // W
|
||||
int num_groups, // number of groups
|
||||
bool use_swish_activation // Whether there is Swish activation after group normalization
|
||||
T* output, // normalized output tensor. Shape is (n, h, w, c)
|
||||
T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c)
|
||||
const T* input, // input tensor. Shape is (n, h, w, c)
|
||||
const T* skip, // optional skip tensor. Shape is (n, h, w, c)
|
||||
const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm
|
||||
const float* gamma, // gamma (also known as weight or scale). Shape is (c)
|
||||
const float* beta, // beta (also known as bias). Shape is (c)
|
||||
void* workspace, // Work space
|
||||
float epsilon, // epsilon used normalization
|
||||
int batch_size, // N
|
||||
int num_channels, // C
|
||||
int height, // H
|
||||
int width, // W
|
||||
int num_groups, // number of groups
|
||||
bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization
|
||||
bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast.
|
||||
int channels_per_block // Pre-computed channels per block.
|
||||
);
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
|
|||
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
|
||||
}
|
||||
|
||||
Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
|
||||
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
|
||||
is_packed = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* gamma = context->Input<Tensor>(1);
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"The number of groups of channels. It should be a divisor of the number of channels C",
|
||||
AttributeProto::INT)
|
||||
.Attr("activation",
|
||||
"Activation after group normalization: 0 for None, 1 for Swish",
|
||||
"Activation after group normalization: 0 for None, 1 for SiLU",
|
||||
AttributeProto::INT)
|
||||
.Attr("channels_last",
|
||||
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.",
|
||||
|
|
@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
|
||||
|
||||
constexpr const char* SkipGroupNorm_ver1_doc = R"DOC(
|
||||
This operator element-wise adds x, skip and bias, then apply group normalization and optional activation.
|
||||
|
||||
This operator transforms input according to
|
||||
s = x + skip + bias
|
||||
y = gamma * (s - mean) / sqrt(variance + epsilon) + beta
|
||||
|
||||
The input channels are separated into num_groups groups, each containing num_channels / num_groups channels.
|
||||
The num_channels must be divisible by num_groups.
|
||||
The mean and standard-deviation of s are calculated separately over the each group.
|
||||
The weight and bias are per-channel affine transform parameter vectors of size num_channels.
|
||||
|
||||
The activation attribute can be used to enable activation after group normalization.
|
||||
)DOC";
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
SkipGroupNorm, 1,
|
||||
OpSchema()
|
||||
.SetDoc(SkipGroupNorm_ver1_doc)
|
||||
.Attr("epsilon", "The epsilon value to use to avoid division by zero",
|
||||
AttributeProto::FLOAT, static_cast<float>(1e-5))
|
||||
.Attr("groups",
|
||||
"The number of groups of channels. It should be a divisor of the number of channels C",
|
||||
AttributeProto::INT)
|
||||
.Attr("activation",
|
||||
"Activation after group normalization: 0 for None, 1 for SiLU",
|
||||
AttributeProto::INT)
|
||||
.Attr("channels_last",
|
||||
"1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Input(0,
|
||||
"X",
|
||||
"Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 "
|
||||
" or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels,"
|
||||
" and H and W are the height and width of the data",
|
||||
"T")
|
||||
.Input(1,
|
||||
"gamma",
|
||||
"1D gamma tensor for normalization with shape (C), where C is number of channels",
|
||||
"M")
|
||||
.Input(2,
|
||||
"beta",
|
||||
"1D beta tensor for normalization with shape (C), where C is number of channels",
|
||||
"M")
|
||||
.Input(3,
|
||||
"skip",
|
||||
"4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)",
|
||||
"T")
|
||||
.Input(4,
|
||||
"bias",
|
||||
"1D bias tensor. Dimensions are (C), where C is number of channels",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
"Y",
|
||||
"The output tensor of the same shape as X",
|
||||
"T")
|
||||
.Output(1,
|
||||
"S",
|
||||
"The element-wise sum of input x, skip and bias tensors. It has the same shape as X",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
constexpr const char* BiasSplitGelu_ver1_doc = R"DOC(
|
||||
A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
|
||||
tensor multiplies the Gelu activation result of right tensor.
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
|
||||
|
|
@ -205,6 +206,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
|
||||
|
|
|
|||
|
|
@ -200,6 +200,7 @@ class SymbolicShapeInference:
|
|||
"GemmFastGelu": self._infer_GemmFastGelu,
|
||||
"GemmFloat8": self._infer_GemmFloat8,
|
||||
"GroupNorm": self._infer_GroupNorm,
|
||||
"SkipGroupNorm": self._infer_SkipGroupNorm,
|
||||
"LayerNormalization": self._infer_LayerNormalization,
|
||||
"LongformerAttention": self._infer_LongformerAttention,
|
||||
"MultiHeadAttention": self._infer_MultiHeadAttention,
|
||||
|
|
@ -2376,6 +2377,11 @@ class SymbolicShapeInference:
|
|||
def _infer_GroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
def _infer_SkipGroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node, 0, 0)
|
||||
if len(node.output) > 1:
|
||||
self._propagate_shape_and_type(node, 0, 1)
|
||||
|
||||
def _infer_BiasSplitGelu(self, node): # noqa: N802
|
||||
input_shape = self._get_shape(node, 0)
|
||||
bias_shape = self._get_shape(node, 1)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
|
@ -229,7 +229,7 @@ class CudaSession:
|
|||
del self.io_binding
|
||||
del self.ort_session
|
||||
|
||||
def allocate_buffers(self, shape_dict: Dict[str, tuple]):
|
||||
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
|
||||
"""Allocate tensors for I/O Binding"""
|
||||
if self.enable_cuda_graph:
|
||||
for name, shape in shape_dict.items():
|
||||
|
|
|
|||
286
onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc
Normal file
286
onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <random>
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) {
|
||||
constexpr int64_t B = 2;
|
||||
constexpr int64_t C = 16;
|
||||
constexpr int64_t H = 2;
|
||||
constexpr int64_t W = 2;
|
||||
|
||||
std::vector<int64_t> dims_nhwc{B, H, W, C};
|
||||
std::vector<float> input_data_nhwc = {
|
||||
-0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f,
|
||||
-0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f,
|
||||
-0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f,
|
||||
-2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f,
|
||||
1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f,
|
||||
-0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f,
|
||||
-0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f,
|
||||
-1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f,
|
||||
-0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f,
|
||||
1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f,
|
||||
0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f,
|
||||
0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f,
|
||||
0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f,
|
||||
-0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f,
|
||||
1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f,
|
||||
-0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f};
|
||||
|
||||
std::vector<float> gamma_data = {
|
||||
0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f,
|
||||
0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f};
|
||||
|
||||
std::vector<float> beta_data = {
|
||||
-1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f,
|
||||
-1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f};
|
||||
|
||||
std::vector<float> skip_data_nhwc = {
|
||||
0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f,
|
||||
-0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f,
|
||||
0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f,
|
||||
-0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f,
|
||||
0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f,
|
||||
-0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f,
|
||||
-0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f,
|
||||
0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f,
|
||||
-0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f,
|
||||
0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f,
|
||||
-0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f,
|
||||
-1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f,
|
||||
-0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f,
|
||||
0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f,
|
||||
0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f,
|
||||
-0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f,
|
||||
1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f};
|
||||
|
||||
std::vector<float> norm_data_nhwc = {
|
||||
-1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f,
|
||||
1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f,
|
||||
-0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f,
|
||||
0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f,
|
||||
-1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f,
|
||||
-0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f,
|
||||
2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f,
|
||||
-1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f,
|
||||
-1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f,
|
||||
1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f,
|
||||
-0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f,
|
||||
-0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f,
|
||||
-0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f,
|
||||
-0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f,
|
||||
0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f,
|
||||
-1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f,
|
||||
-1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f,
|
||||
1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f,
|
||||
-0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f,
|
||||
-2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f,
|
||||
-1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f,
|
||||
-0.542480f, 1.990234f};
|
||||
|
||||
std::vector<float> add_out_data_nhwc = {
|
||||
-0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f,
|
||||
0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f,
|
||||
-1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f,
|
||||
-1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f,
|
||||
2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f,
|
||||
-0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f,
|
||||
-1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f,
|
||||
0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f,
|
||||
-1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f,
|
||||
2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f,
|
||||
-0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f,
|
||||
-0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f,
|
||||
-1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f,
|
||||
1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f,
|
||||
1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f,
|
||||
0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f};
|
||||
|
||||
int min_cuda_architecture = 530;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
|
||||
std::array<int, 2> channels_last_values = {-1, 1};
|
||||
|
||||
for (const int channels_last : channels_last_values) {
|
||||
if (enable_cuda) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
test.AddAttribute<int64_t>("groups", 4);
|
||||
test.AddAttribute<int64_t>("activation", 0);
|
||||
|
||||
// We interpret channels_last==-1 as the attribute not being provided
|
||||
if (channels_last != -1) {
|
||||
test.AddAttribute<int64_t>("channels_last", channels_last);
|
||||
}
|
||||
|
||||
test.AddInput<MLFloat16>("X", dims_nhwc, ToFloat16(input_data_nhwc));
|
||||
test.AddInput<float>("gamma", {C}, gamma_data);
|
||||
test.AddInput<float>("beta", {C}, beta_data);
|
||||
test.AddInput<MLFloat16>("skip", dims_nhwc, ToFloat16(skip_data_nhwc));
|
||||
test.AddInput<MLFloat16>("bias", {C}, ToFloat16(bias_data));
|
||||
|
||||
constexpr float rel_error = 0.0f;
|
||||
constexpr float abs_error = 0.02f;
|
||||
test.AddOutput<MLFloat16>("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error);
|
||||
test.AddOutput<MLFloat16>("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
|
||||
constexpr int64_t B = 1;
|
||||
constexpr int64_t C = 64;
|
||||
constexpr int64_t H = 1;
|
||||
constexpr int64_t W = 1;
|
||||
|
||||
std::vector<int64_t> dims_nhwc{B, H, W, C};
|
||||
std::vector<float> input_data_nhwc = {
|
||||
0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f,
|
||||
-0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f,
|
||||
0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f,
|
||||
-1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f,
|
||||
-0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f,
|
||||
-1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f,
|
||||
-1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f,
|
||||
-0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f};
|
||||
|
||||
std::vector<float> gamma_data = {
|
||||
0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f,
|
||||
0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f,
|
||||
0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f,
|
||||
1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f,
|
||||
-0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f,
|
||||
0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f,
|
||||
0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f,
|
||||
-0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f};
|
||||
|
||||
std::vector<float> beta_data = {
|
||||
-0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f,
|
||||
1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f,
|
||||
-2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f,
|
||||
-1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f,
|
||||
-1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f,
|
||||
-0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f,
|
||||
-1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f,
|
||||
-0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f};
|
||||
|
||||
std::vector<float> skip_data = {
|
||||
0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f,
|
||||
0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f,
|
||||
-0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f,
|
||||
-0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f,
|
||||
-0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f,
|
||||
-0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f,
|
||||
-1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f,
|
||||
-1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f,
|
||||
1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f,
|
||||
-0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f,
|
||||
1.248047f, -2.515625f, 0.080383f, 0.376221f};
|
||||
|
||||
std::vector<float> norm_data_nhwc = {
|
||||
0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f,
|
||||
1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f,
|
||||
-2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f,
|
||||
-2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f,
|
||||
-0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f,
|
||||
-0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f,
|
||||
-1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f,
|
||||
-0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f};
|
||||
|
||||
std::vector<float> add_out_data_nhwc = {
|
||||
1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f,
|
||||
1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f,
|
||||
-0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f,
|
||||
1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f,
|
||||
-1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f,
|
||||
-0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f,
|
||||
-0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f,
|
||||
-0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f,
|
||||
0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f,
|
||||
0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f,
|
||||
1.401367f, -3.828125f, 2.292969f, 1.061523f};
|
||||
|
||||
int min_cuda_architecture = 530;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
|
||||
std::array<bool, 2> has_add_out_values = {true, false};
|
||||
std::array<int, 2> skip_dims = {2, 4};
|
||||
|
||||
constexpr int channels_last = 1;
|
||||
for (const int skip_dim : skip_dims) {
|
||||
for (const bool has_add_out : has_add_out_values) {
|
||||
if (enable_cuda) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (enable_cuda && channels_last != 0) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
// Don't run the test if no providers are supported
|
||||
if (execution_providers.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
test.AddAttribute<int64_t>("groups", 8);
|
||||
test.AddAttribute<int64_t>("activation", 0);
|
||||
|
||||
// We interpret channels_last==-1 as the attribute not being provided
|
||||
if (channels_last != -1) {
|
||||
test.AddAttribute<int64_t>("channels_last", channels_last);
|
||||
}
|
||||
|
||||
test.AddInput<MLFloat16>("X", dims_nhwc, ToFloat16(input_data_nhwc));
|
||||
test.AddInput<float>("gamma", {C}, gamma_data);
|
||||
test.AddInput<float>("beta", {C}, beta_data);
|
||||
if (skip_dim == 2) {
|
||||
test.AddInput<MLFloat16>("skip", {B, C}, ToFloat16(skip_data));
|
||||
} else {
|
||||
test.AddInput<MLFloat16>("skip", {B, 1, 1, C}, ToFloat16(skip_data));
|
||||
}
|
||||
// no bias
|
||||
|
||||
constexpr float rel_error = 0.0f;
|
||||
constexpr float abs_error = 0.02f;
|
||||
test.AddOutput<MLFloat16>("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error);
|
||||
|
||||
if (has_add_out) {
|
||||
test.AddOutput<MLFloat16>("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error);
|
||||
}
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
541
onnxruntime/test/python/transformers/test_group_norm.py
Normal file
541
onnxruntime/test/python/transformers/test_group_norm.py
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
# --------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# -------------------------------------------------------------------------
|
||||
import statistics
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from time import perf_counter
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class GroupNormOpType(Enum):
|
||||
GROUP_NORM = 1
|
||||
SKIP_GROUP_NORM = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupNormConfig:
|
||||
batch_size: int
|
||||
height: int
|
||||
width: int
|
||||
channels: int
|
||||
epsilon: float = 1e-5
|
||||
num_groups: int = 32
|
||||
activation: bool = False
|
||||
channels_last: bool = True
|
||||
fp16: bool = False
|
||||
|
||||
op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM
|
||||
has_bias: bool = False
|
||||
has_add_out: bool = False
|
||||
broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C)
|
||||
|
||||
def get_skip_symbolic_shape(self):
|
||||
skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]}
|
||||
return skip_shape[self.broadcast_skip]
|
||||
|
||||
def get_skip_shape(self):
|
||||
skip_shape = {
|
||||
0: [self.batch_size, self.height, self.width, self.channels],
|
||||
2: [self.batch_size, self.channels],
|
||||
4: [self.batch_size, 1, 1, self.channels],
|
||||
}
|
||||
return skip_shape[self.broadcast_skip]
|
||||
|
||||
def broadcast(self, skip: torch.Tensor):
|
||||
if self.broadcast_skip == 2:
|
||||
return skip.reshape(self.batch_size, 1, 1, self.channels)
|
||||
|
||||
return skip
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
b: int,
|
||||
h: int,
|
||||
w: int,
|
||||
c: int,
|
||||
fp16: bool = False,
|
||||
activation: bool = False,
|
||||
template: int = 0,
|
||||
num_groups: int = 32,
|
||||
):
|
||||
if template == 0:
|
||||
return GroupNormConfig(
|
||||
b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups
|
||||
)
|
||||
|
||||
if template == 1:
|
||||
return GroupNormConfig(
|
||||
b,
|
||||
h,
|
||||
w,
|
||||
c,
|
||||
fp16=fp16,
|
||||
activation=activation,
|
||||
op_type=GroupNormOpType.SKIP_GROUP_NORM,
|
||||
has_bias=True,
|
||||
has_add_out=True,
|
||||
broadcast_skip=0,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
if template == 2:
|
||||
return GroupNormConfig(
|
||||
b,
|
||||
h,
|
||||
w,
|
||||
c,
|
||||
fp16=fp16,
|
||||
activation=activation,
|
||||
op_type=GroupNormOpType.SKIP_GROUP_NORM,
|
||||
has_bias=False,
|
||||
has_add_out=False,
|
||||
broadcast_skip=2,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
if template == 3:
|
||||
return GroupNormConfig(
|
||||
b,
|
||||
h,
|
||||
w,
|
||||
c,
|
||||
fp16=fp16,
|
||||
activation=activation,
|
||||
op_type=GroupNormOpType.SKIP_GROUP_NORM,
|
||||
has_bias=True,
|
||||
has_add_out=False,
|
||||
broadcast_skip=4,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
if template == 4: # No bias
|
||||
return GroupNormConfig(
|
||||
b,
|
||||
h,
|
||||
w,
|
||||
c,
|
||||
fp16=fp16,
|
||||
activation=activation,
|
||||
op_type=GroupNormOpType.SKIP_GROUP_NORM,
|
||||
has_bias=False,
|
||||
has_add_out=True,
|
||||
broadcast_skip=0,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
if template == 5: # No bias, no add_out
|
||||
return GroupNormConfig(
|
||||
b,
|
||||
h,
|
||||
w,
|
||||
c,
|
||||
fp16=fp16,
|
||||
activation=activation,
|
||||
op_type=GroupNormOpType.SKIP_GROUP_NORM,
|
||||
has_bias=False,
|
||||
has_add_out=False,
|
||||
broadcast_skip=0,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_group_norm_graph(config: GroupNormConfig) -> bytes:
|
||||
inputs = ["input", "gamma", "beta"]
|
||||
outputs = ["output"]
|
||||
op_type = "GroupNorm"
|
||||
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
|
||||
op_type = "SkipGroupNorm"
|
||||
inputs = [*inputs, "skip"]
|
||||
if config.has_bias:
|
||||
inputs = [*inputs, "bias"]
|
||||
if config.has_add_out:
|
||||
outputs = [*outputs, "add_out"]
|
||||
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
op_type,
|
||||
inputs,
|
||||
outputs,
|
||||
op_type + "_0",
|
||||
activation=int(config.activation),
|
||||
channels_last=int(config.channels_last),
|
||||
epsilon=config.epsilon,
|
||||
groups=config.num_groups,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT
|
||||
|
||||
input_shapes = [
|
||||
helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]),
|
||||
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]),
|
||||
helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]),
|
||||
]
|
||||
output_shapes = [
|
||||
helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]),
|
||||
]
|
||||
|
||||
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
|
||||
input_shapes = [
|
||||
*input_shapes,
|
||||
helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()),
|
||||
]
|
||||
if config.has_bias:
|
||||
input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])]
|
||||
if config.has_add_out:
|
||||
output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"Group_Norm_Graph",
|
||||
input_shapes,
|
||||
output_shapes,
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model.SerializeToString()
|
||||
|
||||
|
||||
def group_norm_ort(
|
||||
src: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
skip: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
config: GroupNormConfig,
|
||||
measure_latency=False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]:
|
||||
onnx_model_str = create_group_norm_graph(config)
|
||||
ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"])
|
||||
|
||||
session = CudaSession(ort_session, device=torch.device("cuda:0"))
|
||||
|
||||
io_shape = {
|
||||
"input": [config.batch_size, config.height, config.width, config.channels],
|
||||
"gamma": [config.channels],
|
||||
"beta": [config.channels],
|
||||
"output": [config.batch_size, config.height, config.width, config.channels],
|
||||
}
|
||||
|
||||
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
|
||||
io_shape["skip"] = config.get_skip_shape()
|
||||
if config.has_bias:
|
||||
io_shape["bias"] = [config.channels]
|
||||
if config.has_add_out:
|
||||
io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels]
|
||||
|
||||
session.allocate_buffers(io_shape)
|
||||
|
||||
ort_inputs = {
|
||||
"input": src,
|
||||
"gamma": gamma,
|
||||
"beta": beta,
|
||||
}
|
||||
|
||||
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
|
||||
ort_inputs["skip"] = skip
|
||||
if config.has_bias:
|
||||
ort_inputs["bias"] = bias
|
||||
|
||||
ort_outputs = session.infer(ort_inputs)
|
||||
output = ort_outputs["output"]
|
||||
|
||||
add_out = (
|
||||
ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None
|
||||
)
|
||||
|
||||
if measure_latency:
|
||||
latency_list = []
|
||||
for _ in range(10000):
|
||||
start_time = perf_counter()
|
||||
session.infer(ort_inputs)
|
||||
end_time = perf_counter()
|
||||
latency_list.append(end_time - start_time)
|
||||
average_latency = statistics.mean(latency_list)
|
||||
return output, add_out, average_latency
|
||||
|
||||
return output, add_out, None
|
||||
|
||||
|
||||
def group_norm_torch(
|
||||
src: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
skip: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
config: GroupNormConfig,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
add_out = src
|
||||
|
||||
if skip is not None:
|
||||
assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM
|
||||
add_out = add_out + config.broadcast(skip)
|
||||
|
||||
if bias is not None:
|
||||
assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM
|
||||
add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0])
|
||||
|
||||
x = add_out
|
||||
if config.channels_last:
|
||||
x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW
|
||||
|
||||
weight = gamma.to(x.dtype)
|
||||
bias = beta.to(x.dtype)
|
||||
output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon)
|
||||
|
||||
if config.activation:
|
||||
torch.nn.functional.silu(output, inplace=True)
|
||||
|
||||
if config.channels_last:
|
||||
output = output.permute(0, 2, 3, 1) # from NCHW to NHWC
|
||||
|
||||
return output, add_out
|
||||
|
||||
|
||||
def print_tensor(name, tensor):
|
||||
# Print in the format that could be directly added to unit tests in C++.
|
||||
torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000)
|
||||
print(name)
|
||||
if tensor is not None:
|
||||
print("shape", tensor.shape)
|
||||
text = str(tensor.clone().flatten())
|
||||
print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,"))
|
||||
else:
|
||||
print(tensor)
|
||||
|
||||
|
||||
def run_parity(config, measure_latency=True, verbose=False):
|
||||
float_type = torch.float16 if config.fp16 else torch.float32
|
||||
|
||||
input_tensor = torch.randn(
|
||||
config.batch_size,
|
||||
config.height,
|
||||
config.width,
|
||||
config.channels,
|
||||
device="cuda",
|
||||
dtype=float_type,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
gamma = torch.randn(
|
||||
config.channels,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
beta = torch.randn(
|
||||
config.channels,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
skip = None
|
||||
bias = None
|
||||
if config.op_type == GroupNormOpType.SKIP_GROUP_NORM:
|
||||
skip = torch.randn(
|
||||
*config.get_skip_shape(),
|
||||
device="cuda",
|
||||
dtype=float_type,
|
||||
requires_grad=False,
|
||||
)
|
||||
if config.has_bias:
|
||||
bias = torch.randn(
|
||||
config.channels,
|
||||
device="cuda",
|
||||
dtype=float_type,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(config)
|
||||
print_tensor("input", input_tensor)
|
||||
print_tensor("gamma", gamma)
|
||||
print_tensor("beta", beta)
|
||||
print_tensor("skip", skip)
|
||||
print_tensor("bias", bias)
|
||||
|
||||
out_ort, ort_add_out, latency = group_norm_ort(
|
||||
input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print_tensor("out_ort", out_ort)
|
||||
print_tensor("ort_add_out", ort_add_out)
|
||||
|
||||
torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config)
|
||||
|
||||
if verbose:
|
||||
print_tensor("torch_out", torch_out)
|
||||
print_tensor("torch_add_out", torch_add_out)
|
||||
|
||||
average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy()))
|
||||
|
||||
is_close = numpy.allclose(
|
||||
out_ort.detach().cpu().numpy(),
|
||||
torch_out.detach().cpu().numpy(),
|
||||
rtol=1e-1 if config.fp16 else 1e-3,
|
||||
atol=1e-1 if config.fp16 else 1e-3,
|
||||
equal_nan=True,
|
||||
)
|
||||
|
||||
is_add_out_close = (
|
||||
numpy.allclose(
|
||||
ort_add_out.detach().cpu().numpy(),
|
||||
torch_add_out.detach().cpu().numpy(),
|
||||
rtol=1e-1 if config.fp16 else 1e-3,
|
||||
atol=1e-1 if config.fp16 else 1e-3,
|
||||
equal_nan=True,
|
||||
)
|
||||
if ort_add_out is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
# Compare results
|
||||
print(
|
||||
config.op_type.name,
|
||||
" B:",
|
||||
config.batch_size,
|
||||
" H:",
|
||||
config.height,
|
||||
" W:",
|
||||
config.width,
|
||||
" C:",
|
||||
config.channels,
|
||||
" G:",
|
||||
config.num_groups,
|
||||
" activation:",
|
||||
int(config.activation),
|
||||
" channels_last:",
|
||||
int(config.channels_last),
|
||||
" fp16:",
|
||||
int(config.fp16),
|
||||
f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "",
|
||||
" AvgDiff:",
|
||||
average_diff,
|
||||
" Pass:",
|
||||
is_close,
|
||||
is_add_out_close,
|
||||
)
|
||||
|
||||
|
||||
def get_latent_height_width():
|
||||
default_size = [(512, 512), (768, 768), (1024, 1024)]
|
||||
small_img_size = [(512, 768), (768, 512)]
|
||||
xl_img_size = [
|
||||
(1152, 896),
|
||||
(896, 1152),
|
||||
(1216, 832),
|
||||
(832, 1216),
|
||||
(1344, 768),
|
||||
(768, 1344),
|
||||
(1536, 640),
|
||||
(640, 1536),
|
||||
]
|
||||
return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size]
|
||||
|
||||
|
||||
def get_channels():
|
||||
return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304]
|
||||
|
||||
|
||||
def run_activation(template: int, fp16, measure_latency=False):
|
||||
print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32")
|
||||
for b in [2]:
|
||||
for h, w in get_latent_height_width():
|
||||
for c in get_channels():
|
||||
config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template)
|
||||
run_parity(config, measure_latency=measure_latency)
|
||||
|
||||
|
||||
def run_no_activation(template: int, fp16, measure_latency=False):
|
||||
print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32")
|
||||
for b in [1, 2, 4]:
|
||||
for h, w in get_latent_height_width():
|
||||
for c in get_channels():
|
||||
config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template)
|
||||
run_parity(config, measure_latency=measure_latency)
|
||||
|
||||
|
||||
def run_all_groups(template: int, fp16, measure_latency=False):
|
||||
group_sizes = [1, 2, 4, 8, 16, 32]
|
||||
print("Test GroupNorm for different group sizes:", group_sizes)
|
||||
for group_size in group_sizes:
|
||||
for h, w in get_latent_height_width()[:3]:
|
||||
for c in get_channels()[:2]:
|
||||
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template)
|
||||
run_parity(config, measure_latency=measure_latency)
|
||||
|
||||
|
||||
def run_odd_channels(template: int, fp16, measure_latency=False):
|
||||
# Test some random number of channels that can be divisible by 2 * num_groups
|
||||
for h, w in get_latent_height_width():
|
||||
for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]:
|
||||
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template)
|
||||
run_parity(config, measure_latency=measure_latency)
|
||||
|
||||
|
||||
def run_small_inputs(template: int, fp16):
|
||||
config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template)
|
||||
run_parity(config, measure_latency=False)
|
||||
|
||||
config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template)
|
||||
run_parity(config, measure_latency=False)
|
||||
|
||||
config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template)
|
||||
run_parity(config, measure_latency=False)
|
||||
|
||||
|
||||
def run_performance(fp16):
|
||||
# Run perf test to tune parameters for given number of channels.
|
||||
for h, w in get_latent_height_width()[:3]:
|
||||
for c in get_channels():
|
||||
config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0)
|
||||
run_parity(config, measure_latency=True)
|
||||
|
||||
|
||||
def run_all(template: int):
|
||||
for fp16 in [True, False]:
|
||||
run_small_inputs(template, fp16)
|
||||
run_odd_channels(template, fp16)
|
||||
run_all_groups(template, fp16)
|
||||
run_activation(template, fp16)
|
||||
run_no_activation(template, fp16)
|
||||
|
||||
|
||||
def run_not_implemented():
|
||||
# Expect failure. Check whether the error message is expected.
|
||||
try:
|
||||
config = GroupNormConfig(1, 2, 2, 513, num_groups=3)
|
||||
run_parity(config)
|
||||
except RuntimeError as e:
|
||||
assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e)
|
||||
|
||||
|
||||
def main():
|
||||
run_performance(True)
|
||||
|
||||
run_not_implemented()
|
||||
|
||||
for template in range(6):
|
||||
run_all(template)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue