[ROCm] SkipLayerNorm: add more configs for block size; loosen constraints (#14900)

### Description
* add more configs for `threads_per_block` in SkipLayerNorm, also in
kernel explorer.
* loosen constraints for hidden_size, so that `SkipLayerNormSmallOp` can
be selected for larger hidden sizes.
* add flag for optional output in kernel_explorer


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
mindest 2023-03-09 22:27:01 +08:00 committed by GitHub
parent d55ae490e1
commit bf2cc808a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 13 deletions

View file

@ -23,7 +23,7 @@ struct SkipLayerNormParams : OpParams {
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
const T* skip, const T* gamma, const T* beta,
const T* bias, float epsilon, int ld, int element_count)
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {}
std::string Signature() const override {
@ -45,8 +45,10 @@ struct SkipLayerNormParams : OpParams {
template <typename T, int ThreadsPerBlock, int VecSize>
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
// Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels,
// which could offer better performance in some combinations of ThreadsPerBlock and VecSize.
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!((params->ld <= 1024 && params->ld % VecSize == 0 &&
!((params->ld <= 8192 && params->ld % VecSize == 0 &&
params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)));
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
@ -159,7 +161,16 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 384)
ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 1024)
template <typename T>
class SkipLayerNormTunableOp : public TunableOp<SkipLayerNormParams<T>> {

View file

@ -136,7 +136,16 @@ class SkipLayerNormTunable : public IKernelExplorer {
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 576) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 640) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 704) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 768) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 832) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024)
#define REGISTER_OP_TYPED(name, type) \
py::class_<name<type>>(m, #name "_" #type) \

View file

@ -46,7 +46,7 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
return output, val
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func):
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False):
np.random.seed(0)
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
@ -56,7 +56,11 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
# Because of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small
epsilon = 0.05 if hidden_size < 8 else 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
output_optional = (
np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
if has_optional_output
else np.empty((0), dtype=dtype)
)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
@ -86,8 +90,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
optional_d.UpdateHostNumpyArray()
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05)
np.testing.assert_almost_equal(y_optional, output_optional, decimal=1e-05)
np.testing.assert_almost_equal(y_ref, output_y, decimal=1)
if has_optional_output:
np.testing.assert_almost_equal(y_optional, output_optional, decimal=1)
dtypes = ["float32", "float16"]
@ -113,7 +118,7 @@ class SkipLayerNormMetric(ke.BandwidthMetric):
return prefix + "not supported or redundant"
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output):
np.random.seed(0)
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
@ -122,7 +127,11 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
bias = np.random.rand(hidden_size).astype(dtype)
epsilon = 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
output_optional = (
np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
if has_optional_output
else np.empty((0), dtype=dtype)
)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
@ -154,10 +163,10 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size))
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True):
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func)
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output)
def profile():
@ -177,9 +186,12 @@ if __name__ == "__main__":
group.add_argument("hidden_size", type=int)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")
group.add_argument("--has_optional_output", "-o", action="store_true")
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort)
profile_with_args(
args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output
)