Doc improvements (#11347)

Summary:
1. Remove cudnn* symbols from C++ docs
2. Fix code examples for `nn::Module` and `jit::compile`
3. Document Dropout
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11347

Differential Revision: D9716751

Pulled By: goldsborough

fbshipit-source-id: e0566cec35848335cac3eb9196cb244bb0c8fa45
This commit is contained in:
Peter Goldsborough 2018-09-07 14:29:04 -07:00 committed by Facebook Github Bot
parent 7de0332e10
commit 77b6d7d255
4 changed files with 57 additions and 32 deletions

View file

@ -841,7 +841,7 @@ EXCLUDE_PATTERNS =
# Note that the wildcards are matched against the file with absolute path, so to
# exclude all test directories use the pattern */test/*
EXCLUDE_SYMBOLS = c10::* caffe2::* cereal* DL* TH*
EXCLUDE_SYMBOLS = c10::* caffe2::* cereal* DL* TH* cudnn*
# The EXAMPLE_PATH tag can be used to specify one or more files or directories
# that contain example code fragments that are included (see the \include

View file

@ -9,10 +9,15 @@
namespace torch {
namespace jit {
/// Compiles Python JIT code into a graph that can be executed.
/// Compiles script code into an executable graph.
///
/// Takes a string containing functions in script syntax and compiles them into
/// a module (graph). The returned module provides a `run_method` function
/// that may be used to invoke the compiled functions.
///
/// For example:
/// @code
/// \rst
/// .. code-block::
/// auto module = torch::jit::compile(R"JIT(
/// def relu_script(a, b):
/// return torch.relu(a + b)
@ -23,11 +28,7 @@ namespace jit {
/// return a
/// )JIT");
/// IValue output = module->run_method("relu_script", a, b);
/// @endcode
///
/// @param source A string containing functions containing script code to
/// compile
/// @return A module containing the compiled functions
/// \endrst
std::shared_ptr<script::Module> compile(const std::string& source);
} // namespace jit

View file

@ -28,7 +28,6 @@ namespace nn {
///
/// \rst
/// .. note::
///
/// The design and implementation of this class is largely based on the Python
/// API. You may want to consult [its
/// documentation](https://pytorch.org/docs/master/nn.html#torch.nn.Module)
@ -219,13 +218,12 @@ class Module {
/// This method is useful when calling `apply()` on a `ModuleCursor`.
/// \rst
/// .. code-block:: cpp
///
/// void initialize_weights(nn::Module& module) {
/// torch::NoGradGuard no_grad;
/// if (auto* linear = module.as<nn::Linear>()) {
/// linear->weight.normal_(0.0, 0.02);
/// void initialize_weights(nn::Module& module) {
/// torch::NoGradGuard no_grad;
/// if (auto* linear = module.as<nn::Linear>()) {
/// linear->weight.normal_(0.0, 0.02);
/// }
/// }
/// }
///
/// MyModule module;
/// module->modules().apply(initialize_weights);
@ -239,12 +237,12 @@ class Module {
/// \rst
/// .. code-block:: cpp
///
/// void initialize_weights(nn::Module& module) {
/// torch::NoGradGuard no_grad;
/// if (auto* linear = module.as<nn::Linear>()) {
/// linear->weight.normal_(0.0, 0.02);
/// void initialize_weights(nn::Module& module) {
/// torch::NoGradGuard no_grad;
/// if (auto* linear = module.as<nn::Linear>()) {
/// linear->weight.normal_(0.0, 0.02);
/// }
/// }
/// }
///
/// MyModule module;
/// module->modules().apply(initialize_weights);
@ -263,9 +261,9 @@ class Module {
///
/// \rst
/// .. code-block: cpp
/// MyModule::MyModule() {
/// weight_ = register_parameter("weight", torch::randn({A, B}));
/// }
/// MyModule::MyModule() {
/// weight_ = register_parameter("weight", torch::randn({A, B}));
/// }
/// \endrst
Tensor& register_parameter(
std::string name,
@ -280,9 +278,9 @@ class Module {
///
/// \rst
/// .. code-block: cpp
/// MyModule::MyModule() {
/// mean_ = register_buffer("mean", torch::empty({num_features_}));
/// }
/// MyModule::MyModule() {
/// mean_ = register_buffer("mean", torch::empty({num_features_}));
/// }
/// \endrst
Tensor& register_buffer(std::string name, Tensor tensor);
@ -293,9 +291,9 @@ class Module {
///
/// \rst
/// .. code-block: cpp
/// MyModule::MyModule() {
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
/// }
/// MyModule::MyModule() {
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
/// }
/// \endrst
template <typename ModuleType>
std::shared_ptr<ModuleType> register_module(
@ -311,9 +309,9 @@ class Module {
///
/// \rst
/// .. code-block: cpp
/// MyModule::MyModule() {
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
/// }
/// MyModule::MyModule() {
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
/// }
/// \endrst
template <typename ModuleType>
std::shared_ptr<ModuleType> register_module(

View file

@ -11,6 +11,8 @@ namespace torch {
namespace nn {
struct DropoutOptions {
DropoutOptions(double rate);
/// The probability with which a particular component of the input is set to
/// zero.
TORCH_ARG(double, rate) = 0.5;
};
@ -23,26 +25,50 @@ class DropoutImplBase : public torch::nn::Cloneable<Derived> {
explicit DropoutImplBase(DropoutOptions options_);
void reset() override;
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(Tensor input);
/// Returns a noise mask that can be applied to the given input tensor.
/// Used inside `forward()` to generate the noise mask for dropout.
virtual Tensor noise_mask(Tensor input) const = 0;
DropoutOptions options;
};
} // namespace detail
/// Applies [Dropout](https://arxiv.org/abs/1207.0580) during training.
///
/// See https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout to learn more
/// about the exact semantics of this module.
class DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
public:
using detail::DropoutImplBase<DropoutImpl>::DropoutImplBase;
Tensor noise_mask(Tensor input) const override;
};
/// Applies [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
/// 2-dimensional features.
///
/// See https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d to learn more
/// about the exact semantics of this module.
class Dropout2dImpl : public detail::DropoutImplBase<Dropout2dImpl> {
public:
using detail::DropoutImplBase<Dropout2dImpl>::DropoutImplBase;
Tensor noise_mask(Tensor input) const override;
};
/// A `ModuleHolder` subclass for `DropoutImpl`.
/// See the documentation for `DropoutImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Dropout);
/// A `ModuleHolder` subclass for `Dropout2dImpl`.
/// See the documentation for `Dropout2dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Dropout2d);
} // namespace nn
} // namespace torch