mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add ONNX support for swin transformer (#19390)
* swin transformer onnx support * Updated image dimensions as dynamic Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
969534af4b
commit
e162cebfa3
5 changed files with 30 additions and 2 deletions
|
|
@ -94,6 +94,7 @@ Ready-made configurations include the following architectures:
|
|||
- RoFormer
|
||||
- SegFormer
|
||||
- SqueezeBERT
|
||||
- Swin Transformer
|
||||
- T5
|
||||
- ViT
|
||||
- XLM
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
|
|||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]}
|
||||
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig", "SwinOnnxConfig"]}
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -53,7 +53,7 @@ else:
|
|||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
|
||||
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig, SwinOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
|
|||
|
|
@ -14,7 +14,13 @@
|
|||
# limitations under the License.
|
||||
""" Swin Transformer model configuration"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
|
@ -145,3 +151,20 @@ class SwinConfig(PretrainedConfig):
|
|||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
|
||||
|
||||
class SwinOnnxConfig(OnnxConfig):
|
||||
|
||||
torch_onnx_minimum_version = version.parse("1.11")
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-4
|
||||
|
|
|
|||
|
|
@ -471,6 +471,9 @@ class FeaturesManager:
|
|||
"question-answering",
|
||||
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
|
||||
),
|
||||
"swin": supported_features_mapping(
|
||||
"default", "image-classification", "masked-im", onnx_config_cls="models.swin.SwinOnnxConfig"
|
||||
),
|
||||
"t5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
|
|
|
|||
|
|
@ -217,6 +217,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||
("longformer", "allenai/longformer-base-4096"),
|
||||
("yolos", "hustvl/yolos-tiny"),
|
||||
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
|
||||
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
|
|
|
|||
Loading…
Reference in a new issue