diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index c9e54d3b8..0f5bef571 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # return self.position_embeddings + num_patches = embeddings.shape[1] - 2 + num_positions = self.position_embeddings.shape[1] - 2 + + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0, :] + dist_pos_embed = self.position_embeddings[:, 1, :] + patch_pos_embed = self.position_embeddings[:, 2:, :] + dim = embeddings.shape[-1] + h0 = height // self.patch_size + w0 = width // self.patch_size + # # we add a small number to avoid floating point error in the interpolation + # # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + _, _, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_length, _ = embeddings.size() if bool_masked_pos is not None: @@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module): embeddings = embeddings * (1.0 - mask) + mask_tokens * mask cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) - embeddings = embeddings + self.position_embeddings + position_embedding = self.position_embeddings + + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = embeddings + position_embedding embeddings = self.dropout(embeddings) return embeddings @@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module): raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) x = self.projection(pixel_values).flatten(2).transpose(1, 2) return x @@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r""" more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): @@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel): if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) - embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[tuple, MaskedImageModelingOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): @@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] @@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[tuple, ImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] @@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] diff --git a/src/transformers/models/deit/modeling_tf_deit.py b/src/transformers/models/deit/modeling_tf_deit.py index e5faff2a4..03ad1385d 100644 --- a/src/transformers/models/deit/modeling_tf_deit.py +++ b/src/transformers/models/deit/modeling_tf_deit.py @@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer): with tf.name_scope(self.dropout.name): self.dropout.build(None) + def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: + num_patches = embeddings.shape[1] - 2 + num_positions = self.position_embeddings.shape[1] - 2 + + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0, :] + dist_pos_embed = self.position_embeddings[:, 1, :] + patch_pos_embed = self.position_embeddings[:, 2:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # # we add a small number to avoid floating point error in the interpolation + # # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = tf.reshape( + patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ) + patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic") + patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1]) + patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim)) + + return tf.concat( + [tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1 + ) + def call( - self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False + self, + pixel_values: tf.Tensor, + bool_masked_pos: tf.Tensor | None = None, + training: bool = False, + interpolate_pos_encoding: bool = False, ) -> tf.Tensor: + _, height, width, _ = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values) batch_size, seq_length, _ = shape_list(embeddings) @@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer): cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0) embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1) - embeddings = embeddings + self.position_embeddings + position_embedding = self.position_embeddings + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = embeddings + position_embedding embeddings = self.dropout(embeddings, training=training) return embeddings @@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer): raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) - if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) + x = self.projection(pixel_values) batch_size, height, width, num_channels = shape_list(x) x = tf.reshape(x, (batch_size, height * width, num_channels)) @@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, training: bool = False, ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask) - embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, training=training) + embedding_output = self.embeddings( + pixel_values, + bool_masked_pos=bool_masked_pos, + training=training, + interpolate_pos_encoding=interpolate_pos_encoding, + ) encoder_outputs = self.encoder( embedding_output, @@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r""" output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, training: bool = False, ) -> Union[Tuple, TFBaseModelOutputWithPooling]: outputs = self.deit( @@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, training=training, ) return outputs @@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, training: bool = False, ) -> Union[tuple, TFMaskedImageModelingOutput]: r""" @@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, training=training, ) @@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, training: bool = False, ) -> Union[tf.Tensor, TFImageClassifierOutput]: r""" @@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, training=training, ) @@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, training: bool = False, ) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, training=training, ) diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index b251e5904..d1357d129 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + @slow + def test_inference_interpolate_pos_encoding(self): + model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to( + torch_device + ) + + image_processor = self.default_image_processor + + # image size is {"height": 480, "width": 640} + image = prepare_img() + image_processor.size = {"height": 480, "width": 640} + # center crop set to False so image is not center cropped to 224x224 + inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + @slow @require_accelerate @require_torch_accelerator diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py index fefd50680..537d97517 100644 --- a/tests/models/deit/test_modeling_tf_deit.py +++ b/tests/models/deit/test_modeling_tf_deit.py @@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase): expected_slice = tf.constant([-1.0266, 0.1912, -1.2861]) self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_interpolate_pos_encoding(self): + model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224") + + image_processor = self.default_image_processor + # image size is {"height": 480, "width": 640} + image = prepare_img() + image_processor.size = {"height": 480, "width": 640} + # center crop set to False so image is not center cropped to 224x224 + inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False) + # forward pass + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape)