mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add onnx support for VisionEncoderDecoder (#19254)
* Add onnx support for VisionEncoderDecoder * Add onnx support for VisionEncoderDecoder * Removed unused import * Rename encoder hidden state Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docstrings and removed redundant code * Added test function for enc-dec models * Update doc string text Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * fixed code style Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
298f6a98c2
commit
3080bb4754
7 changed files with 305 additions and 35 deletions
|
|
@ -96,6 +96,7 @@ Ready-made configurations include the following architectures:
|
|||
- SqueezeBERT
|
||||
- Swin Transformer
|
||||
- T5
|
||||
- Vision Encoder decoder
|
||||
- ViT
|
||||
- XLM
|
||||
- XLM-RoBERTa
|
||||
|
|
@ -294,6 +295,13 @@ that can be used for fast autoregressive decoding.
|
|||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
For `VisionEncoderDecoder` type models, the encoder and decoder parts are
|
||||
exported separately as two ONNX files named `encoder_model.onnx` and `decoder_model.onnx` respectively.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Exporting a model for an unsupported architecture
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ from ...utils import (
|
|||
)
|
||||
|
||||
|
||||
_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]}
|
||||
_import_structure = {
|
||||
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
@ -54,7 +56,7 @@ else:
|
|||
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
|
|||
|
|
@ -15,12 +15,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ... import PreTrainedTokenizerBase, TensorType
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -119,3 +126,97 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
|
|||
output["decoder"] = self.decoder.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
|
||||
torch_onnx_minimum_version = version.parse("1.11")
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-4
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})
|
||||
|
||||
|
||||
class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict()
|
||||
common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizerBase",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
import torch
|
||||
|
||||
common_inputs = OrderedDict()
|
||||
|
||||
dummy_input = super().generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
batch, encoder_sequence = dummy_input["input_ids"].shape
|
||||
encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
|
||||
common_inputs["input_ids"] = dummy_input.pop("input_ids")
|
||||
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
|
||||
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
|
||||
|
||||
return common_inputs
|
||||
|
||||
|
||||
class VisionEncoderDecoderOnnxConfig(OnnxConfig):
|
||||
@property
|
||||
def inputs(self) -> None:
|
||||
pass
|
||||
|
||||
def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
|
||||
r"""
|
||||
Returns ONNX encoder config for `VisionEncoderDecoder` model.
|
||||
|
||||
Args:
|
||||
encoder_config (`PretrainedConfig`):
|
||||
The encoder model's configuration to use when exporting to ONNX.
|
||||
|
||||
Returns:
|
||||
[`VisionEncoderDecoderEncoderOnnxConfig`]: An instance of the ONNX configuration object
|
||||
"""
|
||||
return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)
|
||||
|
||||
def get_decoder_config(
|
||||
self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
|
||||
) -> OnnxConfig:
|
||||
r"""
|
||||
Returns ONNX decoder config for `VisionEncoderDecoder` model.
|
||||
|
||||
Args:
|
||||
encoder_config (`PretrainedConfig`):
|
||||
The encoder model's configuration to use when exporting to ONNX.
|
||||
decoder_config (`PretrainedConfig`):
|
||||
The decoder model's configuration to use when exporting to ONNX
|
||||
feature (`str`, *optional*):
|
||||
The type of feature to export the model with.
|
||||
|
||||
Returns:
|
||||
[`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
|
||||
"""
|
||||
decoder_config.encoder_hidden_size = encoder_config.hidden_size
|
||||
return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ from .convert import export, validate_model_outputs
|
|||
from .features import FeaturesManager
|
||||
|
||||
|
||||
ENCODER_DECODER_MODELS = ["vision-encoder-decoder"]
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
|
||||
parser.add_argument(
|
||||
|
|
@ -65,48 +68,110 @@ def main():
|
|||
if not args.output.parent.exists():
|
||||
args.output.parent.mkdir(parents=True)
|
||||
|
||||
# Instantiate the appropriate preprocessor
|
||||
if args.preprocessor == "auto":
|
||||
preprocessor = get_preprocessor(args.model)
|
||||
elif args.preprocessor == "tokenizer":
|
||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||
elif args.preprocessor == "feature_extractor":
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||
elif args.preprocessor == "processor":
|
||||
preprocessor = AutoProcessor.from_pretrained(args.model)
|
||||
else:
|
||||
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
||||
|
||||
# Allocate the model
|
||||
model = FeaturesManager.get_model_from_feature(
|
||||
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
if args.opset is None:
|
||||
args.opset = onnx_config.default_onnx_opset
|
||||
if model_kind in ENCODER_DECODER_MODELS:
|
||||
encoder_model = model.get_encoder()
|
||||
decoder_model = model.get_decoder()
|
||||
|
||||
if args.opset < onnx_config.default_onnx_opset:
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||
f"At least {onnx_config.default_onnx_opset} is required."
|
||||
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||
decoder_onnx_config = onnx_config.get_decoder_config(
|
||||
encoder_model.config, decoder_model.config, feature=args.feature
|
||||
)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
model,
|
||||
onnx_config,
|
||||
args.opset,
|
||||
args.output,
|
||||
)
|
||||
if args.opset is None:
|
||||
args.opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||
|
||||
if args.atol is None:
|
||||
args.atol = onnx_config.atol_for_validation
|
||||
if args.opset < min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset):
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. At least "
|
||||
f" {min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)} is required."
|
||||
)
|
||||
|
||||
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
|
||||
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
encoder_onnx_config,
|
||||
args.opset,
|
||||
args.output.parent.joinpath("encoder_model.onnx"),
|
||||
)
|
||||
|
||||
validate_model_outputs(
|
||||
encoder_onnx_config,
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
args.output.parent.joinpath("encoder_model.onnx"),
|
||||
onnx_outputs,
|
||||
args.atol if args.atol else encoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
decoder_onnx_config,
|
||||
args.opset,
|
||||
args.output.parent.joinpath("decoder_model.onnx"),
|
||||
)
|
||||
|
||||
validate_model_outputs(
|
||||
decoder_onnx_config,
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
args.output.parent.joinpath("decoder_model.onnx"),
|
||||
onnx_outputs,
|
||||
args.atol if args.atol else decoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
logger.info(
|
||||
f"All good, model saved at: {args.output.parent.joinpath('encoder_model.onnx').as_posix()},"
|
||||
f" {args.output.parent.joinpath('decoder_model.onnx').as_posix()}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Instantiate the appropriate preprocessor
|
||||
if args.preprocessor == "auto":
|
||||
preprocessor = get_preprocessor(args.model)
|
||||
elif args.preprocessor == "tokenizer":
|
||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||
elif args.preprocessor == "feature_extractor":
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||
elif args.preprocessor == "processor":
|
||||
preprocessor = AutoProcessor.from_pretrained(args.model)
|
||||
else:
|
||||
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
if args.opset is None:
|
||||
args.opset = onnx_config.default_onnx_opset
|
||||
|
||||
if args.opset < onnx_config.default_onnx_opset:
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||
f"At least {onnx_config.default_onnx_opset} is required."
|
||||
)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
model,
|
||||
onnx_config,
|
||||
args.opset,
|
||||
args.output,
|
||||
)
|
||||
|
||||
if args.atol is None:
|
||||
args.atol = onnx_config.atol_for_validation
|
||||
|
||||
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
|
||||
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class OnnxConfig(ABC):
|
|||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
}
|
||||
|
||||
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
|
||||
|
|
@ -451,7 +452,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
|||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
# TODO: should we set seq_length = 1 when self.use_past = True?
|
||||
common_inputs = super().generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
|
|
@ -577,7 +577,6 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
|
|||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ if is_torch_available():
|
|||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
)
|
||||
if is_tf_available():
|
||||
from transformers.models.auto import (
|
||||
|
|
@ -98,6 +99,7 @@ class FeaturesManager:
|
|||
"image-segmentation": AutoModelForImageSegmentation,
|
||||
"masked-im": AutoModelForMaskedImageModeling,
|
||||
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
||||
"vision2seq-lm": AutoModelForVision2Seq,
|
||||
}
|
||||
if is_tf_available():
|
||||
_TASKS_TO_TF_AUTOMODELS = {
|
||||
|
|
@ -481,6 +483,9 @@ class FeaturesManager:
|
|||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.t5.T5OnnxConfig",
|
||||
),
|
||||
"vision-encoder-decoder": supported_features_mapping(
|
||||
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
|
||||
),
|
||||
"vit": supported_features_mapping(
|
||||
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
|
||||
),
|
||||
|
|
@ -582,6 +587,7 @@ class FeaturesManager:
|
|||
raise KeyError(
|
||||
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||
)
|
||||
|
||||
return task_to_automodel[task]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||
"""
|
||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||
with self.subTest(name):
|
||||
|
||||
# without past
|
||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||
|
|
@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = {
|
|||
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
||||
("vision-encoder-decoder", "nlpconnect/vit-gpt2-image-captioning"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
("bloom", "bigscience/bloom-560m"),
|
||||
("gpt2", "gpt2"),
|
||||
|
|
@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase):
|
|||
except (RuntimeError, ValueError) as e:
|
||||
self.fail(f"{name}, {feature} -> {e}")
|
||||
|
||||
def _onnx_export_encoder_decoder_models(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"
|
||||
):
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers.onnx import export
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = model_class.from_config(config)
|
||||
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.utils import torch_version
|
||||
|
||||
if torch_version < onnx_config.torch_onnx_minimum_version:
|
||||
pytest.skip(
|
||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||
)
|
||||
|
||||
encoder_model = model.get_encoder()
|
||||
decoder_model = model.get_decoder()
|
||||
|
||||
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||
decoder_onnx_config = onnx_config.get_decoder_config(encoder_model.config, decoder_model.config, feature)
|
||||
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||
|
||||
onnx_opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||
|
||||
with NamedTemporaryFile("w") as encoder_output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor, encoder_model, encoder_onnx_config, onnx_opset, Path(encoder_output.name), device=device
|
||||
)
|
||||
validate_model_outputs(
|
||||
encoder_onnx_config,
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
Path(encoder_output.name),
|
||||
onnx_outputs,
|
||||
encoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
with NamedTemporaryFile("w") as decoder_output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
decoder_onnx_config,
|
||||
onnx_config.default_onnx_opset,
|
||||
Path(decoder_output.name),
|
||||
device=device,
|
||||
)
|
||||
validate_model_outputs(
|
||||
decoder_onnx_config,
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
Path(decoder_output.name),
|
||||
onnx_outputs,
|
||||
decoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
|
|
@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase):
|
|||
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_rjieba
|
||||
def test_pytorch_export_encoder_decoder_models(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export_encoder_decoder_models(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_rjieba
|
||||
def test_pytorch_export_encoder_decoder_models_on_cuda(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export_encoder_decoder_models(
|
||||
test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda"
|
||||
)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
|
|
|
|||
Loading…
Reference in a new issue