mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Improve perceiver (#14750)
* First draft * Improve docstring + clean up tests * Remove unused code * Add check in case one doesn't provide a preprocessor
This commit is contained in:
parent
971e36667a
commit
e926ea2bdd
3 changed files with 34 additions and 45 deletions
|
|
@ -42,7 +42,8 @@ class PerceiverConfig(PretrainedConfig):
|
|||
d_latents (:obj:`int`, `optional`, defaults to 1280):
|
||||
Dimension of the latent embeddings.
|
||||
d_model (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimension of the inputs.
|
||||
Dimension of the inputs. Should only be provided in case [`PerceiverTextPreprocessor`] is used or no
|
||||
preprocessor is provided.
|
||||
num_blocks (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of blocks in the Transformer encoder.
|
||||
num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26):
|
||||
|
|
|
|||
|
|
@ -499,7 +499,7 @@ class PerceiverLayer(nn.Module):
|
|||
class PerceiverEncoder(nn.Module):
|
||||
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, kv_dim=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
|
@ -523,7 +523,7 @@ class PerceiverEncoder(nn.Module):
|
|||
v_channels=config.v_channels,
|
||||
num_heads=config.num_cross_attention_heads,
|
||||
q_dim=config.d_latents,
|
||||
kv_dim=config.d_model,
|
||||
kv_dim=kv_dim,
|
||||
widening_factor=config.cross_attention_widening_factor,
|
||||
use_query_residual=config.use_query_residual,
|
||||
)
|
||||
|
|
@ -734,7 +734,9 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||
self.input_preprocessor = input_preprocessor
|
||||
self.output_postprocessor = output_postprocessor
|
||||
self.embeddings = PerceiverEmbeddings(config)
|
||||
self.encoder = PerceiverEncoder(config)
|
||||
self.encoder = PerceiverEncoder(
|
||||
config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
|
||||
)
|
||||
self.decoder = decoder
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
|
|
@ -782,16 +784,13 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||
else:
|
||||
modality_sizes = None
|
||||
inputs_without_pos = None
|
||||
if inputs.size()[-1] != self.config.d_model:
|
||||
raise ValueError(
|
||||
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
|
||||
"Make sure to set config.d_model appropriately."
|
||||
)
|
||||
|
||||
if inputs.size()[-1] != self.config.d_model:
|
||||
raise ValueError(
|
||||
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
|
||||
"Please update config.d_model appropriately."
|
||||
)
|
||||
else:
|
||||
input_shape = inputs.size()
|
||||
|
||||
batch_size, seq_length, _ = input_shape
|
||||
batch_size, seq_length, _ = inputs.size()
|
||||
device = inputs.device
|
||||
|
||||
# If no attention mask is provided, make them all ones
|
||||
|
|
@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
|
|||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
text_preprocessor = PerceiverTextPreprocessor(config)
|
||||
|
||||
trainable_position_encoding_kwargs_decoder = dict(
|
||||
num_channels=config.d_model, index_dims=config.max_position_embeddings
|
||||
num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings
|
||||
)
|
||||
|
||||
self.perceiver = PerceiverModel(
|
||||
config,
|
||||
input_preprocessor=PerceiverTextPreprocessor(config),
|
||||
input_preprocessor=text_preprocessor,
|
||||
decoder=PerceiverBasicDecoder(
|
||||
config,
|
||||
output_num_channels=config.d_latents,
|
||||
output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
|
||||
num_channels=config.d_model,
|
||||
num_channels=text_preprocessor.num_channels,
|
||||
qk_channels=8 * 32,
|
||||
v_channels=config.d_model,
|
||||
v_channels=text_preprocessor.num_channels,
|
||||
num_heads=8,
|
||||
use_query_residual=False,
|
||||
final_project=False,
|
||||
|
|
@ -1502,22 +1503,24 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
|
|||
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
|
||||
)
|
||||
|
||||
image_preprocessor = PerceiverImagePreprocessor(
|
||||
config,
|
||||
prep_type="patches",
|
||||
spatial_downsample=1,
|
||||
conv_after_patching=True,
|
||||
conv_after_patching_in_channels=54,
|
||||
temporal_downsample=2,
|
||||
position_encoding_type="fourier",
|
||||
# position_encoding_kwargs
|
||||
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
|
||||
)
|
||||
|
||||
self.perceiver = PerceiverModel(
|
||||
config,
|
||||
input_preprocessor=PerceiverImagePreprocessor(
|
||||
config,
|
||||
prep_type="patches",
|
||||
spatial_downsample=1,
|
||||
conv_after_patching=True,
|
||||
conv_after_patching_in_channels=54,
|
||||
temporal_downsample=2,
|
||||
position_encoding_type="fourier",
|
||||
# position_encoding_kwargs
|
||||
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
|
||||
),
|
||||
input_preprocessor=image_preprocessor,
|
||||
decoder=PerceiverOpticalFlowDecoder(
|
||||
config,
|
||||
num_channels=config.d_model,
|
||||
num_channels=image_preprocessor.num_channels,
|
||||
output_image_shape=config.train_size,
|
||||
rescale_factor=100.0,
|
||||
# decoder kwargs
|
||||
|
|
@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
|
|||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
|
||||
|
||||
|
|
|
|||
|
|
@ -147,19 +147,14 @@ class PerceiverModelTester:
|
|||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
|
||||
config.d_model = 512
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
|
||||
config.d_model = 261
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
|
||||
config.d_model = 322
|
||||
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
elif model_class.__name__ == "PerceiverForOpticalFlow":
|
||||
config.d_model = 322
|
||||
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
|
||||
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
|
||||
config.d_model = 409
|
||||
images = torch.randn(
|
||||
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
|
||||
device=torch_device,
|
||||
|
|
@ -211,8 +206,6 @@ class PerceiverModelTester:
|
|||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
|
||||
# set num_labels
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
|
@ -222,9 +215,6 @@ class PerceiverModelTester:
|
|||
def create_and_check_for_image_classification_learned(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 512
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationLearned(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
|
@ -234,9 +224,6 @@ class PerceiverModelTester:
|
|||
def create_and_check_for_image_classification_fourier(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 261
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationFourier(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
|
@ -246,9 +233,6 @@ class PerceiverModelTester:
|
|||
def create_and_check_for_image_classification_conv(
|
||||
self, config, inputs, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
# set d_model and num_labels
|
||||
config.d_model = 322
|
||||
config.num_labels = self.num_labels
|
||||
model = PerceiverForImageClassificationConvProcessing(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
|
|
|||
Loading…
Reference in a new issue