mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
add --no_attention_mask option (#4750)
output producer name and version in optimized model. avoid removing initializer that existed in graph output
This commit is contained in:
parent
adda8c66d9
commit
a69ca63895
4 changed files with 23 additions and 2 deletions
|
|
@ -17,6 +17,7 @@ class AttentionMaskFormat:
|
|||
MaskIndexEnd = 0
|
||||
MaskIndexEndAndStart = 1
|
||||
AttentionMask = 2
|
||||
NoMask = 3
|
||||
|
||||
|
||||
class AttentionMask():
|
||||
|
|
@ -45,6 +46,9 @@ class AttentionMask():
|
|||
return next(iter(self.mask_indice))
|
||||
|
||||
def process_mask(self, input):
|
||||
if self.mask_format == AttentionMaskFormat.NoMask:
|
||||
return None
|
||||
|
||||
if input in self.mask_indice:
|
||||
return self.mask_indice[input]
|
||||
|
||||
|
|
@ -136,7 +140,10 @@ class FusionAttention(Fusion):
|
|||
vals=qkv_bias.flatten().tolist())
|
||||
self.model.add_initializer(bias)
|
||||
|
||||
attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias', mask_index]
|
||||
attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias']
|
||||
if mask_index is not None:
|
||||
attnetion_inputs.append(mask_index)
|
||||
|
||||
attention_node = helper.make_node('Attention',
|
||||
inputs=attnetion_inputs,
|
||||
outputs=[output],
|
||||
|
|
|
|||
|
|
@ -628,7 +628,7 @@ class OnnxModel:
|
|||
weights_to_remove = []
|
||||
weights_to_keep = []
|
||||
for initializer in graph.initializer:
|
||||
if initializer.name not in remaining_input_names:
|
||||
if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name):
|
||||
weights_to_remove.append(initializer)
|
||||
else:
|
||||
weights_to_keep.append(initializer.name)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ class BertOptimizationOptions:
|
|||
def use_raw_attention_mask(self):
|
||||
self.attention_mask_format = AttentionMaskFormat.AttentionMask
|
||||
|
||||
def disable_attention_mask(self):
|
||||
self.attention_mask_format = AttentionMaskFormat.NoMask
|
||||
|
||||
|
||||
class BertOnnxModel(OnnxModel):
|
||||
def __init__(self, model, num_heads, hidden_size):
|
||||
|
|
|
|||
|
|
@ -192,6 +192,12 @@ def _parse_arguments():
|
|||
help="use raw attention mask instead of mask index in attention operator")
|
||||
parser.set_defaults(use_raw_attention_mask=False)
|
||||
|
||||
parser.add_argument('--no_attention_mask',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help="no attention mask. Only works for model_type=bert")
|
||||
parser.set_defaults(no_attention_mask=False)
|
||||
|
||||
parser.add_argument('--verbose', required=False, action='store_true')
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
|
|
@ -233,6 +239,8 @@ def _get_optimization_options(args):
|
|||
optimization_options.enable_gelu_approximation = True
|
||||
if args.use_raw_attention_mask:
|
||||
optimization_options.use_raw_attention_mask()
|
||||
if args.no_attention_mask:
|
||||
optimization_options.disable_attention_mask()
|
||||
|
||||
return optimization_options
|
||||
|
||||
|
|
@ -295,6 +303,9 @@ def optimize_model(input,
|
|||
os.remove(temp_model_path)
|
||||
logger.debug("Remove tempoary model: {}".format(temp_model_path))
|
||||
|
||||
optimizer.model.producer_name = "onnxruntime_tools"
|
||||
optimizer.model.producer_version = "1.4"
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue