add --no_attention_mask option (#4750)

output producer name and version in optimized model.
avoid removing initializer that existed in graph output
This commit is contained in:
Tianlei Wu 2020-08-12 15:56:25 -07:00 committed by GitHub
parent adda8c66d9
commit a69ca63895
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 2 deletions

View file

@ -17,6 +17,7 @@ class AttentionMaskFormat:
MaskIndexEnd = 0
MaskIndexEndAndStart = 1
AttentionMask = 2
NoMask = 3
class AttentionMask():
@ -45,6 +46,9 @@ class AttentionMask():
return next(iter(self.mask_indice))
def process_mask(self, input):
if self.mask_format == AttentionMaskFormat.NoMask:
return None
if input in self.mask_indice:
return self.mask_indice[input]
@ -136,7 +140,10 @@ class FusionAttention(Fusion):
vals=qkv_bias.flatten().tolist())
self.model.add_initializer(bias)
attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias', mask_index]
attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias']
if mask_index is not None:
attnetion_inputs.append(mask_index)
attention_node = helper.make_node('Attention',
inputs=attnetion_inputs,
outputs=[output],

View file

@ -628,7 +628,7 @@ class OnnxModel:
weights_to_remove = []
weights_to_keep = []
for initializer in graph.initializer:
if initializer.name not in remaining_input_names:
if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name):
weights_to_remove.append(initializer)
else:
weights_to_keep.append(initializer.name)

View file

@ -39,6 +39,9 @@ class BertOptimizationOptions:
def use_raw_attention_mask(self):
self.attention_mask_format = AttentionMaskFormat.AttentionMask
def disable_attention_mask(self):
self.attention_mask_format = AttentionMaskFormat.NoMask
class BertOnnxModel(OnnxModel):
def __init__(self, model, num_heads, hidden_size):

View file

@ -192,6 +192,12 @@ def _parse_arguments():
help="use raw attention mask instead of mask index in attention operator")
parser.set_defaults(use_raw_attention_mask=False)
parser.add_argument('--no_attention_mask',
required=False,
action='store_true',
help="no attention mask. Only works for model_type=bert")
parser.set_defaults(no_attention_mask=False)
parser.add_argument('--verbose', required=False, action='store_true')
parser.set_defaults(verbose=False)
@ -233,6 +239,8 @@ def _get_optimization_options(args):
optimization_options.enable_gelu_approximation = True
if args.use_raw_attention_mask:
optimization_options.use_raw_attention_mask()
if args.no_attention_mask:
optimization_options.disable_attention_mask()
return optimization_options
@ -295,6 +303,9 @@ def optimize_model(input,
os.remove(temp_model_path)
logger.debug("Remove tempoary model: {}".format(temp_model_path))
optimizer.model.producer_name = "onnxruntime_tools"
optimizer.model.producer_version = "1.4"
return optimizer