mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add temporary torch::k{name} enum declarations (#27051)
Summary:
This PR adds temporary declarations for `torch::k{name}` enums, so that we can submit a PR to rename the enum usage in torchvision. And then, after the changes to torchvision is done, we can remove the temporary declarations in https://github.com/pytorch/pytorch/pull/26837 to officially move over to using `c10::variant` for enums.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27051
Differential Revision: D17672220
Pulled By: yf225
fbshipit-source-id: 4ae77634e8c7efa3404698f7c1a69177cbb5dab3
This commit is contained in:
parent
9159a601ca
commit
27d4b34ea6
3 changed files with 52 additions and 14 deletions
|
|
@ -110,19 +110,19 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
|
|||
|
||||
TEST(InitTest, CalculateGainWithTanh) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
|
||||
torch::nn::init::calculate_gain(torch::kTanh);
|
||||
ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
|
||||
}
|
||||
|
||||
TEST(InitTest, CalculateGainWithRelu) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
|
||||
torch::nn::init::calculate_gain(torch::kReLU);
|
||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
|
||||
}
|
||||
|
||||
TEST(InitTest, CalculateGainWithLeakyRelu) {
|
||||
double gain =
|
||||
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
|
||||
torch::nn::init::calculate_gain(torch::kLeakyReLU);
|
||||
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,28 @@ enum class Nonlinearity {
|
|||
|
||||
enum class FanMode { FanIn, FanOut };
|
||||
|
||||
} // namespace init
|
||||
} // nn
|
||||
|
||||
// TODO: Remove the declarations here in https://github.com/pytorch/pytorch/pull/26837.
|
||||
TORCH_API extern const nn::init::Nonlinearity kLinear;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConv1D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConv2D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConv3D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConvTranspose1D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConvTranspose2D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kConvTranspose3D;
|
||||
TORCH_API extern const nn::init::Nonlinearity kSigmoid;
|
||||
TORCH_API extern const nn::init::Nonlinearity kTanh;
|
||||
TORCH_API extern const nn::init::Nonlinearity kReLU;
|
||||
TORCH_API extern const nn::init::Nonlinearity kLeakyReLU;
|
||||
|
||||
TORCH_API extern const nn::init::FanMode kFanIn;
|
||||
TORCH_API extern const nn::init::FanMode kFanOut;
|
||||
|
||||
namespace nn {
|
||||
namespace init {
|
||||
|
||||
/// Return the recommended gain value for the given nonlinearity function.
|
||||
TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01);
|
||||
|
||||
|
|
@ -77,8 +99,8 @@ TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
|
|||
TORCH_API Tensor kaiming_normal_(
|
||||
Tensor tensor,
|
||||
double a = 0,
|
||||
FanMode mode = FanMode::FanIn,
|
||||
Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
|
||||
FanMode mode = torch::kFanIn,
|
||||
Nonlinearity nonlinearity = torch::kLeakyReLU);
|
||||
|
||||
/// Fills the input `Tensor` with values according to the method
|
||||
/// described in "Delving deep into rectifiers: Surpassing human-level
|
||||
|
|
@ -88,8 +110,8 @@ TORCH_API Tensor kaiming_normal_(
|
|||
TORCH_API Tensor kaiming_uniform_(
|
||||
Tensor tensor,
|
||||
double a = 0,
|
||||
FanMode mode = FanMode::FanIn,
|
||||
Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
|
||||
FanMode mode = torch::kFanIn,
|
||||
Nonlinearity nonlinearity = torch::kLeakyReLU);
|
||||
|
||||
/// Fills the input `Tensor` with values according to the method
|
||||
/// described in "Understanding the difficulty of training deep feedforward
|
||||
|
|
|
|||
|
|
@ -12,6 +12,22 @@
|
|||
#include <tuple>
|
||||
|
||||
namespace torch {
|
||||
|
||||
const nn::init::Nonlinearity kLinear = nn::init::Nonlinearity::Linear;
|
||||
const nn::init::Nonlinearity kConv1D = nn::init::Nonlinearity::Conv1D;
|
||||
const nn::init::Nonlinearity kConv2D = nn::init::Nonlinearity::Conv2D;
|
||||
const nn::init::Nonlinearity kConv3D = nn::init::Nonlinearity::Conv3D;
|
||||
const nn::init::Nonlinearity kConvTranspose1D = nn::init::Nonlinearity::ConvTranspose1D;
|
||||
const nn::init::Nonlinearity kConvTranspose2D = nn::init::Nonlinearity::ConvTranspose2D;
|
||||
const nn::init::Nonlinearity kConvTranspose3D = nn::init::Nonlinearity::ConvTranspose3D;
|
||||
const nn::init::Nonlinearity kSigmoid = nn::init::Nonlinearity::Sigmoid;
|
||||
const nn::init::Nonlinearity kTanh = nn::init::Nonlinearity::Tanh;
|
||||
const nn::init::Nonlinearity kReLU = nn::init::Nonlinearity::ReLU;
|
||||
const nn::init::Nonlinearity kLeakyReLU = nn::init::Nonlinearity::LeakyReLU;
|
||||
|
||||
const nn::init::FanMode kFanIn = nn::init::FanMode::FanIn;
|
||||
const nn::init::FanMode kFanOut = nn::init::FanMode::FanOut;
|
||||
|
||||
namespace nn {
|
||||
namespace init {
|
||||
namespace {
|
||||
|
|
@ -44,7 +60,7 @@ double calculate_kaiming_std(
|
|||
Fan fan(tensor);
|
||||
const auto gain = calculate_gain(nonlinearity, a);
|
||||
double std = 0.0;
|
||||
if (mode == FanMode::FanIn) {
|
||||
if (mode == torch::kFanIn) {
|
||||
std = gain / std::sqrt(fan.in);
|
||||
} else {
|
||||
std = gain / std::sqrt(fan.out);
|
||||
|
|
@ -54,12 +70,12 @@ double calculate_kaiming_std(
|
|||
} // namespace
|
||||
|
||||
double calculate_gain(Nonlinearity nonlinearity, double param) {
|
||||
if (nonlinearity == Nonlinearity::Tanh) {
|
||||
return 5.0 / 3.0;
|
||||
} else if (nonlinearity == Nonlinearity::ReLU) {
|
||||
return std::sqrt(2.0);
|
||||
} else if (nonlinearity == Nonlinearity::LeakyReLU) {
|
||||
return std::sqrt(2.0 / (1 + pow(param, 2)));
|
||||
if (nonlinearity == torch::kTanh) {
|
||||
return 5.0 / 3.0; // NOLINT
|
||||
} else if (nonlinearity == torch::kReLU) {
|
||||
return std::sqrt(2.0); // NOLINT
|
||||
} else if (nonlinearity == torch::kLeakyReLU) {
|
||||
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
|
||||
}
|
||||
|
||||
return 1.0;
|
||||
|
|
|
|||
Loading…
Reference in a new issue