From e2cd6748fc460bd9d5ee12d9d40193a0a1ab6b77 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 11 Oct 2023 22:45:22 -0700 Subject: [PATCH] Fix GroupNorm fusion: skip if num of channels not supported (#17869) Right now, GroupNorm only support limited number of channels (320, 640, 960, 1280, 1920, 2560, 128, 256, 512). Skip the fusion if number of channels are not supported. ### Motivation and Context SD XL refiner model uses number of channels 384, 768, 1152, 2304 and 3072 in GroupNorm. --- .../tools/transformers/fusion_group_norm.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index a4491d29b3..cd7dc7017c 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -88,13 +88,17 @@ class FusionGroupNorm(Fusion): if instance_norm_bias is None: return - if not ( - len(instance_norm_scale.shape) == 1 - and len(instance_norm_bias.shape) == 1 - and instance_norm_scale.shape == instance_norm_bias.shape - and instance_norm_scale.shape[0] == 32 - ): - logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) + # Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32]. + if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape) + ) + return + + if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape) + ) return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -105,7 +109,8 @@ class FusionGroupNorm(Fusion): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("GroupNorm channels=%d", weight_elements) + logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements) + return self.add_initializer( name=group_norm_name + "_gamma",