From 27d4b34ea6ce4ae863a81feb8740eececdd4ec81 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 30 Sep 2019 13:35:06 -0700 Subject: [PATCH] Add temporary torch::k{name} enum declarations (#27051) Summary: This PR adds temporary declarations for `torch::k{name}` enums, so that we can submit a PR to rename the enum usage in torchvision. And then, after the changes to torchvision is done, we can remove the temporary declarations in https://github.com/pytorch/pytorch/pull/26837 to officially move over to using `c10::variant` for enums. Pull Request resolved: https://github.com/pytorch/pytorch/pull/27051 Differential Revision: D17672220 Pulled By: yf225 fbshipit-source-id: 4ae77634e8c7efa3404698f7c1a69177cbb5dab3 --- test/cpp/api/init.cpp | 6 +++--- torch/csrc/api/include/torch/nn/init.h | 30 ++++++++++++++++++++++---- torch/csrc/api/src/nn/init.cpp | 30 ++++++++++++++++++++------ 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/test/cpp/api/init.cpp b/test/cpp/api/init.cpp index 5527d725376..12faf0385d3 100644 --- a/test/cpp/api/init.cpp +++ b/test/cpp/api/init.cpp @@ -110,19 +110,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) { TEST(InitTest, CalculateGainWithTanh) { double gain = - torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh); + torch::nn::init::calculate_gain(torch::kTanh); ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0); } TEST(InitTest, CalculateGainWithRelu) { double gain = - torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU); + torch::nn::init::calculate_gain(torch::kReLU); ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0)); } TEST(InitTest, CalculateGainWithLeakyRelu) { double gain = - torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU); + torch::nn::init::calculate_gain(torch::kLeakyReLU); ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2)))); } diff --git a/torch/csrc/api/include/torch/nn/init.h b/torch/csrc/api/include/torch/nn/init.h index 4e43462ee14..5c5c9087619 100644 --- a/torch/csrc/api/include/torch/nn/init.h +++ b/torch/csrc/api/include/torch/nn/init.h @@ -23,6 +23,28 @@ enum class Nonlinearity { enum class FanMode { FanIn, FanOut }; +} // namespace init +} // nn + +// TODO: Remove the declarations here in https://github.com/pytorch/pytorch/pull/26837. +TORCH_API extern const nn::init::Nonlinearity kLinear; +TORCH_API extern const nn::init::Nonlinearity kConv1D; +TORCH_API extern const nn::init::Nonlinearity kConv2D; +TORCH_API extern const nn::init::Nonlinearity kConv3D; +TORCH_API extern const nn::init::Nonlinearity kConvTranspose1D; +TORCH_API extern const nn::init::Nonlinearity kConvTranspose2D; +TORCH_API extern const nn::init::Nonlinearity kConvTranspose3D; +TORCH_API extern const nn::init::Nonlinearity kSigmoid; +TORCH_API extern const nn::init::Nonlinearity kTanh; +TORCH_API extern const nn::init::Nonlinearity kReLU; +TORCH_API extern const nn::init::Nonlinearity kLeakyReLU; + +TORCH_API extern const nn::init::FanMode kFanIn; +TORCH_API extern const nn::init::FanMode kFanOut; + +namespace nn { +namespace init { + /// Return the recommended gain value for the given nonlinearity function. TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01); @@ -77,8 +99,8 @@ TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1); TORCH_API Tensor kaiming_normal_( Tensor tensor, double a = 0, - FanMode mode = FanMode::FanIn, - Nonlinearity nonlinearity = Nonlinearity::LeakyReLU); + FanMode mode = torch::kFanIn, + Nonlinearity nonlinearity = torch::kLeakyReLU); /// Fills the input `Tensor` with values according to the method /// described in "Delving deep into rectifiers: Surpassing human-level @@ -88,8 +110,8 @@ TORCH_API Tensor kaiming_normal_( TORCH_API Tensor kaiming_uniform_( Tensor tensor, double a = 0, - FanMode mode = FanMode::FanIn, - Nonlinearity nonlinearity = Nonlinearity::LeakyReLU); + FanMode mode = torch::kFanIn, + Nonlinearity nonlinearity = torch::kLeakyReLU); /// Fills the input `Tensor` with values according to the method /// described in "Understanding the difficulty of training deep feedforward diff --git a/torch/csrc/api/src/nn/init.cpp b/torch/csrc/api/src/nn/init.cpp index c16f2b2b9ab..388e64c6c8b 100644 --- a/torch/csrc/api/src/nn/init.cpp +++ b/torch/csrc/api/src/nn/init.cpp @@ -12,6 +12,22 @@ #include namespace torch { + +const nn::init::Nonlinearity kLinear = nn::init::Nonlinearity::Linear; +const nn::init::Nonlinearity kConv1D = nn::init::Nonlinearity::Conv1D; +const nn::init::Nonlinearity kConv2D = nn::init::Nonlinearity::Conv2D; +const nn::init::Nonlinearity kConv3D = nn::init::Nonlinearity::Conv3D; +const nn::init::Nonlinearity kConvTranspose1D = nn::init::Nonlinearity::ConvTranspose1D; +const nn::init::Nonlinearity kConvTranspose2D = nn::init::Nonlinearity::ConvTranspose2D; +const nn::init::Nonlinearity kConvTranspose3D = nn::init::Nonlinearity::ConvTranspose3D; +const nn::init::Nonlinearity kSigmoid = nn::init::Nonlinearity::Sigmoid; +const nn::init::Nonlinearity kTanh = nn::init::Nonlinearity::Tanh; +const nn::init::Nonlinearity kReLU = nn::init::Nonlinearity::ReLU; +const nn::init::Nonlinearity kLeakyReLU = nn::init::Nonlinearity::LeakyReLU; + +const nn::init::FanMode kFanIn = nn::init::FanMode::FanIn; +const nn::init::FanMode kFanOut = nn::init::FanMode::FanOut; + namespace nn { namespace init { namespace { @@ -44,7 +60,7 @@ double calculate_kaiming_std( Fan fan(tensor); const auto gain = calculate_gain(nonlinearity, a); double std = 0.0; - if (mode == FanMode::FanIn) { + if (mode == torch::kFanIn) { std = gain / std::sqrt(fan.in); } else { std = gain / std::sqrt(fan.out); @@ -54,12 +70,12 @@ double calculate_kaiming_std( } // namespace double calculate_gain(Nonlinearity nonlinearity, double param) { - if (nonlinearity == Nonlinearity::Tanh) { - return 5.0 / 3.0; - } else if (nonlinearity == Nonlinearity::ReLU) { - return std::sqrt(2.0); - } else if (nonlinearity == Nonlinearity::LeakyReLU) { - return std::sqrt(2.0 / (1 + pow(param, 2))); + if (nonlinearity == torch::kTanh) { + return 5.0 / 3.0; // NOLINT + } else if (nonlinearity == torch::kReLU) { + return std::sqrt(2.0); // NOLINT + } else if (nonlinearity == torch::kLeakyReLU) { + return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT } return 1.0;