mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary:
This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.
**BC-breaking changes:**
- Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API.
- RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API.
- RNN / LSTM / GRU layers now have `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API.
- RNN / LSTM / GRU layers no longer have these fields: `w_ih` / `w_hh` / `b_ih` / `b_hh`. Instead, to access the weights and biases of the gates, users should do e.g. `rnn->named_parameters()["weight_ih_l0"]`, which mirrors the Python API `rnn.weight_ih_l0`.
- In `RNNOptions`
- `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
- In `LSTMOptions`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
- In `GRUOptions`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322
Differential Revision: D20458302
Pulled By: yf225
fbshipit-source-id: ffff2ae1ddb1c742c966956f6ad4d7fba03dc54d
6.4 KiB
6.4 KiB
C++ / Python API parity tracker
torch::nn
| API | Implementation Parity | Doc Parity |
|---|---|---|
| torch::nn::Sequential | Yes | No |
| torch::nn::ModuleList | Yes | No |
| torch::nn::ModuleDict | No | No |
| torch::nn::ParameterList | No | No |
| torch::nn::ParameterDict | No | No |
| torch::nn::Conv1d | Yes | No |
| torch::nn::Conv2d | Yes | No |
| torch::nn::Conv3d | Yes | No |
| torch::nn::ConvTranspose1d | Yes | No |
| torch::nn::ConvTranspose2d | Yes | No |
| torch::nn::ConvTranspose3d | Yes | No |
| torch::nn::Unfold | Yes | No |
| torch::nn::Fold | Yes | No |
| torch::nn::MaxPool1d | Yes | No |
| torch::nn::MaxPool2d | Yes | No |
| torch::nn::MaxPool3d | Yes | No |
| torch::nn::MaxUnpool1d | Yes | No |
| torch::nn::MaxUnpool2d | Yes | No |
| torch::nn::MaxUnpool3d | Yes | No |
| torch::nn::AvgPool1d | Yes | No |
| torch::nn::AvgPool2d | Yes | No |
| torch::nn::AvgPool3d | Yes | No |
| torch::nn::FractionalMaxPool2d | Yes | No |
| torch::nn::FractionalMaxPool3d | Yes | No |
| torch::nn::LPPool1d | Yes | No |
| torch::nn::LPPool2d | Yes | No |
| torch::nn::AdaptiveMaxPool1d | Yes | No |
| torch::nn::AdaptiveMaxPool2d | Yes | No |
| torch::nn::AdaptiveMaxPool3d | Yes | No |
| torch::nn::AdaptiveAvgPool1d | Yes | No |
| torch::nn::AdaptiveAvgPool2d | Yes | No |
| torch::nn::AdaptiveAvgPool3d | Yes | No |
| torch::nn::ReflectionPad1d | Yes | No |
| torch::nn::ReflectionPad2d | Yes | No |
| torch::nn::ReplicationPad1d | Yes | No |
| torch::nn::ReplicationPad2d | Yes | No |
| torch::nn::ReplicationPad3d | Yes | No |
| torch::nn::ZeroPad2d | Yes | No |
| torch::nn::ConstantPad1d | Yes | No |
| torch::nn::ConstantPad2d | Yes | No |
| torch::nn::ConstantPad3d | Yes | No |
| torch::nn::ELU | Yes | No |
| torch::nn::Hardshrink | Yes | No |
| torch::nn::Hardtanh | Yes | No |
| torch::nn::LeakyReLU | Yes | No |
| torch::nn::LogSigmoid | Yes | No |
| torch::nn::MultiheadAttention | No | No |
| torch::nn::PReLU | Yes | No |
| torch::nn::ReLU | Yes | No |
| torch::nn::ReLU6 | Yes | No |
| torch::nn::RReLU | Yes | No |
| torch::nn::SELU | Yes | No |
| torch::nn::CELU | Yes | No |
| torch::nn::GELU | Yes | No |
| torch::nn::Sigmoid | Yes | No |
| torch::nn::Softplus | Yes | No |
| torch::nn::Softshrink | Yes | No |
| torch::nn::Softsign | Yes | No |
| torch::nn::Tanh | Yes | No |
| torch::nn::Tanhshrink | Yes | No |
| torch::nn::Threshold | Yes | No |
| torch::nn::GLU | Yes | No |
| torch::nn::Softmin | Yes | No |
| torch::nn::Softmax | Yes | No |
| torch::nn::Softmax2d | Yes | No |
| torch::nn::LogSoftmax | Yes | No |
| torch::nn::AdaptiveLogSoftmaxWithLoss | Yes | No |
| torch::nn::BatchNorm1d | Yes | No |
| torch::nn::BatchNorm2d | Yes | No |
| torch::nn::BatchNorm3d | Yes | No |
| torch::nn::GroupNorm | Yes | No |
| torch::nn::SyncBatchNorm | No | No |
| torch::nn::InstanceNorm1d | Yes | No |
| torch::nn::InstanceNorm2d | Yes | No |
| torch::nn::InstanceNorm3d | Yes | No |
| torch::nn::LayerNorm | Yes | No |
| torch::nn::LocalResponseNorm | Yes | No |
| torch::nn::CrossMapLRN2d | Yes | No |
| torch::nn::RNN | Yes | No |
| torch::nn::LSTM | Yes | No |
| torch::nn::GRU | Yes | No |
| torch::nn::RNNCell | Yes | No |
| torch::nn::LSTMCell | Yes | No |
| torch::nn::GRUCell | Yes | No |
| torch::nn::Transformer | No | No |
| torch::nn::TransformerEncoder | No | No |
| torch::nn::TransformerDecoder | No | No |
| torch::nn::TransformerEncoderLayer | No | No |
| torch::nn::TransformerDecoderLayer | No | No |
| torch::nn::Identity | Yes | No |
| torch::nn::Linear | Yes | No |
| torch::nn::Bilinear | Yes | No |
| torch::nn::Flatten | Yes | No |
| torch::nn::Dropout | Yes | No |
| torch::nn::Dropout2d | Yes | No |
| torch::nn::Dropout3d | Yes | No |
| torch::nn::AlphaDropout | Yes | No |
| torch::nn::FeatureAlphaDropout | Yes | No |
| torch::nn::Embedding | Yes | No |
| torch::nn::EmbeddingBag | Yes | No |
| torch::nn::CosineSimilarity | Yes | No |
| torch::nn::PairwiseDistance | Yes | No |
| torch::nn::L1Loss | Yes | No |
| torch::nn::MSELoss | Yes | No |
| torch::nn::CrossEntropyLoss | Yes | No |
| torch::nn::CTCLoss | Yes | No |
| torch::nn::NLLLoss | Yes | No |
| torch::nn::PoissonNLLLoss | Yes | No |
| torch::nn::KLDivLoss | Yes | No |
| torch::nn::BCELoss | Yes | No |
| torch::nn::BCEWithLogitsLoss | Yes | No |
| torch::nn::MarginRankingLoss | Yes | No |
| torch::nn::HingeEmbeddingLoss | Yes | No |
| torch::nn::MultiLabelMarginLoss | Yes | No |
| torch::nn::SmoothL1Loss | Yes | No |
| torch::nn::SoftMarginLoss | Yes | No |
| torch::nn::MultiLabelSoftMarginLoss | Yes | No |
| torch::nn::CosineEmbeddingLoss | Yes | No |
| torch::nn::MultiMarginLoss | Yes | No |
| torch::nn::TripletMarginLoss | Yes | No |
| torch::nn::PixelShuffle | Yes | No |
| torch::nn::Upsample | Yes | No |
| torch::nn::DataParallel | No | No |
| torch::nn::parallel::DistributedDataParallel | No | No |
| torch::nn::utils::clip_grad_norm_ | Yes | No |
| torch::nn::utils::clip_grad_value_ | Yes | No |
| torch::nn::utils::parameters_to_vector | Yes | No |
| torch::nn::utils::vector_to_parameters | Yes | No |
| torch::nn::utils::weight_norm | No | No |
| torch::nn::utils::remove_weight_norm | No | No |
| torch::nn::utils::spectral_norm | No | No |
| torch::nn::utils::remove_spectral_norm | No | No |
| torch::nn::utils::rnn::PackedSequence | Yes | No |
| torch::nn::utils::rnn::pack_padded_sequence | Yes | No |
| torch::nn::utils::rnn::pad_packed_sequence | Yes | No |
| torch::nn::utils::rnn::pad_sequence | Yes | No |
| torch::nn::utils::rnn::pack_sequence | Yes | No |
| torch::nn::SampleModule | Yes | Yes |
torch::nn::functional
| API | Implementation Parity | Doc Parity |
|---|---|---|
| F::conv1d | Yes | No |
| F::conv2d | Yes | No |
| F::conv3d | Yes | No |
| F::conv_transpose1d | Yes | No |
| F::conv_transpose2d | Yes | No |
| F::conv_transpose3d | Yes | No |
| F::unfold | Yes | No |
| F::fold | Yes | No |
| F::avg_pool1d | Yes | No |
| F::avg_pool2d | Yes | No |
| F::avg_pool3d | Yes | No |
| F::max_pool1d | Yes | No |
| F::max_pool2d | Yes | No |
| F::max_pool3d | Yes | No |
| F::max_unpool1d | Yes | No |
| F::max_unpool2d | Yes | No |
| F::max_unpool3d | Yes | No |
| F::lp_pool1d | Yes | No |
| F::lp_pool2d | Yes | No |
| F::adaptive_max_pool1d | Yes | No |
| F::adaptive_max_pool2d | Yes | No |
| F::adaptive_max_pool3d | Yes | No |
| F::adaptive_avg_pool1d | Yes | No |
| F::adaptive_avg_pool2d | Yes | No |
| F::adaptive_avg_pool3d | Yes | No |
| F::threshold | Yes | No |
| F::relu | Yes | No |
| F::hardtanh | Yes | No |
| F::relu6 | Yes | No |
| F::elu | Yes | No |
| F::selu | Yes | No |
| F::celu | Yes | No |
| F::leaky_relu | Yes | No |
| F::prelu | Yes | No |
| F::rrelu | Yes | No |
| F::glu | Yes | No |
| F::gelu | Yes | No |
| F::logsigmoid | Yes | No |
| F::hardshrink | Yes | No |
| F::tanhshrink | Yes | No |
| F::softsign | Yes | No |
| F::softplus | Yes | No |
| F::softmin | Yes | No |
| F::softmax | Yes | No |
| F::softshrink | Yes | No |
| F::gumbel_softmax | Yes | No |
| F::log_softmax | Yes | No |
| F::batch_norm | Yes | No |
| F::instance_norm | Yes | No |
| F::layer_norm | Yes | No |
| F::local_response_norm | Yes | No |
| F::normalize | Yes | No |
| F::linear | Yes | No |
| F::bilinear | Yes | No |
| F::dropout | Yes | No |
| F::alpha_dropout | Yes | No |
| F::dropout2d | Yes | No |
| F::dropout3d | Yes | No |
| F::embedding | Yes | No |
| F::embedding_bag | Yes | No |
| F::one_hot | Yes | No |
| F::pairwise_distance | Yes | No |
| F::cosine_similarity | Yes | No |
| F::pdist | Yes | No |
| F::binary_cross_entropy | Yes | No |
| F::binary_cross_entropy_with_logits | Yes | No |
| F::poisson_nll_loss | Yes | No |
| F::cosine_embedding_loss | Yes | No |
| F::cross_entropy | Yes | No |
| F::ctc_loss | Yes | No |
| F::hinge_embedding_loss | Yes | No |
| F::kl_div | Yes | No |
| F::l1_loss | Yes | No |
| F::mse_loss | Yes | No |
| F::margin_ranking_loss | Yes | No |
| F::multilabel_margin_loss | Yes | No |
| F::multilabel_soft_margin_loss | Yes | No |
| F::multi_margin_loss | Yes | No |
| F::nll_loss | Yes | No |
| F::smooth_l1_loss | Yes | No |
| F::soft_margin_loss | Yes | No |
| F::triplet_margin_loss | Yes | No |
| F::pixel_shuffle | Yes | No |
| F::pad | Yes | No |
| F::interpolate | Yes | No |
| F::grid_sample | Yes | No |
| F::affine_grid | Yes | No |