mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Doc improvements (#11347)
Summary: 1. Remove cudnn* symbols from C++ docs 2. Fix code examples for `nn::Module` and `jit::compile` 3. Document Dropout Pull Request resolved: https://github.com/pytorch/pytorch/pull/11347 Differential Revision: D9716751 Pulled By: goldsborough fbshipit-source-id: e0566cec35848335cac3eb9196cb244bb0c8fa45
This commit is contained in:
parent
7de0332e10
commit
77b6d7d255
4 changed files with 57 additions and 32 deletions
|
|
@ -841,7 +841,7 @@ EXCLUDE_PATTERNS =
|
|||
# Note that the wildcards are matched against the file with absolute path, so to
|
||||
# exclude all test directories use the pattern */test/*
|
||||
|
||||
EXCLUDE_SYMBOLS = c10::* caffe2::* cereal* DL* TH*
|
||||
EXCLUDE_SYMBOLS = c10::* caffe2::* cereal* DL* TH* cudnn*
|
||||
|
||||
# The EXAMPLE_PATH tag can be used to specify one or more files or directories
|
||||
# that contain example code fragments that are included (see the \include
|
||||
|
|
|
|||
|
|
@ -9,10 +9,15 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/// Compiles Python JIT code into a graph that can be executed.
|
||||
/// Compiles script code into an executable graph.
|
||||
///
|
||||
/// Takes a string containing functions in script syntax and compiles them into
|
||||
/// a module (graph). The returned module provides a `run_method` function
|
||||
/// that may be used to invoke the compiled functions.
|
||||
///
|
||||
/// For example:
|
||||
/// @code
|
||||
/// \rst
|
||||
/// .. code-block::
|
||||
/// auto module = torch::jit::compile(R"JIT(
|
||||
/// def relu_script(a, b):
|
||||
/// return torch.relu(a + b)
|
||||
|
|
@ -23,11 +28,7 @@ namespace jit {
|
|||
/// return a
|
||||
/// )JIT");
|
||||
/// IValue output = module->run_method("relu_script", a, b);
|
||||
/// @endcode
|
||||
///
|
||||
/// @param source A string containing functions containing script code to
|
||||
/// compile
|
||||
/// @return A module containing the compiled functions
|
||||
/// \endrst
|
||||
std::shared_ptr<script::Module> compile(const std::string& source);
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ namespace nn {
|
|||
///
|
||||
/// \rst
|
||||
/// .. note::
|
||||
///
|
||||
/// The design and implementation of this class is largely based on the Python
|
||||
/// API. You may want to consult [its
|
||||
/// documentation](https://pytorch.org/docs/master/nn.html#torch.nn.Module)
|
||||
|
|
@ -219,13 +218,12 @@ class Module {
|
|||
/// This method is useful when calling `apply()` on a `ModuleCursor`.
|
||||
/// \rst
|
||||
/// .. code-block:: cpp
|
||||
///
|
||||
/// void initialize_weights(nn::Module& module) {
|
||||
/// torch::NoGradGuard no_grad;
|
||||
/// if (auto* linear = module.as<nn::Linear>()) {
|
||||
/// linear->weight.normal_(0.0, 0.02);
|
||||
/// void initialize_weights(nn::Module& module) {
|
||||
/// torch::NoGradGuard no_grad;
|
||||
/// if (auto* linear = module.as<nn::Linear>()) {
|
||||
/// linear->weight.normal_(0.0, 0.02);
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// MyModule module;
|
||||
/// module->modules().apply(initialize_weights);
|
||||
|
|
@ -239,12 +237,12 @@ class Module {
|
|||
/// \rst
|
||||
/// .. code-block:: cpp
|
||||
///
|
||||
/// void initialize_weights(nn::Module& module) {
|
||||
/// torch::NoGradGuard no_grad;
|
||||
/// if (auto* linear = module.as<nn::Linear>()) {
|
||||
/// linear->weight.normal_(0.0, 0.02);
|
||||
/// void initialize_weights(nn::Module& module) {
|
||||
/// torch::NoGradGuard no_grad;
|
||||
/// if (auto* linear = module.as<nn::Linear>()) {
|
||||
/// linear->weight.normal_(0.0, 0.02);
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// MyModule module;
|
||||
/// module->modules().apply(initialize_weights);
|
||||
|
|
@ -263,9 +261,9 @@ class Module {
|
|||
///
|
||||
/// \rst
|
||||
/// .. code-block: cpp
|
||||
/// MyModule::MyModule() {
|
||||
/// weight_ = register_parameter("weight", torch::randn({A, B}));
|
||||
/// }
|
||||
/// MyModule::MyModule() {
|
||||
/// weight_ = register_parameter("weight", torch::randn({A, B}));
|
||||
/// }
|
||||
/// \endrst
|
||||
Tensor& register_parameter(
|
||||
std::string name,
|
||||
|
|
@ -280,9 +278,9 @@ class Module {
|
|||
///
|
||||
/// \rst
|
||||
/// .. code-block: cpp
|
||||
/// MyModule::MyModule() {
|
||||
/// mean_ = register_buffer("mean", torch::empty({num_features_}));
|
||||
/// }
|
||||
/// MyModule::MyModule() {
|
||||
/// mean_ = register_buffer("mean", torch::empty({num_features_}));
|
||||
/// }
|
||||
/// \endrst
|
||||
Tensor& register_buffer(std::string name, Tensor tensor);
|
||||
|
||||
|
|
@ -293,9 +291,9 @@ class Module {
|
|||
///
|
||||
/// \rst
|
||||
/// .. code-block: cpp
|
||||
/// MyModule::MyModule() {
|
||||
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
|
||||
/// }
|
||||
/// MyModule::MyModule() {
|
||||
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
|
||||
/// }
|
||||
/// \endrst
|
||||
template <typename ModuleType>
|
||||
std::shared_ptr<ModuleType> register_module(
|
||||
|
|
@ -311,9 +309,9 @@ class Module {
|
|||
///
|
||||
/// \rst
|
||||
/// .. code-block: cpp
|
||||
/// MyModule::MyModule() {
|
||||
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
|
||||
/// }
|
||||
/// MyModule::MyModule() {
|
||||
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
|
||||
/// }
|
||||
/// \endrst
|
||||
template <typename ModuleType>
|
||||
std::shared_ptr<ModuleType> register_module(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ namespace torch {
|
|||
namespace nn {
|
||||
struct DropoutOptions {
|
||||
DropoutOptions(double rate);
|
||||
/// The probability with which a particular component of the input is set to
|
||||
/// zero.
|
||||
TORCH_ARG(double, rate) = 0.5;
|
||||
};
|
||||
|
||||
|
|
@ -23,26 +25,50 @@ class DropoutImplBase : public torch::nn::Cloneable<Derived> {
|
|||
explicit DropoutImplBase(DropoutOptions options_);
|
||||
|
||||
void reset() override;
|
||||
|
||||
/// During training, applies a noise mask to the input tensor.
|
||||
/// During evaluation, applies an identity function.
|
||||
Tensor forward(Tensor input);
|
||||
|
||||
/// Returns a noise mask that can be applied to the given input tensor.
|
||||
/// Used inside `forward()` to generate the noise mask for dropout.
|
||||
virtual Tensor noise_mask(Tensor input) const = 0;
|
||||
|
||||
DropoutOptions options;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Applies [Dropout](https://arxiv.org/abs/1207.0580) during training.
|
||||
///
|
||||
/// See https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout to learn more
|
||||
/// about the exact semantics of this module.
|
||||
class DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
|
||||
public:
|
||||
using detail::DropoutImplBase<DropoutImpl>::DropoutImplBase;
|
||||
Tensor noise_mask(Tensor input) const override;
|
||||
};
|
||||
|
||||
/// Applies [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
|
||||
/// 2-dimensional features.
|
||||
///
|
||||
/// See https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d to learn more
|
||||
/// about the exact semantics of this module.
|
||||
class Dropout2dImpl : public detail::DropoutImplBase<Dropout2dImpl> {
|
||||
public:
|
||||
using detail::DropoutImplBase<Dropout2dImpl>::DropoutImplBase;
|
||||
Tensor noise_mask(Tensor input) const override;
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `DropoutImpl`.
|
||||
/// See the documentation for `DropoutImpl` class to learn what methods it
|
||||
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(Dropout);
|
||||
|
||||
/// A `ModuleHolder` subclass for `Dropout2dImpl`.
|
||||
/// See the documentation for `Dropout2dImpl` class to learn what methods it
|
||||
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(Dropout2d);
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in a new issue