mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[ROCm] add CK GroupNorm to GroupNormTunable (#15510)
- Add CK GroupNorm to GroupNormTunable. - Reduce configuration of GroupNormNHWCOp because CK implementation is better. The performance gain on stable diffusion v1.5. Before: ``` 'height': 512 'width': 512 'steps': 50 'batch_size': 1 'batch_count': 5 'num_prompts': 1 'average_latency': 2.4782688856124877 'median_latency': 2.4783748388290405 'provider': 'ROCMExecutionProvider' 'disable_safety_checker': True ``` After: ``` 'height': 512, 'width': 512, 'steps': 50, 'batch_size': 1, 'batch_count': 5, 'num_prompts': 1, 'average_latency': 2.107170510292053, 'median_latency': 2.1067750453948975, 'first_run_memory_MB': -1, 'second_run_memory_MB': -1, 'provider': 'ROCMExecutionProvider', 'disable_safety_checker': True ```
This commit is contained in:
parent
a66af390fa
commit
59ea35d592
13 changed files with 675 additions and 65 deletions
2
cmake/external/composable_kernel.cmake
vendored
2
cmake/external/composable_kernel.cmake
vendored
|
|
@ -1,5 +1,5 @@
|
|||
set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git)
|
||||
set(composable_kernel_TAG bef0cb20dba0d9b315df46899310478a81c21852) # 2023-02-16 11:54:08 -0800
|
||||
set(composable_kernel_TAG ed3a2e52265e11daa366f47b082141a652b67c58) # 2023-04-10 21:02:17 +0800
|
||||
|
||||
set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ set(contrib_ops_excluded_files
|
|||
"bert/packed_attention.cc"
|
||||
"bert/packed_attention_impl.h"
|
||||
"bert/packed_attention_impl.cu"
|
||||
"diffusion/group_norm.cc"
|
||||
"diffusion/group_norm_impl.cu"
|
||||
"diffusion/group_norm_impl.h"
|
||||
"diffusion/nhwc_conv.cc"
|
||||
"math/complex_mul.cc"
|
||||
"math/complex_mul.h"
|
||||
|
|
|
|||
132
onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc
Normal file
132
onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
#define GROUP_NORM_TYPES float, MLFloat16
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GroupNorm, kMSDomain, 1, kRocmExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DispatchGroupNorm {
|
||||
Status operator()(RocmTuningContext* tuning_ctx,
|
||||
hipStream_t stream,
|
||||
Tensor* output,
|
||||
const Tensor* input,
|
||||
const Tensor* gamma,
|
||||
const Tensor* beta,
|
||||
void* workspace,
|
||||
float epsilon,
|
||||
int batch_size,
|
||||
int num_channels,
|
||||
int height,
|
||||
int width,
|
||||
int num_groups,
|
||||
bool use_swish_activation) {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
return LaunchGroupNormKernel<HipT>(
|
||||
tuning_ctx,
|
||||
stream,
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()),
|
||||
gamma->Data<float>(),
|
||||
beta->Data<float>(),
|
||||
workspace,
|
||||
epsilon,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_groups,
|
||||
use_swish_activation);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
|
||||
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
|
||||
int64_t num_groups;
|
||||
ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
|
||||
ORT_ENFORCE(num_groups >= 0);
|
||||
num_groups_ = static_cast<int>(num_groups);
|
||||
|
||||
int64_t activation;
|
||||
ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
|
||||
ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
|
||||
use_swish_activation_ = (activation == 1);
|
||||
}
|
||||
|
||||
Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* gamma = context->Input<Tensor>(1);
|
||||
const Tensor* beta = context->Input<Tensor>(2);
|
||||
Tensor* output = context->Output(0, input->Shape());
|
||||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 4 dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != input_dims[3]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in gamma and input does not match");
|
||||
}
|
||||
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != input_dims[3]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Number of channels in beta and input does not match");
|
||||
}
|
||||
|
||||
// Input and output format is NHWC
|
||||
int batch_size = static_cast<int>(input_dims[0]);
|
||||
int num_channels = static_cast<int>(input_dims[3]);
|
||||
int height = static_cast<int>(input_dims[1]);
|
||||
int width = static_cast<int>(input_dims[2]);
|
||||
|
||||
if (num_channels % num_groups_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"number of channels should be divisiable by num_groups");
|
||||
}
|
||||
|
||||
auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
|
||||
|
||||
utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
|
||||
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(), Stream(context),
|
||||
output, input, gamma, beta, workspace.get(),
|
||||
epsilon_,
|
||||
batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
num_groups_,
|
||||
use_swish_activation_);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
100
onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
Normal file
100
onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeAdaptor {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeAdaptor<half> {
|
||||
using type = ck::half_t;
|
||||
};
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
template <typename T, typename AccT, bool WithSwish>
|
||||
auto GetCKGroupNormNHWCTypeStringAndOps() {
|
||||
using InDataType = typename DataTypeAdaptor<T>::type;
|
||||
using OutDataType = typename DataTypeAdaptor<T>::type;
|
||||
using AccDataType = typename DataTypeAdaptor<AccT>::type;
|
||||
using GammaDataType = float;
|
||||
using BetaDataType = float;
|
||||
|
||||
using Activation = std::conditional_t<WithSwish, Swish, Pass>;
|
||||
|
||||
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
|
||||
for (auto&& impl : internal::GetDeviceGroupNormInstances<InDataType, GammaDataType, BetaDataType, AccDataType,
|
||||
OutDataType, Activation, Rank, NumReduceDim>()) {
|
||||
std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
|
||||
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
|
||||
auto invoker = impl->MakeInvokerPointer();
|
||||
|
||||
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
|
||||
if constexpr (WithSwish) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!params->withSwish, "Swish version only support groupnorm with swish");
|
||||
} else {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->withSwish, "Pass version only support groupnorm without swish");
|
||||
}
|
||||
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
|
||||
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
|
||||
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
|
||||
std::vector<ck::index_t> reduce_dims{1, 2, 4};
|
||||
|
||||
auto activation = Activation{};
|
||||
|
||||
auto arg = impl->MakeArgumentPointer(in_lengths, // lengths
|
||||
in_out_strides, // xStrides
|
||||
gamma_beta_strides, // gammaStrides
|
||||
gamma_beta_strides, // betaStrides
|
||||
in_out_strides, // yStrides
|
||||
reduce_dims, // reduceDims
|
||||
params->epsilon,
|
||||
params->src,
|
||||
params->gamma,
|
||||
params->beta,
|
||||
params->dst,
|
||||
nullptr,
|
||||
nullptr,
|
||||
activation);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
|
||||
impl->GetTypeString(), " does not support ", params->Signature());
|
||||
invoker->Run(arg.get(), StreamConfig{params->stream});
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// Modifications Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
namespace internal {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ck::tensor_operation::device::DeviceNormalization; // the interface
|
||||
using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation
|
||||
|
||||
template <typename OutElementwise, ck::index_t Rank, ck::index_t Reduce>
|
||||
using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, ck::index_t Rank, ck::index_t Reduce>
|
||||
using device_normalization_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Use this function to get implementation
|
||||
template <typename InDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename YElementwiseOperation,
|
||||
ck::index_t Rank,
|
||||
ck::index_t NumReduceDim>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<InDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>>
|
||||
GetDeviceGroupNormInstances() {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<
|
||||
F16, F32, F32, F32, F16, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F16, F32, F32, F32, F16, Swish, 5, 3>();
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<
|
||||
F16, F32, F32, F32, F16, Pass, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F16, F32, F32, F32, F16, Pass, 5, 3>();
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<
|
||||
F32, F32, F32, F32, F32, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F32, F32, F32, F32, F32, Swish, 5, 3>();
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<
|
||||
F32, F32, F32, F32, F32, Pass, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<
|
||||
F32, F32, F32, F32, F32, Pass, 5, 3>();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F16, F32, F32, F32, F16, Swish, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f16_instances<Swish, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Pass, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F16, F32, F32, F32, F16, Pass, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Pass, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f16_instances<Pass, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f32_instances<Swish, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>
|
||||
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Pass, 5, 3>() {
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>> instances;
|
||||
ck::tensor_operation::device::instance::add_device_operation_instances(
|
||||
instances,
|
||||
device_normalization_f32_instances<Pass, 5, 3>{});
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
|
@ -65,7 +65,8 @@ struct GroupNormNHWCParams : OpParams {
|
|||
}
|
||||
|
||||
std::string Signature() const override {
|
||||
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups);
|
||||
std::string swish_suffix = withSwish ? "_Swish" : "_Pass";
|
||||
std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix;
|
||||
return sig;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ namespace rocm {
|
|||
|
||||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
RocmTuningContext* tuning_ctx,
|
||||
hipStream_t stream,
|
||||
T* output,
|
||||
const T* input,
|
||||
|
|
@ -37,18 +38,23 @@ Status LaunchGroupNormKernel(
|
|||
"only num_groups=32 is supported. Got", num_groups);
|
||||
}
|
||||
|
||||
GroupNormNHWCParams<T> params(nullptr, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
|
||||
GroupNormNHWCParams<T> params(tuning_ctx, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation);
|
||||
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
static GroupNormNHWCTunableOp<T> op;
|
||||
return op(¶ms);
|
||||
}
|
||||
|
||||
return GroupNormNHWCStaticSelection(¶ms);
|
||||
}
|
||||
|
||||
template Status LaunchGroupNormKernel<half>(hipStream_t stream, half* output,
|
||||
template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, hipStream_t stream, half* output,
|
||||
const half* input, const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
|
||||
template Status LaunchGroupNormKernel<float>(hipStream_t stream, float* output,
|
||||
template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, hipStream_t stream, float* output,
|
||||
const float* input, const float* gamma, const float* beta, void* workspace,
|
||||
float epsilon, int batch_size, int num_channels,
|
||||
int height, int width, int num_groups, bool swish);
|
||||
|
|
|
|||
47
onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h
Normal file
47
onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
using onnxruntime::rocm::tunable::RocmTuningContext;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
constexpr size_t kMaxGroupNormBatchSize = 32;
|
||||
constexpr size_t kGroupNormNumberOfGroups = 32;
|
||||
|
||||
constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
|
||||
// Two buffers for sum and squared sum
|
||||
return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGroupNormKernel(
|
||||
RocmTuningContext* tuning_ctx,
|
||||
hipStream_t stream,
|
||||
T* output, // normalized output tensor
|
||||
const T* input, // input tensor
|
||||
const float* gamma, // gamma (also known as weight or scale)
|
||||
const float* beta, // beta (also known as bias)
|
||||
void* workspace, // Work space
|
||||
float epsilon, // epsilon used normalization
|
||||
int batch_size, // N
|
||||
int num_channels, // C
|
||||
int height, // H
|
||||
int width, // W
|
||||
int num_groups, // number of groups
|
||||
bool use_swish_activation // Whether there is Swish activation after group normalization
|
||||
);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
#include <hip/hip_runtime_api.h>
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh"
|
||||
|
|
@ -136,12 +137,16 @@ class GroupNormNHWCOp {
|
|||
|
||||
Status IsSupported(const GroupNormNHWCParams<T>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0));
|
||||
!(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0),
|
||||
"The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup,
|
||||
") isn't divisible by the number of vector size: ", VecSize);
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 &&
|
||||
params->c % params->cPerBlock == 0 &&
|
||||
params->hw % params->hwPerBlock == 0));
|
||||
params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0),
|
||||
"The value of attributes don't meet the requirements.");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize &&
|
||||
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize));
|
||||
params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
|
||||
"Configuration: Threads (", ThreadsPerBlock, "), vector size (",
|
||||
VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -160,19 +165,14 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
|
|||
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
|
||||
this->RegisterOp(name<T, threads_per_block, 1>{}); \
|
||||
this->RegisterOp(name<T, threads_per_block, 2>{}); \
|
||||
this->RegisterOp(name<T, threads_per_block, 4>{}); \
|
||||
this->RegisterOp(name<T, threads_per_block, 8>{}); \
|
||||
this->RegisterOp(name<T, threads_per_block, 16>{});
|
||||
this->RegisterOp(name<T, threads_per_block, 4>{});
|
||||
|
||||
#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 512)
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 320)
|
||||
|
||||
template <typename T>
|
||||
class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
|
||||
|
|
@ -180,6 +180,18 @@ class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
|
|||
GroupNormNHWCTunableOp() {
|
||||
this->RegisterOp(GroupNormNHWCStaticSelection<T>);
|
||||
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp)
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/false>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
|
||||
for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/true>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ from itertools import product
|
|||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
from utils import dtype_to_bytes
|
||||
from utils import dtype_to_bytes, dtype_to_suffix
|
||||
|
||||
|
||||
def get_sd_sizes():
|
||||
batch_sizes = [1, 2]
|
||||
height = [8, 16, 32, 64]
|
||||
height = [8, 16, 32]
|
||||
num_channels = [320, 640, 1280, 1920, 2560]
|
||||
|
||||
num_groups = [32]
|
||||
|
|
@ -25,8 +25,8 @@ def get_sd_sizes():
|
|||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: re.search("GroupNormNHWC.*_half", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.search("GroupNormNHWC.*_float", x), dir(ke))),
|
||||
"float16": list(filter(lambda x: re.match("GroupNormNHWC.*_half", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.match("GroupNormNHWC.*_float", x), dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
|
|||
return x
|
||||
|
||||
|
||||
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, func):
|
||||
def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func):
|
||||
np.random.seed(0)
|
||||
width = height
|
||||
input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
|
||||
|
|
@ -62,7 +62,7 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
|
|||
workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32)
|
||||
epsilon = 1e-05
|
||||
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
use_swish = True
|
||||
use_swish = swish
|
||||
|
||||
host_x = input_x.astype(dtype)
|
||||
input_d = ke.DeviceArray(host_x)
|
||||
|
|
@ -86,12 +86,16 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
|
|||
epsilon,
|
||||
use_swish,
|
||||
)
|
||||
if my_op.IsSupported():
|
||||
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
|
||||
|
||||
for impl in my_op.ListOps():
|
||||
if not my_op.SelectOp(impl):
|
||||
continue
|
||||
|
||||
my_op.Run()
|
||||
|
||||
y_d.UpdateHostNumpyArray()
|
||||
|
||||
y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
|
||||
np.testing.assert_allclose(y_ref, output_y, atol=1e-02)
|
||||
|
||||
|
||||
|
|
@ -100,9 +104,19 @@ dtypes = ["float32", "float16"]
|
|||
|
||||
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_skip_layer_norm(sd_sizes, dtype):
|
||||
@pytest.mark.parametrize("swish", [True])
|
||||
def test_group_norm(sd_sizes, dtype, swish):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
run_group_norm(*sd_sizes, dtype, func)
|
||||
run_group_norm(*sd_sizes, dtype, swish, func)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_sizes", get_sd_sizes())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("swish", [True])
|
||||
def test_group_norm_ck(sd_sizes, dtype, swish):
|
||||
swish_suffix = "Swish" if swish else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
|
||||
run_group_norm(*sd_sizes, dtype, swish, ck_f_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -124,7 +138,7 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
|
|||
|
||||
|
||||
def profile_group_norm_func(
|
||||
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, func
|
||||
batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func
|
||||
):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
|
|
@ -133,7 +147,7 @@ def profile_group_norm_func(
|
|||
workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32)
|
||||
epsilon = 0.05
|
||||
output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
|
||||
use_swish = True
|
||||
use_swish = swish
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
gamma_d = ke.DeviceArray(gamma)
|
||||
|
|
@ -156,21 +170,27 @@ def profile_group_norm_func(
|
|||
epsilon,
|
||||
use_swish,
|
||||
)
|
||||
for impl in my_op.ListOps():
|
||||
duration_ms = -1
|
||||
if my_op.SelectOp(impl):
|
||||
duration_ms = my_op.Profile()
|
||||
total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype)
|
||||
|
||||
duration_ms = -1
|
||||
if my_op.IsSupported():
|
||||
duration_ms = my_op.Profile()
|
||||
total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype)
|
||||
|
||||
ke.report(
|
||||
GroupNormNHWCMetric(func, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups)
|
||||
)
|
||||
ke.report(
|
||||
GroupNormNHWCMetric(
|
||||
impl, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, sort=True):
|
||||
def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, func)
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func)
|
||||
# ck function
|
||||
swish_suffix = "Swish" if swish else "Pass"
|
||||
ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
|
||||
profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name)
|
||||
|
||||
|
||||
sd_profile_sizes = [
|
||||
|
|
@ -209,6 +229,7 @@ if __name__ == "__main__":
|
|||
group.add_argument("num_channels", type=int)
|
||||
group.add_argument("num_groups", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--swish", action="store_true")
|
||||
group.add_argument("--sort", action="store_true")
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
|
|
@ -216,5 +237,12 @@ if __name__ == "__main__":
|
|||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(
|
||||
args.batch_size, args.height, args.width, args.num_channels, args.num_groups, args.dtype, args.sort
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.num_channels,
|
||||
args.num_groups,
|
||||
args.dtype,
|
||||
args.swish,
|
||||
args.sort,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_common.h"
|
||||
#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h"
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
|
|
@ -21,21 +22,28 @@ class GroupNormNHWC : public IKernelExplorer {
|
|||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = op_.IsSupported(¶ms_);
|
||||
return status.IsOK();
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize);
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
}
|
||||
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {type_string_};
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
Status status = op_.IsSupported(¶ms_);
|
||||
return status.IsOK() && name == type_string_;
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::GroupNormNHWCOp<T, ThreadsPerBlock, VecSize> op_{};
|
||||
std::string type_string_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -45,20 +53,27 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
|
|||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {}
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
type_string_ = "GroupNormNHWCStaticSelection";
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::GroupNormNHWCStaticSelection<T>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {type_string_};
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
Status status = contrib::rocm::GroupNormNHWCStaticSelection<T>(¶ms_);
|
||||
return status.IsOK();
|
||||
return status.IsOK() && name == type_string_;
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
ParamsT params_{};
|
||||
std::string type_string_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -69,15 +84,19 @@ class GroupNormNHWCTunable : public IKernelExplorer {
|
|||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
params_.TuningContext()->EnableTunableOp();
|
||||
params_.TuningContext()->EnableTunableOpAndTuning();
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
return true;
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {"GroupNormNHWCTunable"};
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
return name == "GroupNormNHWCTunable";
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -86,6 +105,51 @@ class GroupNormNHWCTunable : public IKernelExplorer {
|
|||
contrib::rocm::GroupNormNHWCTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
template <typename T, bool WithSwish>
|
||||
class CKGroupNormNHWC : public IKernelExplorer {
|
||||
public:
|
||||
CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
|
||||
int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
|
||||
batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
|
||||
for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSwish>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_));
|
||||
}
|
||||
|
||||
std::vector<std::string> ListOps() const {
|
||||
return type_strings_;
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
for (size_t i = 0; i < ops_.size(); i++) {
|
||||
if (type_strings_[i] == name) {
|
||||
selected_op_ = i;
|
||||
Status status = ops_[i](¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
}
|
||||
|
||||
ORT_THROW("Cannot find implementation ", name);
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
|
||||
using OpT = rocm::tunable::Op<ParamsT>;
|
||||
ParamsT params_{};
|
||||
std::vector<OpT> ops_;
|
||||
std::vector<std::string> type_strings_;
|
||||
size_t selected_op_{};
|
||||
};
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
|
|
@ -93,33 +157,36 @@ class GroupNormNHWCTunable : public IKernelExplorer {
|
|||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
.def("ListOps", &name<type, threads_per_block, vec_size>::ListOps) \
|
||||
.def("SelectOp", &name<type, threads_per_block, vec_size>::SelectOp);
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
REGISTER_OP(name, type, threads_per_block, 1) \
|
||||
REGISTER_OP(name, type, threads_per_block, 2) \
|
||||
REGISTER_OP(name, type, threads_per_block, 4) \
|
||||
REGISTER_OP(name, type, threads_per_block, 8) \
|
||||
REGISTER_OP(name, type, threads_per_block, 16)
|
||||
REGISTER_OP(name, type, threads_per_block, 4)
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512)
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320)
|
||||
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
#define REGISTER_COMMON(name, type, ...) \
|
||||
py::class_<type<__VA_ARGS__>>(m, name) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
int, int, int, int, int, float, bool>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("Run", &name<type>::Run) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
|
||||
.def("Profile", &type<__VA_ARGS__>::Profile) \
|
||||
.def("Run", &type<__VA_ARGS__>::Run) \
|
||||
.def("ListOps", &type<__VA_ARGS__>::ListOps) \
|
||||
.def("SelectOp", &type<__VA_ARGS__>::SelectOp);
|
||||
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
REGISTER_COMMON(#name "_" #type, name, type)
|
||||
|
||||
#define REGISTER_CK(type, with_swish, swish_suffix) \
|
||||
REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish)
|
||||
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half);
|
||||
|
|
@ -130,6 +197,13 @@ KE_REGISTER(m) {
|
|||
|
||||
REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, half);
|
||||
REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, float);
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
REGISTER_CK(half, false, "Pass");
|
||||
REGISTER_CK(half, true, "Swish");
|
||||
REGISTER_CK(float, false, "Pass");
|
||||
REGISTER_CK(float, true, "Swish");
|
||||
#endif // USE_COMPOSABLE_KERNEL
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue