diff --git a/src/transformers/models/bit/configuration_bit.py b/src/transformers/models/bit/configuration_bit.py index 6418549ab..7c1e10510 100644 --- a/src/transformers/models/bit/configuration_bit.py +++ b/src/transformers/models/bit/configuration_bit.py @@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig): The width factor for the model. out_features (`List[str]`, *optional*): If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. - (depending on how many stages the model has). + (depending on how many stages the model has). Will default to the last stage if unset. Example: ```python diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index f71b008d6..71caabf91 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin): self.stage_names = config.stage_names self.bit = BitModel(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] out_feature_channels = {} out_feature_channels["stem"] = config.embedding_size diff --git a/src/transformers/models/maskformer/configuration_maskformer_swin.py b/src/transformers/models/maskformer/configuration_maskformer_swin.py index 4c9f1a4ca..36e074655 100644 --- a/src/transformers/models/maskformer/configuration_maskformer_swin.py +++ b/src/transformers/models/maskformer/configuration_maskformer_swin.py @@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig): layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. out_features (`List[str]`, *optional*): - If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`. + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. Example: diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 05af1b90f..f3c5577ab 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): self.stage_names = config.stage_names self.model = MaskFormerSwinModel(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] if "stem" in self.out_features: raise ValueError("This backbone does not support 'stem' in the `out_features`.") diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 2d0dbc3b0..74f6c6939 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig): downsample_in_first_stage (`bool`, *optional*, defaults to `False`): If `True`, the first stage will downsample the inputs using a `stride` of 2. out_features (`List[str]`, *optional*): - If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, - `"stage3"`, `"stage4"`. + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. Example: ```python diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index ebd134be5..4c737c218 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel): nn.init.constant_(module.bias, 0) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (ResNetModel, ResNetBackbone)): + if isinstance(module, ResNetEncoder): module.gradient_checkpointing = value @@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): self.embedder = ResNetEmbeddings(config) self.encoder = ResNetEncoder(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] out_feature_channels = {} out_feature_channels["stem"] = config.embedding_size diff --git a/tests/models/bit/test_modeling_bit.py b/tests/models/bit/test_modeling_bit.py index 0c3bf147c..7b7e07cb8 100644 --- a/tests/models/bit/test_modeling_bit.py +++ b/tests/models/bit/test_modeling_bit.py @@ -119,7 +119,7 @@ class BitModelTester: model.eval() result = model(pixel_values) - # verify hidden states + # verify feature maps self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) @@ -127,6 +127,21 @@ class BitModelTester: self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) + # verify backbone works with out_features=None + config.out_features = None + model = BitBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 53777d27c..15d3dca3c 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -119,7 +119,7 @@ class ResNetModelTester: model.eval() result = model(pixel_values) - # verify hidden states + # verify feature maps self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) @@ -127,6 +127,21 @@ class ResNetModelTester: self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) + # verify backbone works with out_features=None + config.out_features = None + model = ResNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs