mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
Fix attention fusion for UNet onnx model export when using LoRA weights (#17249)
### Description Tested with stable diffusion unet models exported by both pytorch 2.1.0 (nightly) and pytorch 1.13.1, with and without LoRA weights. ### Motivation and Context LoRA weights modifiy the unet model by adding matmul and scale operations to every q/k/v/out tensors, which breaks the current MHA pattern recognition.
This commit is contained in:
parent
761c4333b5
commit
4880f1da46
1 changed files with 670 additions and 20 deletions
|
|
@ -375,6 +375,481 @@ class FusionAttentionUnet(Fusion):
|
|||
self.increase_counter(counter_name)
|
||||
return attention_node
|
||||
|
||||
def create_attention_node_lora(
|
||||
self,
|
||||
q_matmul_add: NodeProto,
|
||||
k_matmul_add: NodeProto,
|
||||
v_matmul_add: NodeProto,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
input: str,
|
||||
output: str,
|
||||
) -> Union[NodeProto, None]:
|
||||
"""Create an Attention node.
|
||||
|
||||
Args:
|
||||
q_matmul (NodeProto): MatMul node in fully connection for Q
|
||||
k_matmul (NodeProto): MatMul node in fully connection for K
|
||||
v_matmul (NodeProto): MatMul node in fully connection for V
|
||||
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
||||
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
||||
input (str): input name
|
||||
output (str): output name
|
||||
|
||||
Returns:
|
||||
Union[NodeProto, None]: the node created or None if failed.
|
||||
"""
|
||||
is_self_attention = not self.is_cross_attention
|
||||
|
||||
q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
|
||||
k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
|
||||
v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)
|
||||
|
||||
q_lora_nodes = self.match_lora_path(q_matmul_add)
|
||||
if q_lora_nodes is None:
|
||||
return None
|
||||
(q_lora_last_node, q_lora_matmul_1) = q_lora_nodes
|
||||
|
||||
k_lora_nodes = self.match_lora_path(k_matmul_add)
|
||||
if k_lora_nodes is None:
|
||||
return None
|
||||
(k_lora_last_node, k_lora_matmul_1) = k_lora_nodes
|
||||
|
||||
v_lora_nodes = self.match_lora_path(v_matmul_add)
|
||||
if v_lora_nodes is None:
|
||||
return None
|
||||
(v_lora_last_node, v_lora_matmul_1) = v_lora_nodes
|
||||
|
||||
if is_self_attention:
|
||||
if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
|
||||
logger.debug(
|
||||
"For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
|
||||
q_matmul.input[0],
|
||||
k_matmul.input[0],
|
||||
v_matmul.input[0],
|
||||
)
|
||||
return None
|
||||
|
||||
if (
|
||||
q_lora_matmul_1.input[0] != input
|
||||
or k_lora_matmul_1.input[0] != input
|
||||
or v_lora_matmul_1.input[0] != input
|
||||
):
|
||||
logger.debug(
|
||||
"For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
|
||||
q_lora_matmul_1.input[0],
|
||||
k_lora_matmul_1.input[0],
|
||||
v_lora_matmul_1.input[0],
|
||||
)
|
||||
return None
|
||||
else:
|
||||
if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
|
||||
logger.debug(
|
||||
"For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
|
||||
q_matmul.input[0],
|
||||
k_matmul.input[0],
|
||||
v_matmul.input[0],
|
||||
)
|
||||
return None
|
||||
|
||||
if (
|
||||
q_lora_matmul_1.input[0] != input
|
||||
or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
|
||||
or (k_matmul.input[0] == input)
|
||||
):
|
||||
logger.debug(
|
||||
(
|
||||
"For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
|
||||
"Got %s, %s, %s"
|
||||
),
|
||||
q_lora_matmul_1.input[0],
|
||||
k_lora_matmul_1.input[0],
|
||||
v_lora_matmul_1.input[0],
|
||||
)
|
||||
return None
|
||||
|
||||
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
||||
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
||||
return None
|
||||
|
||||
q_weight = self.model.get_initializer(q_matmul.input[1])
|
||||
k_weight = self.model.get_initializer(k_matmul.input[1])
|
||||
v_weight = self.model.get_initializer(v_matmul.input[1])
|
||||
if not (q_weight and k_weight and v_weight):
|
||||
return None
|
||||
|
||||
# Sometimes weights are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
|
||||
return None
|
||||
|
||||
qw = NumpyHelper.to_array(q_weight)
|
||||
kw = NumpyHelper.to_array(k_weight)
|
||||
vw = NumpyHelper.to_array(v_weight)
|
||||
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
|
||||
|
||||
# assert q and k have same shape as expected
|
||||
if is_self_attention:
|
||||
if qw.shape != kw.shape or qw.shape != vw.shape:
|
||||
return None
|
||||
|
||||
qw_in_size = qw.shape[0]
|
||||
|
||||
if hidden_size > 0 and hidden_size != qw_in_size:
|
||||
raise ValueError(
|
||||
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
|
||||
"Please provide a correct input hidden size or pass in 0"
|
||||
)
|
||||
|
||||
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
|
||||
# For 2d weights, the shapes would be [in_size, out_size].
|
||||
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
||||
qw_out_size = int(np.prod(qw.shape[1:]))
|
||||
|
||||
if self.enable_packed_qkv:
|
||||
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
||||
|
||||
c = qw_in_size
|
||||
n = num_heads
|
||||
h = qw_out_size // num_heads
|
||||
|
||||
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
|
||||
qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
|
||||
c, n * 3 * h
|
||||
)
|
||||
|
||||
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
|
||||
weight = helper.make_tensor(
|
||||
name=matmul_node_name + "_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
|
||||
vals=qkv_weight.flatten().tolist(),
|
||||
)
|
||||
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
matmul_node = helper.make_node(
|
||||
"MatMul",
|
||||
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
||||
outputs=[matmul_node_name + "_out"],
|
||||
name=matmul_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
||||
|
||||
# Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
|
||||
# the Q/K/V weights to be changed without having to re-run the optimizer.
|
||||
lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
|
||||
lora_weight_shape_tensor = helper.make_tensor(
|
||||
name=lora_weight_shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[4],
|
||||
vals=[0, 0, n, h],
|
||||
)
|
||||
self.model.add_initializer(lora_weight_shape_tensor, self.this_graph_name)
|
||||
|
||||
# Reshape the LoRA Q weights
|
||||
q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
|
||||
q_lora_reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
||||
outputs=[q_lora_reshape_node_name + "_out"],
|
||||
name=q_lora_reshape_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name
|
||||
|
||||
# Reshape the LoRA K weights
|
||||
k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
|
||||
k_lora_reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
||||
outputs=[k_lora_reshape_node_name + "_out"],
|
||||
name=k_lora_reshape_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
|
||||
|
||||
# Reshape the LoRA V weights
|
||||
v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
|
||||
v_lora_reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
||||
outputs=[v_lora_reshape_node_name + "_out"],
|
||||
name=v_lora_reshape_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
|
||||
|
||||
# Concat the reshaped LoRA Q/K/V weights together on the third axis
|
||||
qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
|
||||
qkv_lora_concat_node = helper.make_node(
|
||||
"Concat",
|
||||
inputs=[
|
||||
q_lora_reshape_node.output[0],
|
||||
k_lora_reshape_node.output[0],
|
||||
v_lora_reshape_node.output[0],
|
||||
],
|
||||
outputs=[qkv_lora_concat_node_name + "_out"],
|
||||
name=qkv_lora_concat_node_name,
|
||||
)
|
||||
qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
|
||||
self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name
|
||||
|
||||
# Reshape the LoRA concatenated weights to [..., n * 3 * h]
|
||||
reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
|
||||
reshaped_lora_weights_shape_tensor = helper.make_tensor(
|
||||
name=reshaped_lora_weights_shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[3],
|
||||
vals=[0, 0, n * 3 * h],
|
||||
)
|
||||
self.model.add_initializer(reshaped_lora_weights_shape_tensor, self.this_graph_name)
|
||||
|
||||
qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
|
||||
qkv_lora_reshaped_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
|
||||
outputs=[qkv_lora_reshaped_node_name + "_out"],
|
||||
name=qkv_lora_reshaped_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name
|
||||
|
||||
# Add the LoRA Q/K/V weights to the base Q/K/V weights
|
||||
add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
|
||||
add_weights_node = helper.make_node(
|
||||
"Add",
|
||||
inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
|
||||
outputs=[add_weights_node_name + "_out"],
|
||||
name=add_weights_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name
|
||||
|
||||
# Finally, reshape the concatenated Q/K/V result to 5D
|
||||
shape_tensor_name = add_weights_node_name + "_reshape_shape"
|
||||
shape_tensor = helper.make_tensor(
|
||||
name=shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[5],
|
||||
vals=[0, 0, n, 3, h],
|
||||
)
|
||||
self.model.add_initializer(shape_tensor, self.this_graph_name)
|
||||
|
||||
reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[add_weights_node.output[0], shape_tensor_name],
|
||||
outputs=[attention_node_name + "_qkv_input"],
|
||||
name=add_weights_node_name + "_reshape",
|
||||
)
|
||||
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_add.extend(
|
||||
[
|
||||
matmul_node,
|
||||
q_lora_reshape_node,
|
||||
k_lora_reshape_node,
|
||||
v_lora_reshape_node,
|
||||
qkv_lora_concat_node,
|
||||
qkv_lora_reshaped_node,
|
||||
add_weights_node,
|
||||
reshape_node,
|
||||
]
|
||||
)
|
||||
self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
|
||||
else:
|
||||
# TODO: Support non-packed QKV
|
||||
return None
|
||||
else: # cross attention
|
||||
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
||||
if self.enable_packed_kv:
|
||||
if kw.shape != vw.shape:
|
||||
return None
|
||||
|
||||
kw_in_size = kw.shape[0]
|
||||
vw_in_size = vw.shape[0]
|
||||
assert kw_in_size == vw_in_size
|
||||
|
||||
qw_out_size = qw.shape[1]
|
||||
kw_out_size = kw.shape[1]
|
||||
vw_out_size = vw.shape[1]
|
||||
assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
|
||||
|
||||
c = kw_in_size
|
||||
n = num_heads
|
||||
h = kw_out_size // num_heads
|
||||
|
||||
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
|
||||
kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
|
||||
|
||||
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
|
||||
weight = helper.make_tensor(
|
||||
name=matmul_node_name + "_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[kv_weight.shape[0], kv_weight.shape[1]],
|
||||
vals=kv_weight.flatten().tolist(),
|
||||
)
|
||||
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
matmul_node = helper.make_node(
|
||||
"MatMul",
|
||||
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
||||
outputs=[matmul_node_name + "_out"],
|
||||
name=matmul_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
||||
|
||||
# Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
|
||||
# the Q/K/V weights to be changed without having to re-run the optimizer.
|
||||
kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
|
||||
lora_weight_shape_tensor = helper.make_tensor(
|
||||
name=kv_lora_weight_shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[4],
|
||||
vals=[0, 0, n, h],
|
||||
)
|
||||
self.model.add_initializer(lora_weight_shape_tensor, self.this_graph_name)
|
||||
|
||||
# Reshape the LoRA K weights
|
||||
k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
|
||||
k_lora_reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
|
||||
outputs=[k_lora_reshape_node_name + "_out"],
|
||||
name=k_lora_reshape_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
|
||||
|
||||
# Reshape the LoRA V weights
|
||||
v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
|
||||
v_lora_reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
|
||||
outputs=[v_lora_reshape_node_name + "_out"],
|
||||
name=v_lora_reshape_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
|
||||
|
||||
# Concat the reshaped LoRA K/V weights together on the third axis
|
||||
kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
|
||||
kv_lora_concat_node = helper.make_node(
|
||||
"Concat",
|
||||
inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
|
||||
outputs=[kv_lora_concat_node_name + "_out"],
|
||||
name=kv_lora_concat_node_name,
|
||||
)
|
||||
kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
|
||||
self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name
|
||||
|
||||
# Reshape the LoRA concatenated weights to [..., n * 2 * h]
|
||||
reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
|
||||
reshaped_kv_lora_weights_shape_tensor = helper.make_tensor(
|
||||
name=reshaped_kv_lora_weights_shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[3],
|
||||
vals=[0, 0, n * 2 * h],
|
||||
)
|
||||
self.model.add_initializer(reshaped_kv_lora_weights_shape_tensor, self.this_graph_name)
|
||||
|
||||
kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
|
||||
kv_lora_reshaped_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
|
||||
outputs=[kv_lora_reshaped_node_name + "_out"],
|
||||
name=kv_lora_reshaped_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name
|
||||
|
||||
# Add the LoRA K/V weights to the base K/V weights
|
||||
add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
|
||||
add_kv_weights_node = helper.make_node(
|
||||
"Add",
|
||||
inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
|
||||
outputs=[add_kv_weights_node_name + "_out"],
|
||||
name=add_kv_weights_node_name,
|
||||
)
|
||||
self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name
|
||||
|
||||
# Finally, reshape the concatenated K/V result to 5D
|
||||
shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
|
||||
shape_tensor = helper.make_tensor(
|
||||
name=shape_tensor_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[5],
|
||||
vals=[0, 0, n, 2, h],
|
||||
)
|
||||
self.model.add_initializer(shape_tensor, self.this_graph_name)
|
||||
|
||||
reshape_node = helper.make_node(
|
||||
"Reshape",
|
||||
inputs=[add_kv_weights_node.output[0], shape_tensor_name],
|
||||
outputs=[attention_node_name + "_kv_input"],
|
||||
name=add_kv_weights_node_name + "_reshape",
|
||||
)
|
||||
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
||||
self.nodes_to_add.extend(
|
||||
[
|
||||
matmul_node,
|
||||
k_lora_reshape_node,
|
||||
v_lora_reshape_node,
|
||||
kv_lora_concat_node,
|
||||
kv_lora_reshaped_node,
|
||||
add_kv_weights_node,
|
||||
reshape_node,
|
||||
]
|
||||
)
|
||||
self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
|
||||
else:
|
||||
# TODO: Support non-packed KV
|
||||
return None
|
||||
|
||||
# No bias, use zeros
|
||||
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
|
||||
qkv_bias_dim = 3 * hidden_size
|
||||
|
||||
bias = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_bias",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qkv_bias_dim],
|
||||
vals=qkv_bias.flatten().tolist(),
|
||||
)
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
|
||||
if is_self_attention:
|
||||
if not self.enable_packed_qkv:
|
||||
# TODO: Support non-packed QKV
|
||||
return None
|
||||
else:
|
||||
attention_inputs = [attention_node_name + "_qkv_input"]
|
||||
else:
|
||||
if not self.enable_packed_kv:
|
||||
# TODO: Support non-packed QKV
|
||||
return None
|
||||
else:
|
||||
attention_inputs = [
|
||||
q_matmul_add.output[0],
|
||||
attention_node_name + "_kv_input",
|
||||
]
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
counter_name = (
|
||||
"Attention (self attention)"
|
||||
if is_self_attention and not self.enable_packed_qkv
|
||||
else "MultiHeadAttention ({})".format(
|
||||
"self attention with packed qkv"
|
||||
if self.enable_packed_qkv
|
||||
else "cross attention with packed kv"
|
||||
if self.enable_packed_kv
|
||||
else "cross attention"
|
||||
)
|
||||
)
|
||||
self.increase_counter(counter_name)
|
||||
return attention_node
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
|
||||
|
||||
|
|
@ -397,30 +872,62 @@ class FusionAttentionUnet(Fusion):
|
|||
return
|
||||
|
||||
match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
|
||||
if match_qkv is None:
|
||||
return
|
||||
if match_qkv is not None:
|
||||
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
|
||||
|
||||
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
||||
new_node = self.create_attention_node(
|
||||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input=normalize_node.output[0],
|
||||
output=attention_last_node.output[0],
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
else:
|
||||
# Check if we have a LoRA pattern
|
||||
match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
|
||||
root_input, skip_add
|
||||
)
|
||||
if match_qkv is None:
|
||||
return
|
||||
|
||||
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
||||
new_node = self.create_attention_node(
|
||||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input=normalize_node.output[0],
|
||||
output=attention_last_node.output[0],
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
|
||||
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
||||
new_node = self.create_attention_node_lora(
|
||||
matmul_add_q,
|
||||
matmul_add_k,
|
||||
matmul_add_v,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input=normalize_node.output[0],
|
||||
output=attention_last_node.output[0],
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
|
@ -530,3 +1037,146 @@ class FusionAttentionUnet(Fusion):
|
|||
return None
|
||||
|
||||
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|
||||
|
||||
def match_qkv_torch1_lora(self, root_input, skip_add):
|
||||
"""Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
||||
[another_input, 0, None, None, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes is None:
|
||||
return None
|
||||
|
||||
(_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
||||
|
||||
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA v path")
|
||||
return None
|
||||
(_, _, _, matmul_add_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match LoRA qk path")
|
||||
return None
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA q path")
|
||||
return None
|
||||
(_, _transpose_q, reshape_q, matmul_add_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
|
||||
)
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA k path")
|
||||
return None
|
||||
|
||||
(_, _, _, _, matmul_add_k) = k_nodes
|
||||
|
||||
return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
|
||||
|
||||
def match_qkv_torch2_lora(self, root_input, skip_add):
|
||||
"""Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[another_input, 0, None, None, 0, 0],
|
||||
)
|
||||
if qkv_nodes is None:
|
||||
return None
|
||||
|
||||
(_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
||||
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA v path")
|
||||
return None
|
||||
(_, _, matmul_add_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match LoRA qk path")
|
||||
return None
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA q path")
|
||||
return None
|
||||
(mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match LoRA k path")
|
||||
return None
|
||||
|
||||
(_mul_k, _, _, matmul_add_k) = k_nodes
|
||||
|
||||
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
|
||||
mul_q_nodes = self.model.match_parent_path(
|
||||
mul_q,
|
||||
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
|
||||
[None, 0, 1, 0, 0, 0, 0, 0],
|
||||
)
|
||||
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
|
||||
logger.debug("fuse_attention: failed to match LoRA mul_q path")
|
||||
return None
|
||||
|
||||
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
|
||||
|
||||
def match_lora_path(
|
||||
self,
|
||||
add_node: NodeProto,
|
||||
):
|
||||
# Lora paths can look like one of the following options:
|
||||
# MatMul -> MatMul -> Add
|
||||
# MatMul -> MatMul -> Mul -> Add
|
||||
# MatMul -> MatMul -> Mul -> Mul -> Add
|
||||
|
||||
# Try matching MatMul -> MatMul -> Add
|
||||
lora_nodes = self.model.match_parent_path(
|
||||
add_node,
|
||||
["MatMul", "MatMul"],
|
||||
[1, 0],
|
||||
)
|
||||
|
||||
if lora_nodes is not None:
|
||||
(lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
|
||||
return (lora_matmul_2_node, lora_matmul_1_node)
|
||||
|
||||
# Try matching MatMul -> MatMul -> Mul -> Add
|
||||
lora_nodes = self.model.match_parent_path(
|
||||
add_node,
|
||||
["Mul", "MatMul", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
|
||||
if lora_nodes is not None:
|
||||
(lora_mul_node, _, lora_matmul_1_node) = lora_nodes
|
||||
return (lora_mul_node, lora_matmul_1_node)
|
||||
|
||||
# Try matching MatMul -> MatMul -> Mul -> Mul -> Add
|
||||
lora_nodes = self.model.match_parent_path(
|
||||
add_node,
|
||||
["Mul", "Mul", "MatMul", "MatMul"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
|
||||
if lora_nodes is not None:
|
||||
(lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
|
||||
return (lora_mul_node, lora_matmul_1_node)
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in a new issue