diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index e147ef7904..16019f3a24 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -17,6 +17,7 @@ class AttentionMaskFormat: MaskIndexEnd = 0 MaskIndexEndAndStart = 1 AttentionMask = 2 + NoMask = 3 class AttentionMask(): @@ -45,6 +46,9 @@ class AttentionMask(): return next(iter(self.mask_indice)) def process_mask(self, input): + if self.mask_format == AttentionMaskFormat.NoMask: + return None + if input in self.mask_indice: return self.mask_indice[input] @@ -136,7 +140,10 @@ class FusionAttention(Fusion): vals=qkv_bias.flatten().tolist()) self.model.add_initializer(bias) - attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias', mask_index] + attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias'] + if mask_index is not None: + attnetion_inputs.append(mask_index) + attention_node = helper.make_node('Attention', inputs=attnetion_inputs, outputs=[output], diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 58954f0bda..f7cb09e059 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -628,7 +628,7 @@ class OnnxModel: weights_to_remove = [] weights_to_keep = [] for initializer in graph.initializer: - if initializer.name not in remaining_input_names: + if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name): weights_to_remove.append(initializer) else: weights_to_keep.append(initializer.name) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 4b7f24c80a..8ecf75f85f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -39,6 +39,9 @@ class BertOptimizationOptions: def use_raw_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.AttentionMask + def disable_attention_mask(self): + self.attention_mask_format = AttentionMaskFormat.NoMask + class BertOnnxModel(OnnxModel): def __init__(self, model, num_heads, hidden_size): diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ad92c04c45..e7900e877f 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -192,6 +192,12 @@ def _parse_arguments(): help="use raw attention mask instead of mask index in attention operator") parser.set_defaults(use_raw_attention_mask=False) + parser.add_argument('--no_attention_mask', + required=False, + action='store_true', + help="no attention mask. Only works for model_type=bert") + parser.set_defaults(no_attention_mask=False) + parser.add_argument('--verbose', required=False, action='store_true') parser.set_defaults(verbose=False) @@ -233,6 +239,8 @@ def _get_optimization_options(args): optimization_options.enable_gelu_approximation = True if args.use_raw_attention_mask: optimization_options.use_raw_attention_mask() + if args.no_attention_mask: + optimization_options.disable_attention_mask() return optimization_options @@ -295,6 +303,9 @@ def optimize_model(input, os.remove(temp_model_path) logger.debug("Remove tempoary model: {}".format(temp_model_path)) + optimizer.model.producer_name = "onnxruntime_tools" + optimizer.model.producer_version = "1.4" + return optimizer