diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index a4491d29b3..cd7dc7017c 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -88,13 +88,17 @@ class FusionGroupNorm(Fusion): if instance_norm_bias is None: return - if not ( - len(instance_norm_scale.shape) == 1 - and len(instance_norm_bias.shape) == 1 - and instance_norm_scale.shape == instance_norm_bias.shape - and instance_norm_scale.shape[0] == 32 - ): - logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) + # Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32]. + if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape) + ) + return + + if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32): + logger.debug( + "Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape) + ) return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -105,7 +109,8 @@ class FusionGroupNorm(Fusion): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("GroupNorm channels=%d", weight_elements) + logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements) + return self.add_initializer( name=group_norm_name + "_gamma",