Add temporary torch::k{name} enum declarations (#27051)

Summary:
This PR adds temporary declarations for `torch::k{name}` enums, so that we can submit a PR to rename the enum usage in torchvision. And then, after the changes to torchvision is done, we can remove the temporary declarations in https://github.com/pytorch/pytorch/pull/26837 to officially move over to using `c10::variant` for enums.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27051

Differential Revision: D17672220

Pulled By: yf225

fbshipit-source-id: 4ae77634e8c7efa3404698f7c1a69177cbb5dab3
This commit is contained in:
Will Feng 2019-09-30 13:35:06 -07:00 committed by Facebook Github Bot
parent 9159a601ca
commit 27d4b34ea6
3 changed files with 52 additions and 14 deletions

View file

@ -110,19 +110,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
TEST(InitTest, CalculateGainWithTanh) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
torch::nn::init::calculate_gain(torch::kTanh);
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
}
TEST(InitTest, CalculateGainWithRelu) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
torch::nn::init::calculate_gain(torch::kReLU);
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
}
TEST(InitTest, CalculateGainWithLeakyRelu) {
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
torch::nn::init::calculate_gain(torch::kLeakyReLU);
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
}

View file

@ -23,6 +23,28 @@ enum class Nonlinearity {
enum class FanMode { FanIn, FanOut };
} // namespace init
} // nn
// TODO: Remove the declarations here in https://github.com/pytorch/pytorch/pull/26837.
TORCH_API extern const nn::init::Nonlinearity kLinear;
TORCH_API extern const nn::init::Nonlinearity kConv1D;
TORCH_API extern const nn::init::Nonlinearity kConv2D;
TORCH_API extern const nn::init::Nonlinearity kConv3D;
TORCH_API extern const nn::init::Nonlinearity kConvTranspose1D;
TORCH_API extern const nn::init::Nonlinearity kConvTranspose2D;
TORCH_API extern const nn::init::Nonlinearity kConvTranspose3D;
TORCH_API extern const nn::init::Nonlinearity kSigmoid;
TORCH_API extern const nn::init::Nonlinearity kTanh;
TORCH_API extern const nn::init::Nonlinearity kReLU;
TORCH_API extern const nn::init::Nonlinearity kLeakyReLU;
TORCH_API extern const nn::init::FanMode kFanIn;
TORCH_API extern const nn::init::FanMode kFanOut;
namespace nn {
namespace init {
/// Return the recommended gain value for the given nonlinearity function.
TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01);
@ -77,8 +99,8 @@ TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
TORCH_API Tensor kaiming_normal_(
Tensor tensor,
double a = 0,
FanMode mode = FanMode::FanIn,
Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
FanMode mode = torch::kFanIn,
Nonlinearity nonlinearity = torch::kLeakyReLU);
/// Fills the input `Tensor` with values according to the method
/// described in "Delving deep into rectifiers: Surpassing human-level
@ -88,8 +110,8 @@ TORCH_API Tensor kaiming_normal_(
TORCH_API Tensor kaiming_uniform_(
Tensor tensor,
double a = 0,
FanMode mode = FanMode::FanIn,
Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
FanMode mode = torch::kFanIn,
Nonlinearity nonlinearity = torch::kLeakyReLU);
/// Fills the input `Tensor` with values according to the method
/// described in "Understanding the difficulty of training deep feedforward

View file

@ -12,6 +12,22 @@
#include <tuple>
namespace torch {
const nn::init::Nonlinearity kLinear = nn::init::Nonlinearity::Linear;
const nn::init::Nonlinearity kConv1D = nn::init::Nonlinearity::Conv1D;
const nn::init::Nonlinearity kConv2D = nn::init::Nonlinearity::Conv2D;
const nn::init::Nonlinearity kConv3D = nn::init::Nonlinearity::Conv3D;
const nn::init::Nonlinearity kConvTranspose1D = nn::init::Nonlinearity::ConvTranspose1D;
const nn::init::Nonlinearity kConvTranspose2D = nn::init::Nonlinearity::ConvTranspose2D;
const nn::init::Nonlinearity kConvTranspose3D = nn::init::Nonlinearity::ConvTranspose3D;
const nn::init::Nonlinearity kSigmoid = nn::init::Nonlinearity::Sigmoid;
const nn::init::Nonlinearity kTanh = nn::init::Nonlinearity::Tanh;
const nn::init::Nonlinearity kReLU = nn::init::Nonlinearity::ReLU;
const nn::init::Nonlinearity kLeakyReLU = nn::init::Nonlinearity::LeakyReLU;
const nn::init::FanMode kFanIn = nn::init::FanMode::FanIn;
const nn::init::FanMode kFanOut = nn::init::FanMode::FanOut;
namespace nn {
namespace init {
namespace {
@ -44,7 +60,7 @@ double calculate_kaiming_std(
Fan fan(tensor);
const auto gain = calculate_gain(nonlinearity, a);
double std = 0.0;
if (mode == FanMode::FanIn) {
if (mode == torch::kFanIn) {
std = gain / std::sqrt(fan.in);
} else {
std = gain / std::sqrt(fan.out);
@ -54,12 +70,12 @@ double calculate_kaiming_std(
} // namespace
double calculate_gain(Nonlinearity nonlinearity, double param) {
if (nonlinearity == Nonlinearity::Tanh) {
return 5.0 / 3.0;
} else if (nonlinearity == Nonlinearity::ReLU) {
return std::sqrt(2.0);
} else if (nonlinearity == Nonlinearity::LeakyReLU) {
return std::sqrt(2.0 / (1 + pow(param, 2)));
if (nonlinearity == torch::kTanh) {
return 5.0 / 3.0; // NOLINT
} else if (nonlinearity == torch::kReLU) {
return std::sqrt(2.0); // NOLINT
} else if (nonlinearity == torch::kLeakyReLU) {
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
}
return 1.0;