mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Fix GroupNorm fusion: skip if num of channels not supported (#17869)
Right now, GroupNorm only support limited number of channels (320, 640, 960, 1280, 1920, 2560, 128, 256, 512). Skip the fusion if number of channels are not supported. ### Motivation and Context SD XL refiner model uses number of channels 384, 768, 1152, 2304 and 3072 in GroupNorm.
This commit is contained in:
parent
25bbd8d4eb
commit
e2cd6748fc
1 changed files with 13 additions and 8 deletions
|
|
@ -88,13 +88,17 @@ class FusionGroupNorm(Fusion):
|
|||
if instance_norm_bias is None:
|
||||
return
|
||||
|
||||
if not (
|
||||
len(instance_norm_scale.shape) == 1
|
||||
and len(instance_norm_bias.shape) == 1
|
||||
and instance_norm_scale.shape == instance_norm_bias.shape
|
||||
and instance_norm_scale.shape[0] == 32
|
||||
):
|
||||
logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0])
|
||||
# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
|
||||
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
|
||||
logger.debug(
|
||||
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
|
||||
)
|
||||
return
|
||||
|
||||
if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
|
||||
logger.debug(
|
||||
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
|
||||
)
|
||||
return
|
||||
|
||||
if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
|
||||
|
|
@ -105,7 +109,8 @@ class FusionGroupNorm(Fusion):
|
|||
group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")
|
||||
|
||||
if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
|
||||
logger.info("GroupNorm channels=%d", weight_elements)
|
||||
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
|
||||
return
|
||||
|
||||
self.add_initializer(
|
||||
name=group_norm_name + "_gamma",
|
||||
|
|
|
|||
Loading…
Reference in a new issue