Fix GroupNorm fusion: skip if num of channels not supported (#17869)

Right now, GroupNorm only support limited number of channels (320, 640,
960, 1280, 1920, 2560, 128, 256, 512). Skip the fusion if number of
channels are not supported.

### Motivation and Context
SD XL refiner model uses number of channels 384, 768, 1152, 2304 and
3072 in GroupNorm.
This commit is contained in:
Tianlei Wu 2023-10-11 22:45:22 -07:00 committed by GitHub
parent 25bbd8d4eb
commit e2cd6748fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -88,13 +88,17 @@ class FusionGroupNorm(Fusion):
if instance_norm_bias is None:
return
if not (
len(instance_norm_scale.shape) == 1
and len(instance_norm_bias.shape) == 1
and instance_norm_scale.shape == instance_norm_bias.shape
and instance_norm_scale.shape[0] == 32
):
logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0])
# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
)
return
if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
)
return
if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
@ -105,7 +109,8 @@ class FusionGroupNorm(Fusion):
group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")
if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
logger.info("GroupNorm channels=%d", weight_elements)
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
return
self.add_initializer(
name=group_norm_name + "_gamma",