diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 743856ca0..8684a58a4 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -864,7 +864,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): else: self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] - self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) + self.hidden_states_norms = nn.ModuleList( + [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]] + ) # Initialize weights and apply final processing self.post_init() @@ -889,10 +891,10 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): # we skip the stem hidden_states = outputs.hidden_states[1:] - feature_maps = () # we need to reshape the hidden states to their original spatial dimensions # spatial dimensions contains all the heights and widths of each stage, including after the embeddings spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions + feature_maps = () for i, (hidden_state, stage, (height, width)) in enumerate( zip(hidden_states, self.stage_names[1:], spatial_dimensions) ): diff --git a/tests/models/bit/test_modeling_bit.py b/tests/models/bit/test_modeling_bit.py index ef7a6dbb2..cec997e59 100644 --- a/tests/models/bit/test_modeling_bit.py +++ b/tests/models/bit/test_modeling_bit.py @@ -22,6 +22,7 @@ from transformers import BitConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -55,6 +56,7 @@ class BitModelTester: num_labels=3, scope=None, out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], num_groups=1, ): self.parent = parent @@ -71,6 +73,7 @@ class BitModelTester: self.scope = scope self.num_stages = len(hidden_sizes) self.out_features = out_features + self.out_indices = out_indices self.num_groups = num_groups def prepare_config_and_inputs(self): @@ -93,6 +96,7 @@ class BitModelTester: hidden_act=self.hidden_act, num_labels=self.num_labels, out_features=self.out_features, + out_indices=self.out_indices, num_groups=self.num_groups, ) @@ -317,3 +321,14 @@ class BitModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([[-0.6526, -0.5263, -1.4398]]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +class BitBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (BitBackbone,) if is_torch_available() else () + config_class = BitConfig + + has_attentions = False + + def setUp(self): + self.model_tester = BitModelTester(self) diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py index 9021c5e5e..16d57d7e9 100644 --- a/tests/models/convnext/test_modeling_convnext.py +++ b/tests/models/convnext/test_modeling_convnext.py @@ -22,6 +22,7 @@ from transformers import ConvNextConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -57,6 +58,7 @@ class ConvNextModelTester: num_labels=10, initializer_range=0.02, out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], scope=None, ): self.parent = parent @@ -73,6 +75,7 @@ class ConvNextModelTester: self.num_labels = num_labels self.initializer_range = initializer_range self.out_features = out_features + self.out_indices = out_indices self.scope = scope def prepare_config_and_inputs(self): @@ -95,6 +98,7 @@ class ConvNextModelTester: is_decoder=False, initializer_range=self.initializer_range, out_features=self.out_features, + out_indices=self.out_indices, num_labels=self.num_labels, ) @@ -224,6 +228,10 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) @@ -299,3 +307,14 @@ class ConvNextModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([-0.0260, -0.4739, 0.1911]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +class ConvNextBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (ConvNextBackbone,) if is_torch_available() else () + config_class = ConvNextConfig + + has_attentions = False + + def setUp(self): + self.model_tester = ConvNextModelTester(self) diff --git a/tests/models/convnextv2/test_modeling_convnextv2.py b/tests/models/convnextv2/test_modeling_convnextv2.py index 10ae34c22..008481ab3 100644 --- a/tests/models/convnextv2/test_modeling_convnextv2.py +++ b/tests/models/convnextv2/test_modeling_convnextv2.py @@ -58,6 +58,7 @@ class ConvNextV2ModelTester: num_labels=10, initializer_range=0.02, out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], scope=None, ): self.parent = parent @@ -74,6 +75,7 @@ class ConvNextV2ModelTester: self.num_labels = num_labels self.initializer_range = initializer_range self.out_features = out_features + self.out_indices = out_indices self.scope = scope def prepare_config_and_inputs(self): @@ -97,6 +99,7 @@ class ConvNextV2ModelTester: is_decoder=False, initializer_range=self.initializer_range, out_features=self.out_features, + out_indices=self.out_indices, num_labels=self.num_labels, ) diff --git a/tests/models/dinat/test_modeling_dinat.py b/tests/models/dinat/test_modeling_dinat.py index 0ba3a808b..c08abf3d1 100644 --- a/tests/models/dinat/test_modeling_dinat.py +++ b/tests/models/dinat/test_modeling_dinat.py @@ -22,6 +22,7 @@ from transformers import DinatConfig from transformers.testing_utils import require_natten, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -67,6 +68,7 @@ class DinatModelTester: use_labels=True, num_labels=10, out_features=["stage1", "stage2"], + out_indices=[1, 2], ): self.parent = parent self.batch_size = batch_size @@ -92,6 +94,7 @@ class DinatModelTester: self.use_labels = use_labels self.num_labels = num_labels self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -125,6 +128,7 @@ class DinatModelTester: layer_norm_eps=self.layer_norm_eps, initializer_range=self.initializer_range, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -383,3 +387,13 @@ class DinatModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor([-0.1545, -0.7667, 0.4642]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +@require_natten +class DinatBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (DinatBackbone,) if is_torch_available() else () + config_class = DinatConfig + + def setUp(self): + self.model_tester = DinatModelTester(self) diff --git a/tests/models/maskformer/test_modeling_maskformer_swin.py b/tests/models/maskformer/test_modeling_maskformer_swin.py index 9285c444a..4125f36db 100644 --- a/tests/models/maskformer/test_modeling_maskformer_swin.py +++ b/tests/models/maskformer/test_modeling_maskformer_swin.py @@ -23,6 +23,7 @@ from transformers import MaskFormerSwinConfig from transformers.testing_utils import require_torch, require_torch_multi_gpu, torch_device from transformers.utils import is_torch_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -64,6 +65,7 @@ class MaskFormerSwinModelTester: type_sequence_label_size=10, encoder_stride=8, out_features=["stage1", "stage2", "stage3"], + out_indices=[1, 2, 3], ): self.parent = parent self.batch_size = batch_size @@ -90,6 +92,7 @@ class MaskFormerSwinModelTester: self.type_sequence_label_size = type_sequence_label_size self.encoder_stride = encoder_stride self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -123,6 +126,7 @@ class MaskFormerSwinModelTester: initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -395,3 +399,48 @@ class MaskFormerSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + +@require_torch +class MaskFormerSwinBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (MaskFormerSwinBackbone,) if is_torch_available() else () + config_class = MaskFormerSwinConfig + + def setUp(self): + self.model_tester = MaskFormerSwinModelTester(self) + + # Overriding as returned hidden states are tuples of tensors instead of a single tensor + def test_backbone_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + batch_size = inputs_dict["pixel_values"].shape[0] + + for backbone_class in self.all_model_classes: + backbone = backbone_class(config) + backbone.to(torch_device) + backbone.eval() + + outputs = backbone(**inputs_dict) + + # Test default outputs and verify feature maps + self.assertIsInstance(outputs.feature_maps, tuple) + self.assertTrue(len(outputs.feature_maps) == len(backbone.channels)) + for feature_map, n_channels in zip(outputs.feature_maps, backbone.channels): + self.assertTrue(feature_map.shape[:2], (batch_size, n_channels)) + self.assertIsNone(outputs.hidden_states) + self.assertIsNone(outputs.attentions) + + # Test output_hidden_states=True + outputs = backbone(**inputs_dict, output_hidden_states=True) + self.assertIsNotNone(outputs.hidden_states) + self.assertTrue(len(outputs.hidden_states), len(backbone.stage_names)) + # We skip the stem layer + for hidden_states, n_channels in zip(outputs.hidden_states[1:], backbone.channels): + for hidden_state in hidden_states: + # Hidden states are in the format (batch_size, (height * width), n_channels) + h_batch_size, _, h_n_channels = hidden_state.shape + self.assertTrue((h_batch_size, h_n_channels), (batch_size, n_channels)) + + # Test output_attentions=True + if self.has_attentions: + outputs = backbone(**inputs_dict, output_attentions=True) + self.assertIsNotNone(outputs.attentions) diff --git a/tests/models/nat/test_modeling_nat.py b/tests/models/nat/test_modeling_nat.py index dff0a4323..cd4f0bf96 100644 --- a/tests/models/nat/test_modeling_nat.py +++ b/tests/models/nat/test_modeling_nat.py @@ -22,6 +22,7 @@ from transformers import NatConfig from transformers.testing_utils import require_natten, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -66,6 +67,7 @@ class NatModelTester: use_labels=True, num_labels=10, out_features=["stage1", "stage2"], + out_indices=[1, 2], ): self.parent = parent self.batch_size = batch_size @@ -90,6 +92,7 @@ class NatModelTester: self.use_labels = use_labels self.num_labels = num_labels self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -122,6 +125,7 @@ class NatModelTester: layer_norm_eps=self.layer_norm_eps, initializer_range=self.initializer_range, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -380,3 +384,13 @@ class NatModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor([0.3805, -0.8676, -0.3912]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +@require_natten +class NatBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (NatBackbone,) if is_torch_available() else () + config_class = NatConfig + + def setUp(self): + self.model_tester = NatModelTester(self) diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 31d19d4bd..fe7a1a045 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -22,6 +22,7 @@ from transformers import ResNetConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -57,6 +58,7 @@ class ResNetModelTester: num_labels=3, scope=None, out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], ): self.parent = parent self.batch_size = batch_size @@ -72,6 +74,7 @@ class ResNetModelTester: self.scope = scope self.num_stages = len(hidden_sizes) self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -93,6 +96,7 @@ class ResNetModelTester: hidden_act=self.hidden_act, num_labels=self.num_labels, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -323,3 +327,13 @@ class ResNetModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +class ResNetBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (ResNetBackbone,) if is_torch_available() else () + has_attentions = False + config_class = ResNetConfig + + def setUp(self): + self.model_tester = ResNetModelTester(self) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index f519a0204..9dcb1bfdc 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -22,6 +22,7 @@ from transformers import SwinConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -69,6 +70,7 @@ class SwinModelTester: type_sequence_label_size=10, encoder_stride=8, out_features=["stage1", "stage2"], + out_indices=[1, 2], ): self.parent = parent self.batch_size = batch_size @@ -95,6 +97,7 @@ class SwinModelTester: self.type_sequence_label_size = type_sequence_label_size self.encoder_stride = encoder_stride self.out_features = out_features + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -128,6 +131,7 @@ class SwinModelTester: initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, out_features=self.out_features, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -502,3 +506,12 @@ class SwinModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + +@require_torch +class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (SwinBackbone,) if is_torch_available() else () + config_class = SwinConfig + + def setUp(self): + self.model_tester = SwinModelTester(self) diff --git a/tests/test_backbone_common.py b/tests/test_backbone_common.py new file mode 100644 index 000000000..80e68a2f4 --- /dev/null +++ b/tests/test_backbone_common.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect + +from transformers.testing_utils import require_torch, torch_device + + +@require_torch +class BackboneTesterMixin: + all_model_classes = () + has_attentions = True + + def test_config(self): + config_class = self.config_class + + # test default config + config = config_class() + self.assertIsNotNone(config) + expected_stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(config.depths) + 1)] + self.assertEqual(config.stage_names, expected_stage_names) + self.assertTrue(set(config.out_features).issubset(set(config.stage_names))) + + # Test out_features and out_indices are correctly set + # out_features and out_indices both None + config = config_class(out_features=None, out_indices=None) + self.assertEqual(config.out_features, [config.stage_names[-1]]) + self.assertEqual(config.out_indices, [len(config.stage_names) - 1]) + + # out_features and out_indices both set + config = config_class(out_features=["stem", "stage1"], out_indices=[0, 1]) + self.assertEqual(config.out_features, ["stem", "stage1"]) + self.assertEqual(config.out_indices, [0, 1]) + + # Only out_features set + config = config_class(out_features=["stage1", "stage3"]) + self.assertEqual(config.out_features, ["stage1", "stage3"]) + self.assertEqual(config.out_indices, [1, 3]) + + # Only out_indices set + config = config_class(out_indices=[0, 2]) + self.assertEqual(config.out_features, [config.stage_names[0], config.stage_names[2]]) + self.assertEqual(config.out_indices, [0, 2]) + + # Error raised when out_indices do not correspond to out_features + with self.assertRaises(ValueError): + config = config_class(out_features=["stage1", "stage2"], out_indices=[0, 2]) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_channels(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertEqual(len(model.channels), len(config.out_features)) + num_features = model.num_features + out_indices = [config.stage_names.index(feat) for feat in config.out_features] + out_channels = [num_features[idx] for idx in out_indices] + self.assertListEqual(model.channels, out_channels) + + config.out_features = None + config.out_indices = None + model = model_class(config) + self.assertEqual(len(model.channels), 1) + self.assertListEqual(model.channels, [num_features[-1]]) + + def test_create_from_modified_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + self.assertEqual(len(result.feature_maps), len(config.out_features)) + self.assertEqual(len(model.channels), len(config.out_features)) + + # Check output of last stage is taken if out_features=None, out_indices=None + modified_config = copy.deepcopy(config) + modified_config.out_features = None + modified_config.out_indices = None + model = model_class(modified_config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + self.assertEqual(len(result.feature_maps), 1) + self.assertEqual(len(model.channels), 1) + + # Check backbone can be initialized with fresh weights + modified_config = copy.deepcopy(config) + modified_config.use_pretrained_backbone = False + model = model_class(modified_config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + def test_backbone_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for backbone_class in self.all_model_classes: + backbone = backbone_class(config) + + self.assertTrue(hasattr(backbone, "stage_names")) + self.assertTrue(hasattr(backbone, "num_features")) + self.assertTrue(hasattr(backbone, "out_indices")) + self.assertTrue(hasattr(backbone, "out_features")) + self.assertTrue(hasattr(backbone, "out_feature_channels")) + self.assertTrue(hasattr(backbone, "channels")) + + # Verify num_features has been initialized in the backbone init + self.assertIsNotNone(backbone.num_features) + self.assertTrue(len(backbone.channels) == len(backbone.out_indices)) + self.assertTrue(len(backbone.stage_names) == len(backbone.num_features)) + self.assertTrue(len(backbone.channels) <= len(backbone.num_features)) + self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names)) + + def test_backbone_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + batch_size = inputs_dict["pixel_values"].shape[0] + + for backbone_class in self.all_model_classes: + backbone = backbone_class(config) + backbone.to(torch_device) + backbone.eval() + + outputs = backbone(**inputs_dict) + + # Test default outputs and verify feature maps + self.assertIsInstance(outputs.feature_maps, tuple) + self.assertTrue(len(outputs.feature_maps) == len(backbone.channels)) + for feature_map, n_channels in zip(outputs.feature_maps, backbone.channels): + self.assertTrue(feature_map.shape[:2], (batch_size, n_channels)) + self.assertIsNone(outputs.hidden_states) + self.assertIsNone(outputs.attentions) + + # Test output_hidden_states=True + outputs = backbone(**inputs_dict, output_hidden_states=True) + self.assertIsNotNone(outputs.hidden_states) + self.assertTrue(len(outputs.hidden_states), len(backbone.stage_names)) + for hidden_state, n_channels in zip(outputs.hidden_states, backbone.channels): + self.assertTrue(hidden_state.shape[:2], (batch_size, n_channels)) + + # Test output_attentions=True + if self.has_attentions: + outputs = backbone(**inputs_dict, output_attentions=True) + self.assertIsNotNone(outputs.attentions)