[ROCm] add CK GroupNorm to GroupNormTunable (#15510)

- Add CK GroupNorm to GroupNormTunable.
- Reduce configuration of GroupNormNHWCOp because CK implementation is
better.

The performance gain on stable diffusion v1.5.
Before:
```
'height': 512
'width': 512
'steps': 50
'batch_size': 1
'batch_count': 5
'num_prompts': 1
'average_latency': 2.4782688856124877
'median_latency': 2.4783748388290405
'provider': 'ROCMExecutionProvider'
'disable_safety_checker': True 
```

After:
```
'height': 512, 
'width': 512, 
'steps': 50, 
'batch_size': 1,
'batch_count': 5,
'num_prompts': 1, 
'average_latency': 2.107170510292053,
 'median_latency': 2.1067750453948975,
 'first_run_memory_MB': -1, 
'second_run_memory_MB': -1,
'provider': 'ROCMExecutionProvider', 
'disable_safety_checker': True
```
This commit is contained in:
PeixuanZuo 2023-04-19 13:54:59 +08:00 committed by GitHub
parent a66af390fa
commit 59ea35d592
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 675 additions and 65 deletions

View file

@ -1,5 +1,5 @@
set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git)
set(composable_kernel_TAG bef0cb20dba0d9b315df46899310478a81c21852) # 2023-02-16 11:54:08 -0800
set(composable_kernel_TAG ed3a2e52265e11daa366f47b082141a652b67c58) # 2023-04-10 21:02:17 +0800
set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch)

View file

@ -31,7 +31,9 @@ set(contrib_ops_excluded_files
"bert/packed_attention.cc"
"bert/packed_attention_impl.h"
"bert/packed_attention_impl.cu"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"

View file

@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/diffusion/group_norm.h"
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define GROUP_NORM_TYPES float, MLFloat16
ONNX_OPERATOR_KERNEL_EX(
GroupNorm, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
using namespace ONNX_NAMESPACE;
namespace {
template <typename T>
struct DispatchGroupNorm {
Status operator()(RocmTuningContext* tuning_ctx,
hipStream_t stream,
Tensor* output,
const Tensor* input,
const Tensor* gamma,
const Tensor* beta,
void* workspace,
float epsilon,
int batch_size,
int num_channels,
int height,
int width,
int num_groups,
bool use_swish_activation) {
typedef typename ToHipType<T>::MappedType HipT;
return LaunchGroupNormKernel<HipT>(
tuning_ctx,
stream,
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
gamma->Data<float>(),
beta->Data<float>(),
workspace,
epsilon,
batch_size,
num_channels,
height,
width,
num_groups,
use_swish_activation);
}
};
} // namespace
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
ORT_ENFORCE(epsilon_ >= 0);
int64_t num_groups;
ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
ORT_ENFORCE(num_groups >= 0);
num_groups_ = static_cast<int>(num_groups);
int64_t activation;
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
use_swish_activation_ = (activation == 1);
}
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);
const Tensor* beta = context->Input<Tensor>(2);
Tensor* output = context->Output(0, input->Shape());
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 4 dimensions, got ", input_dims.size());
}
const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[3]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in gamma and input does not match");
}
const auto& beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[3]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in beta and input does not match");
}
// Input and output format is NHWC
int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[3]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);
if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(), Stream(context),
output, input, gamma, beta, workspace.get(),
epsilon_,
batch_size,
num_channels,
height,
width,
num_groups_,
use_swish_activation_);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#ifdef USE_COMPOSABLE_KERNEL
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
#endif // USE_COMPOSABLE_KERNEL
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#ifdef USE_COMPOSABLE_KERNEL
template <typename T>
struct DataTypeAdaptor {
using type = T;
};
template <>
struct DataTypeAdaptor<half> {
using type = ck::half_t;
};
using Swish = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
template <typename T, typename AccT, bool WithSwish>
auto GetCKGroupNormNHWCTypeStringAndOps() {
using InDataType = typename DataTypeAdaptor<T>::type;
using OutDataType = typename DataTypeAdaptor<T>::type;
using AccDataType = typename DataTypeAdaptor<AccT>::type;
using GammaDataType = float;
using BetaDataType = float;
using Activation = std::conditional_t<WithSwish, Swish, Pass>;
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
for (auto&& impl : internal::GetDeviceGroupNormInstances<InDataType, GammaDataType, BetaDataType, AccDataType,
OutDataType, Activation, Rank, NumReduceDim>()) {
std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
auto invoker = impl->MakeInvokerPointer();
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
if constexpr (WithSwish) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->withSwish, "Swish version only support groupnorm with swish");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->withSwish, "Pass version only support groupnorm without swish");
}
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
std::vector<ck::index_t> reduce_dims{1, 2, 4};
auto activation = Activation{};
auto arg = impl->MakeArgumentPointer(in_lengths, // lengths
in_out_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
in_out_strides, // yStrides
reduce_dims, // reduceDims
params->epsilon,
params->src,
params->gamma,
params->beta,
params->dst,
nullptr,
nullptr,
activation);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op)));
}
return ret;
}
#endif // USE_COMPOSABLE_KERNEL
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// Modifications Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef USE_COMPOSABLE_KERNEL
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/utility/data_type.hpp"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace internal {
using F16 = ck::half_t;
using F32 = float;
using Swish = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;
using ck::tensor_operation::device::DeviceNormalization; // the interface
using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation
template <typename OutElementwise, ck::index_t Rank, ck::index_t Reduce>
using device_normalization_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, ck::index_t Rank, ck::index_t Reduce>
using device_normalization_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
// Use this function to get implementation
template <typename InDataType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename OutDataType,
typename YElementwiseOperation,
ck::index_t Rank,
ck::index_t NumReduceDim>
std::vector<std::unique_ptr<DeviceNormalization<InDataType,
GammaDataType,
BetaDataType,
AccDataType,
OutDataType,
YElementwiseOperation,
Rank,
NumReduceDim>>>
GetDeviceGroupNormInstances() {
return {};
}
template <>
std::vector<std::unique_ptr<DeviceNormalization<
F16, F32, F32, F32, F16, Swish, 5, 3>>>
GetDeviceGroupNormInstances<
F16, F32, F32, F32, F16, Swish, 5, 3>();
template <>
std::vector<std::unique_ptr<DeviceNormalization<
F16, F32, F32, F32, F16, Pass, 5, 3>>>
GetDeviceGroupNormInstances<
F16, F32, F32, F32, F16, Pass, 5, 3>();
template <>
std::vector<std::unique_ptr<DeviceNormalization<
F32, F32, F32, F32, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<
F32, F32, F32, F32, F32, Swish, 5, 3>();
template <>
std::vector<std::unique_ptr<DeviceNormalization<
F32, F32, F32, F32, F32, Pass, 5, 3>>>
GetDeviceGroupNormInstances<
F32, F32, F32, F32, F32, Pass, 5, 3>();
} // namespace internal
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#endif // USE_COMPOSABLE_KERNEL

