Summary: The goal of this PR was to add support for dropout descriptors in the C++ API's RNN class. The end result is a 4x-5x speedup for our RNN integration tests since they can now use cuDNN instead of autograd when dropout is set. To achieve this, I had to move `_cudnn_init_dropout_state` to the `TensorOptions` API. I also fixed a bug around `RNN::cuda()` not flattening parameters for cuDNN. ebetica ezyang Closes https://github.com/pytorch/pytorch/pull/9012 Reviewed By: pjh5 Differential Revision: D8689786 Pulled By: goldsborough fbshipit-source-id: 44fb191f5a38e41c4ded5417306b5bbc012cd56c |
||
|---|---|---|
| .. | ||
| any.cpp | ||
| cursor.cpp | ||
| integration.cpp | ||
| main.cpp | ||
| misc.cpp | ||
| module.cpp | ||
| modules.cpp | ||
| optim.cpp | ||
| optim_baseline.h | ||
| optim_baseline.py | ||
| README.md | ||
| rnn.cpp | ||
| sequential.cpp | ||
| serialization.cpp | ||
| static.cpp | ||
| tensor.cpp | ||
| tensor_cuda.cpp | ||
| tensor_options.cpp | ||
| tensor_options_cuda.cpp | ||
| util.h | ||
C++ API Tests
In this folder live the tests for PyTorch's C++ API (formerly known as autogradpp). They use the Catch2 test framework.
CUDA Tests
The way we handle CUDA tests is by separating them into a separate TEST_CASE
(e.g. we have optim and optim_cuda test cases in optim.cpp), and giving
them the [cuda] tag. Then, inside main.cpp we detect at runtime whether
CUDA is available. If not, we disable these CUDA tests by appending ~[cuda]
to the test specifications. The ~ disables the tag.
One annoying aspect is that Catch only allows filtering on test cases and not
sections. Ideally, one could have a section like LSTM inside the RNN test
case, and give this section a [cuda] tag to only run it when CUDA is
available. Instead, we have to create a whole separate RNN_cuda test case and
put all these CUDA sections in there.
Integration Tests
Integration tests use the MNIST dataset. You must download it by running the following command from the PyTorch root folder:
$ python tools/download_mnist.py -d test/cpp/api/mnist