diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index b2331582f9..fe57a5b532 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,5 +1,5 @@ set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git) -set(composable_kernel_TAG bef0cb20dba0d9b315df46899310478a81c21852) # 2023-02-16 11:54:08 -0800 +set(composable_kernel_TAG ed3a2e52265e11daa366f47b082141a652b67c58) # 2023-04-10 21:02:17 +0800 set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 7bf641d151..df2d68a6a4 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -31,7 +31,9 @@ set(contrib_ops_excluded_files "bert/packed_attention.cc" "bert/packed_attention_impl.h" "bert/packed_attention_impl.cu" + "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" + "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" "math/complex_mul.cc" "math/complex_mul.h" diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc new file mode 100644 index 0000000000..9b0e6251b3 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm.h" +#include "contrib_ops/rocm/diffusion/group_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define GROUP_NORM_TYPES float, MLFloat16 + +ONNX_OPERATOR_KERNEL_EX( + GroupNorm, kMSDomain, 1, kRocmExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + +using namespace ONNX_NAMESPACE; + +namespace { +template +struct DispatchGroupNorm { + Status operator()(RocmTuningContext* tuning_ctx, + hipStream_t stream, + Tensor* output, + const Tensor* input, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + typedef typename ToHipType::MappedType HipT; + return LaunchGroupNormKernel( + tuning_ctx, + stream, + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation); + } +}; + +} // namespace + +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); + ORT_ENFORCE(epsilon_ >= 0); + + int64_t num_groups; + ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); + ORT_ENFORCE(num_groups >= 0); + num_groups_ = static_cast(num_groups); + + int64_t activation; + ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); + ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish + use_swish_activation_ = (activation == 1); +} + +Status GroupNorm::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* gamma = context->Input(1); + const Tensor* beta = context->Input(2); + Tensor* output = context->Output(0, input->Shape()); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 4 dimensions, got ", input_dims.size()); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in gamma and input does not match"); + } + + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in beta and input does not match"); + } + + // Input and output format is NHWC + int batch_size = static_cast(input_dims[0]); + int num_channels = static_cast(input_dims[3]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + + utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); + return dispatcher.InvokeRet(GetTuningContext(), Stream(context), + output, input, gamma, beta, workspace.get(), + epsilon_, + batch_size, + num_channels, + height, + width, + num_groups_, + use_swish_activation_); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh new file mode 100644 index 0000000000..4ce7b1284d --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#ifdef USE_COMPOSABLE_KERNEL +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#endif // USE_COMPOSABLE_KERNEL + +#include "contrib_ops/rocm/diffusion/group_norm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#ifdef USE_COMPOSABLE_KERNEL + +template +struct DataTypeAdaptor { + using type = T; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::half_t; +}; + +using Swish = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +template +auto GetCKGroupNormNHWCTypeStringAndOps() { + using InDataType = typename DataTypeAdaptor::type; + using OutDataType = typename DataTypeAdaptor::type; + using AccDataType = typename DataTypeAdaptor::type; + using GammaDataType = float; + using BetaDataType = float; + + using Activation = std::conditional_t; + + std::vector>>> ret; + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + auto invoker = impl->MakeInvokerPointer(); + + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { + if constexpr (WithSwish) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->withSwish, "Swish version only support groupnorm with swish"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->withSwish, "Pass version only support groupnorm without swish"); + } + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector reduce_dims{1, 2, 4}; + + auto activation = Activation{}; + + auto arg = impl->MakeArgumentPointer(in_lengths, // lengths + in_out_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + in_out_strides, // yStrides + reduce_dims, // reduceDims + params->epsilon, + params->src, + params->gamma, + params->beta, + params->dst, + nullptr, + nullptr, + activation); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); + } + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh new file mode 100644 index 0000000000..88443478cf --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef USE_COMPOSABLE_KERNEL +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/utility/data_type.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using F16 = ck::half_t; +using F32 = float; + +using Swish = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +using ck::tensor_operation::device::DeviceNormalization; // the interface +using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation + +template +using device_normalization_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f16_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +// Use this function to get implementation +template +std::vector>> +GetDeviceGroupNormInstances() { + return {}; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F32, F16, Swish, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F32, F16, Pass, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Swish, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Pass, 5, 3>(); + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu new file mode 100644 index 0000000000..d1dd78e345 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu new file mode 100644 index 0000000000..97baed34a3 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 3ccc527526..c0fbe6c7c6 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -65,7 +65,8 @@ struct GroupNormNHWCParams : OpParams { } std::string Signature() const override { - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups); + std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; + std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; return sig; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index 1cec66a9a3..5ee7ad40b0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -14,6 +14,7 @@ namespace rocm { template Status LaunchGroupNormKernel( + RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, @@ -37,18 +38,23 @@ Status LaunchGroupNormKernel( "only num_groups=32 is supported. Got", num_groups); } - GroupNormNHWCParams params(nullptr, stream, output, reinterpret_cast(workspace), input, gamma, beta, + GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + if (tuning_ctx->IsTunableOpEnabled()) { + static GroupNormNHWCTunableOp op; + return op(¶ms); + } + return GroupNormNHWCStaticSelection(¶ms); } -template Status LaunchGroupNormKernel(hipStream_t stream, half* output, +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, const half* input, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool swish); -template Status LaunchGroupNormKernel(hipStream_t stream, float* output, +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, const float* input, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool swish); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h new file mode 100644 index 0000000000..ff5f2f19df --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/common.h" +#include "core/common/status.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::RocmTuningContext; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +constexpr size_t kMaxGroupNormBatchSize = 32; +constexpr size_t kGroupNormNumberOfGroups = 32; + +constexpr size_t GetGroupNormWorkspaceSizeInBytes() { + // Two buffers for sum and squared sum + return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; +} + +template +Status LaunchGroupNormKernel( + RocmTuningContext* tuning_ctx, + hipStream_t stream, + T* output, // normalized output tensor + const T* input, // input tensor + const float* gamma, // gamma (also known as weight or scale) + const float* beta, // beta (also known as bias) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_swish_activation // Whether there is Swish activation after group normalization +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index adeaf2c126..7ca38608b8 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -7,6 +7,7 @@ #include #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" #include "contrib_ops/rocm/diffusion/group_norm_common.h" #include "contrib_ops/rocm/diffusion/group_norm_impl.h" #include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" @@ -136,12 +137,16 @@ class GroupNormNHWCOp { Status IsSupported(const GroupNormNHWCParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0)); + !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + ") isn't divisible by the number of vector size: ", VecSize); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && - params->hw % params->hwPerBlock == 0)); + params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), + "The value of attributes don't meet the requirements."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); + params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + "Configuration: Threads (", ThreadsPerBlock, "), vector size (", + VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); return Status::OK(); } @@ -160,19 +165,14 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { #define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ this->RegisterOp(name{}); \ this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); + this->RegisterOp(name{}); #define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 512) + ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template class GroupNormNHWCTunableOp : public TunableOp> { @@ -180,6 +180,18 @@ class GroupNormNHWCTunableOp : public TunableOp> { GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif // USE_COMPOSABLE_KERNEL } }; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 3506c64f47..b7a017ee72 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -11,12 +11,12 @@ from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import dtype_to_bytes +from utils import dtype_to_bytes, dtype_to_suffix def get_sd_sizes(): batch_sizes = [1, 2] - height = [8, 16, 32, 64] + height = [8, 16, 32] num_channels = [320, 640, 1280, 1920, 2560] num_groups = [32] @@ -25,8 +25,8 @@ def get_sd_sizes(): def dtype_to_funcs(dtype): type_map = { - "float16": list(filter(lambda x: re.search("GroupNormNHWC.*_half", x), dir(ke))), - "float32": list(filter(lambda x: re.search("GroupNormNHWC.*_float", x), dir(ke))), + "float16": list(filter(lambda x: re.match("GroupNormNHWC.*_half", x), dir(ke))), + "float32": list(filter(lambda x: re.match("GroupNormNHWC.*_float", x), dir(ke))), } return type_map[dtype] @@ -52,7 +52,7 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): return x -def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, func): +def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func): np.random.seed(0) width = height input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) @@ -62,7 +62,7 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32) epsilon = 1e-05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = True + use_swish = swish host_x = input_x.astype(dtype) input_d = ke.DeviceArray(host_x) @@ -86,12 +86,16 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: epsilon, use_swish, ) - if my_op.IsSupported(): + y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) + + for impl in my_op.ListOps(): + if not my_op.SelectOp(impl): + continue + my_op.Run() y_d.UpdateHostNumpyArray() - y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) np.testing.assert_allclose(y_ref, output_y, atol=1e-02) @@ -100,9 +104,19 @@ dtypes = ["float32", "float16"] @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -def test_skip_layer_norm(sd_sizes, dtype): +@pytest.mark.parametrize("swish", [True]) +def test_group_norm(sd_sizes, dtype, swish): for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, func) + run_group_norm(*sd_sizes, dtype, swish, func) + + +@pytest.mark.parametrize("sd_sizes", get_sd_sizes()) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("swish", [True]) +def test_group_norm_ck(sd_sizes, dtype, swish): + swish_suffix = "Swish" if swish else "Pass" + ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) + run_group_norm(*sd_sizes, dtype, swish, ck_f_name) @dataclass @@ -124,7 +138,7 @@ class GroupNormNHWCMetric(ke.BandwidthMetric): def profile_group_norm_func( - batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, func + batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func ): np.random.seed(0) input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) @@ -133,7 +147,7 @@ def profile_group_norm_func( workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32) epsilon = 0.05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = True + use_swish = swish input_d = ke.DeviceArray(input_x) gamma_d = ke.DeviceArray(gamma) @@ -156,21 +170,27 @@ def profile_group_norm_func( epsilon, use_swish, ) + for impl in my_op.ListOps(): + duration_ms = -1 + if my_op.SelectOp(impl): + duration_ms = my_op.Profile() + total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype) - duration_ms = -1 - if my_op.IsSupported(): - duration_ms = my_op.Profile() - total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype) - - ke.report( - GroupNormNHWCMetric(func, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups) - ) + ke.report( + GroupNormNHWCMetric( + impl, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups + ) + ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, sort=True): +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, func) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func) + # ck function + swish_suffix = "Swish" if swish else "Pass" + ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name) sd_profile_sizes = [ @@ -209,6 +229,7 @@ if __name__ == "__main__": group.add_argument("num_channels", type=int) group.add_argument("num_groups", type=int) group.add_argument("dtype", choices=dtypes) + group.add_argument("--swish", action="store_true") group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: @@ -216,5 +237,12 @@ if __name__ == "__main__": else: args = parser.parse_args() profile_with_args( - args.batch_size, args.height, args.width, args.num_channels, args.num_groups, args.dtype, args.sort + args.batch_size, + args.height, + args.width, + args.num_channels, + args.num_groups, + args.dtype, + args.swish, + args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/groupnorm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu similarity index 58% rename from onnxruntime/python/tools/kernel_explorer/kernels/rocm/groupnorm.cu rename to onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index b878d692e7..fd1aaa97fb 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/groupnorm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -1,10 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include +#include +#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" #include "contrib_ops/rocm/diffusion/group_norm_common.h" #include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" #include "python/tools/kernel_explorer/device_array.h" @@ -21,21 +22,28 @@ class GroupNormNHWC : public IKernelExplorer { int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {} - - bool IsSupported() { - Status status = op_.IsSupported(¶ms_); - return status.IsOK(); + batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } void Run() override { ORT_THROW_IF_ERROR(op_(¶ms_)); } + std::vector ListOps() const { + return {type_string_}; + } + + bool SelectOp(const std::string& name) { + Status status = op_.IsSupported(¶ms_); + return status.IsOK() && name == type_string_; + } + private: using ParamsT = contrib::rocm::GroupNormNHWCParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCOp op_{}; + std::string type_string_{}; }; template @@ -45,20 +53,27 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {} + batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + type_string_ = "GroupNormNHWCStaticSelection"; + } void Run() override { ORT_THROW_IF_ERROR((contrib::rocm::GroupNormNHWCStaticSelection(¶ms_))); } - bool IsSupported() { + std::vector ListOps() const { + return {type_string_}; + } + + bool SelectOp(const std::string& name) { Status status = contrib::rocm::GroupNormNHWCStaticSelection(¶ms_); - return status.IsOK(); + return status.IsOK() && name == type_string_; } private: using ParamsT = contrib::rocm::GroupNormNHWCParams; ParamsT params_{}; + std::string type_string_{}; }; template @@ -69,15 +84,19 @@ class GroupNormNHWCTunable : public IKernelExplorer { : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { ORT_THROW_IF_ERROR(op_(¶ms_)); } - bool IsSupported() { - return true; + std::vector ListOps() const { + return {"GroupNormNHWCTunable"}; + } + + bool SelectOp(const std::string& name) { + return name == "GroupNormNHWCTunable"; } private: @@ -86,6 +105,51 @@ class GroupNormNHWCTunable : public IKernelExplorer { contrib::rocm::GroupNormNHWCTunableOp op_{}; }; +#ifdef USE_COMPOSABLE_KERNEL +template +class CKGroupNormNHWC : public IKernelExplorer { + public: + CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, + int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), + static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), + batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = contrib::rocm::GroupNormNHWCParams; + using OpT = rocm::tunable::Op; + ParamsT params_{}; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; +}; +#endif // USE_COMPOSABLE_KERNEL + #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init::SetRepeats) \ .def("Profile", &name::Profile) \ .def("Run", &name::Run) \ - .def("IsSupported", &name::IsSupported); + .def("ListOps", &name::ListOps) \ + .def("SelectOp", &name::SelectOp); #define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ REGISTER_OP(name, type, threads_per_block, 1) \ REGISTER_OP(name, type, threads_per_block, 2) \ - REGISTER_OP(name, type, threads_per_block, 4) \ - REGISTER_OP(name, type, threads_per_block, 8) \ - REGISTER_OP(name, type, threads_per_block, 16) + REGISTER_OP(name, type, threads_per_block, 4) #define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) -#define REGISTER_OP_TYPED(name, type) \ - py::class_>(m, #name "_" #type) \ +#define REGISTER_COMMON(name, type, ...) \ + py::class_>(m, name) \ .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("IsSupported", &name::IsSupported); + .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ + .def("Profile", &type<__VA_ARGS__>::Profile) \ + .def("Run", &type<__VA_ARGS__>::Run) \ + .def("ListOps", &type<__VA_ARGS__>::ListOps) \ + .def("SelectOp", &type<__VA_ARGS__>::SelectOp); + +#define REGISTER_OP_TYPED(name, type) \ + REGISTER_COMMON(#name "_" #type, name, type) + +#define REGISTER_CK(type, with_swish, swish_suffix) \ + REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); @@ -130,6 +197,13 @@ KE_REGISTER(m) { REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, half); REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, float); + +#ifdef USE_COMPOSABLE_KERNEL + REGISTER_CK(half, false, "Pass"); + REGISTER_CK(half, true, "Swish"); + REGISTER_CK(float, false, "Pass"); + REGISTER_CK(float, true, "Swish"); +#endif // USE_COMPOSABLE_KERNEL } } // namespace onnxruntime