View file

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_COMPOSABLE_KERNEL
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace internal {
template <>
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F32, F16, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f16_instances<Swish, 5, 3>{});
return instances;
}
template <>
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Pass, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F32, F16, Pass, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Pass, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f16_instances<Pass, 5, 3>{});
return instances;
}
} // namespace internal
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#endif // USE_COMPOSABLE_KERNEL

View file

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_COMPOSABLE_KERNEL
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace internal {
template <>
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f32_instances<Swish, 5, 3>{});
return instances;
}
template <>
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Pass, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f32_instances<Pass, 5, 3>{});
return instances;
}
} // namespace internal
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#endif // USE_COMPOSABLE_KERNEL

View file

@ -65,7 +65,8 @@ struct GroupNormNHWCParams : OpParams {
}
std::string Signature() const override {
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups);
std::string swish_suffix = withSwish ? "_Swish" : "_Pass";
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix;
return sig;
}

View file

@ -14,6 +14,7 @@ namespace rocm {
template <typename T>
Status LaunchGroupNormKernel(
RocmTuningContext* tuning_ctx,
hipStream_t stream,
T* output,
const T* input,
@ -37,18 +38,23 @@ Status LaunchGroupNormKernel(
"only num_groups=32 is supported. Got", num_groups);
}
GroupNormNHWCParams<T> params(nullptr, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
GroupNormNHWCParams<T> params(tuning_ctx, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation);
if (tuning_ctx->IsTunableOpEnabled()) {
static GroupNormNHWCTunableOp<T> op;
return op(&params);
}
return GroupNormNHWCStaticSelection(&params);
}
template Status LaunchGroupNormKernel<half>(hipStream_t stream, half* output,
template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, hipStream_t stream, half* output,
const half* input, const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
template Status LaunchGroupNormKernel<float>(hipStream_t stream, float* output,
template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, hipStream_t stream, float* output,
const float* input, const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);

View file

@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
#include <hip/hip_runtime.h>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
using onnxruntime::rocm::tunable::RocmTuningContext;
namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr size_t kMaxGroupNormBatchSize = 32;
constexpr size_t kGroupNormNumberOfGroups = 32;
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
// Two buffers for sum and squared sum
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
}
template <typename T>
Status LaunchGroupNormKernel(
RocmTuningContext* tuning_ctx,
hipStream_t stream,
T* output, // normalized output tensor
const T* input, // input tensor
const float* gamma, // gamma (also known as weight or scale)
const float* beta, // beta (also known as bias)
void* workspace, // Work space
float epsilon, // epsilon used normalization
int batch_size, // N
int num_channels, // C
int height, // H
int width, // W
int num_groups, // number of groups
bool use_swish_activation // Whether there is Swish activation after group normalization
);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -7,6 +7,7 @@
#include <hip/hip_runtime_api.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh"
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh"
@ -136,12 +137,16 @@ class GroupNormNHWCOp {
Status IsSupported(const GroupNormNHWCParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0));
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0),
"The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup,
") isn't divisible by the number of vector size: ", VecSize);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 &&
params->c % params->cPerBlock == 0 &&
params->hw % params->hwPerBlock == 0));
params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0),
"The value of attributes don't meet the requirements.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize &&
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize));
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
"Configuration: Threads (", ThreadsPerBlock, "), vector size (",
VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock);
return Status::OK();
}
@ -160,19 +165,14 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
this->RegisterOp(name<T, threads_per_block, 1>{}); \
this->RegisterOp(name<T, threads_per_block, 2>{}); \
this->RegisterOp(name<T, threads_per_block, 4>{}); \
this->RegisterOp(name<T, threads_per_block, 8>{}); \
this->RegisterOp(name<T, threads_per_block, 16>{});
this->RegisterOp(name<T, threads_per_block, 4>{});
#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 512)
ADD_OP_FOR_ALL_VEC_SIZE(name, 320)
template <typename T>
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
@ -180,6 +180,18 @@ class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
GroupNormNHWCTunableOp() {
this->RegisterOp(GroupNormNHWCStaticSelection<T>);
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp)
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/false>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/true>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif // USE_COMPOSABLE_KERNEL
}
};

View file

@ -11,12 +11,12 @@ from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_bytes
from utils import dtype_to_bytes, dtype_to_suffix
def get_sd_sizes():
batch_sizes = [1, 2]
height = [8, 16, 32, 64]
height = [8, 16, 32]
num_channels = [320, 640, 1280, 1920, 2560]
num_groups = [32]
@ -25,8 +25,8 @@ def get_sd_sizes():
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: re.search("GroupNormNHWC.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.search("GroupNormNHWC.*_float", x), dir(ke))),
"float16": list(filter(lambda x: re.match("GroupNormNHWC.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.match("GroupNormNHWC.*_float", x), dir(ke))),
}
return type_map[dtype]
@ -52,7 +52,7 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
return x
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, func):
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func):
np.random.seed(0)
width = height
input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
@ -62,7 +62,7 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32)
epsilon = 1e-05
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
use_swish = True
use_swish = swish
host_x = input_x.astype(dtype)
input_d = ke.DeviceArray(host_x)
@ -86,12 +86,16 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
epsilon,
use_swish,
)
if my_op.IsSupported():
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
for impl in my_op.ListOps():
if not my_op.SelectOp(impl):
continue
my_op.Run()
y_d.UpdateHostNumpyArray()
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
np.testing.assert_allclose(y_ref, output_y, atol=1e-02)
@ -100,9 +104,19 @@ dtypes = ["float32", "float16"]
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
@pytest.mark.parametrize("dtype", dtypes)
def test_skip_layer_norm(sd_sizes, dtype):
@pytest.mark.parametrize("swish", [True])
def test_group_norm(sd_sizes, dtype, swish):
for func in dtype_to_funcs(dtype):
run_group_norm(*sd_sizes, dtype, func)
run_group_norm(*sd_sizes, dtype, swish, func)
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("swish", [True])
def test_group_norm_ck(sd_sizes, dtype, swish):
swish_suffix = "Swish" if swish else "Pass"
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
run_group_norm(*sd_sizes, dtype, swish, ck_f_name)
@dataclass
@ -124,7 +138,7 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
def profile_group_norm_func(
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, func
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func
):
np.random.seed(0)
input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
@ -133,7 +147,7 @@ def profile_group_norm_func(
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32)
epsilon = 0.05
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
use_swish = True
use_swish = swish
input_d = ke.DeviceArray(input_x)
gamma_d = ke.DeviceArray(gamma)
@ -156,21 +170,27 @@ def profile_group_norm_func(
epsilon,
use_swish,
)
for impl in my_op.ListOps():
duration_ms = -1
if my_op.SelectOp(impl):
duration_ms = my_op.Profile()
total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype)
duration_ms = -1
if my_op.IsSupported():
duration_ms = my_op.Profile()
total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype)
ke.report(
GroupNormNHWCMetric(func, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups)
)
ke.report(
GroupNormNHWCMetric(
impl, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups
)
)
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, sort=True):
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, func)
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func)
# ck function
swish_suffix = "Swish" if swish else "Pass"
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name)
sd_profile_sizes = [
@ -209,6 +229,7 @@ if __name__ == "__main__":
group.add_argument("num_channels", type=int)
group.add_argument("num_groups", type=int)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--swish", action="store_true")
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:
@ -216,5 +237,12 @@ if __name__ == "__main__":
else:
args = parser.parse_args()
profile_with_args(
args.batch_size, args.height, args.width, args.num_channels, args.num_groups, args.dtype, args.sort
args.batch_size,
args.height,
args.width,
args.num_channels,
args.num_groups,
args.dtype,
args.swish,
args.sort,
)

View file

@ -1,10 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <stdio.h>
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh"
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h"
#include "python/tools/kernel_explorer/device_array.h"
@ -21,21 +22,28 @@ class GroupNormNHWC : public IKernelExplorer {
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {}
bool IsSupported() {
Status status = op_.IsSupported(&params_);
return status.IsOK();
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize);
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
}
std::vector<std::string> ListOps() const {
return {type_string_};
}
bool SelectOp(const std::string& name) {
Status status = op_.IsSupported(&params_);
return status.IsOK() && name == type_string_;
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
ParamsT params_{};
contrib::rocm::GroupNormNHWCOp<T, ThreadsPerBlock, VecSize> op_{};
std::string type_string_{};
};
template <typename T>
@ -45,20 +53,27 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {}
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
type_string_ = "GroupNormNHWCStaticSelection";
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::GroupNormNHWCStaticSelection<T>(&params_)));
}
bool IsSupported() {
std::vector<std::string> ListOps() const {
return {type_string_};
}
bool SelectOp(const std::string& name) {
Status status = contrib::rocm::GroupNormNHWCStaticSelection<T>(&params_);
return status.IsOK();
return status.IsOK() && name == type_string_;
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
ParamsT params_{};
std::string type_string_{};
};
template <typename T>
@ -69,15 +84,19 @@ class GroupNormNHWCTunable : public IKernelExplorer {
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
}
bool IsSupported() {
return true;
std::vector<std::string> ListOps() const {
return {"GroupNormNHWCTunable"};
}
bool SelectOp(const std::string& name) {
return name == "GroupNormNHWCTunable";
}
private:
@ -86,6 +105,51 @@ class GroupNormNHWCTunable : public IKernelExplorer {
contrib::rocm::GroupNormNHWCTunableOp<T> op_{};
};
#ifdef USE_COMPOSABLE_KERNEL
template <typename T, bool WithSwish>
class CKGroupNormNHWC : public IKernelExplorer {
public:
CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSwish>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
}
void Run() override {
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return type_strings_;
}
bool SelectOp(const std::string& name) {
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_{};
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
#endif // USE_COMPOSABLE_KERNEL
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
@ -93,33 +157,36 @@ class GroupNormNHWCTunable : public IKernelExplorer {
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
.def("ListOps", &name<type, threads_per_block, vec_size>::ListOps) \
.def("SelectOp", &name<type, threads_per_block, vec_size>::SelectOp);
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
REGISTER_OP(name, type, threads_per_block, 1) \
REGISTER_OP(name, type, threads_per_block, 2) \
REGISTER_OP(name, type, threads_per_block, 4) \
REGISTER_OP(name, type, threads_per_block, 8) \
REGISTER_OP(name, type, threads_per_block, 16)
REGISTER_OP(name, type, threads_per_block, 4)
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512)
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320)
#define REGISTER_OP_TYPED(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
#define REGISTER_COMMON(name, type, ...) \
py::class_<type<__VA_ARGS__>>(m, name) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
int, int, int, int, int, float, bool>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run) \
.def("IsSupported", &name<type>::IsSupported);
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
.def("Profile", &type<__VA_ARGS__>::Profile) \
.def("Run", &type<__VA_ARGS__>::Run) \
.def("ListOps", &type<__VA_ARGS__>::ListOps) \
.def("SelectOp", &type<__VA_ARGS__>::SelectOp);
#define REGISTER_OP_TYPED(name, type) \
REGISTER_COMMON(#name "_" #type, name, type)
#define REGISTER_CK(type, with_swish, swish_suffix) \
REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish)
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half);
@ -130,6 +197,13 @@ KE_REGISTER(m) {
REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, half);
REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, float);
#ifdef USE_COMPOSABLE_KERNEL
REGISTER_CK(half, false, "Pass");
REGISTER_CK(half, true, "Swish");
REGISTER_CK(float, false, "Pass");
REGISTER_CK(float, true, "Swish");
#endif // USE_COMPOSABLE_KERNEL
}
} // namespace onnxruntime