mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
[ROCm] SkipLayerNorm: add more configs for block size; loosen constraints (#14900)
### Description * add more configs for `threads_per_block` in SkipLayerNorm, also in kernel explorer. * loosen constraints for hidden_size, so that `SkipLayerNormSmallOp` can be selected for larger hidden sizes. * add flag for optional output in kernel_explorer ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
d55ae490e1
commit
bf2cc808a1
3 changed files with 45 additions and 13 deletions
|
|
@ -23,7 +23,7 @@ struct SkipLayerNormParams : OpParams {
|
|||
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
|
||||
const T* skip, const T* gamma, const T* beta,
|
||||
const T* bias, float epsilon, int ld, int element_count)
|
||||
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
|
||||
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
|
||||
gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {}
|
||||
|
||||
std::string Signature() const override {
|
||||
|
|
@ -45,8 +45,10 @@ struct SkipLayerNormParams : OpParams {
|
|||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
|
||||
// Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels,
|
||||
// which could offer better performance in some combinations of ThreadsPerBlock and VecSize.
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!((params->ld <= 1024 && params->ld % VecSize == 0 &&
|
||||
!((params->ld <= 8192 && params->ld % VecSize == 0 &&
|
||||
params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)));
|
||||
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
|
|
@ -159,7 +161,16 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
|
|||
ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 384)
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 1024)
|
||||
|
||||
template <typename T>
|
||||
class SkipLayerNormTunableOp : public TunableOp<SkipLayerNormParams<T>> {
|
||||
|
|
|
|||
|
|
@ -136,7 +136,16 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 576) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 640) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 704) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 768) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 832) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024)
|
||||
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
|
|||
return output, val
|
||||
|
||||
|
||||
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func):
|
||||
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
|
|
@ -56,7 +56,11 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
|
|||
# Because of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small
|
||||
epsilon = 0.05 if hidden_size < 8 else 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
output_optional = (
|
||||
np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
if has_optional_output
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
|
|
@ -86,8 +90,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
|
|||
optional_d.UpdateHostNumpyArray()
|
||||
|
||||
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
|
||||
np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05)
|
||||
np.testing.assert_almost_equal(y_optional, output_optional, decimal=1e-05)
|
||||
np.testing.assert_almost_equal(y_ref, output_y, decimal=1)
|
||||
if has_optional_output:
|
||||
np.testing.assert_almost_equal(y_optional, output_optional, decimal=1)
|
||||
|
||||
|
||||
dtypes = ["float32", "float16"]
|
||||
|
|
@ -113,7 +118,7 @@ class SkipLayerNormMetric(ke.BandwidthMetric):
|
|||
return prefix + "not supported or redundant"
|
||||
|
||||
|
||||
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
||||
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
|
|
@ -122,7 +127,11 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
bias = np.random.rand(hidden_size).astype(dtype)
|
||||
epsilon = 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
output_optional = (
|
||||
np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
if has_optional_output
|
||||
else np.empty((0), dtype=dtype)
|
||||
)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
|
|
@ -154,10 +163,10 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size))
|
||||
|
||||
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True):
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func)
|
||||
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output)
|
||||
|
||||
|
||||
def profile():
|
||||
|
|
@ -177,9 +186,12 @@ if __name__ == "__main__":
|
|||
group.add_argument("hidden_size", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--sort", action="store_true")
|
||||
group.add_argument("--has_optional_output", "-o", action="store_true")
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort)
|
||||
profile_with_args(
|
||||
args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue