2018-09-22 04:12:37 +00:00
|
|
|
#include <gtest/gtest.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
|
Re-organize C++ API `torch::nn` folder structure (#26262)
Summary:
This PR aims to re-organize C++ API `torch::nn` folder structure in the following way:
- Every module in `torch/csrc/api/include/torch/nn/modules/` (except `any.h`, `named_any.h`, `modulelist.h`, `sequential.h`, `embedding.h`) has a strictly equivalent Python file in `torch/nn/modules/`. For example:
`torch/csrc/api/include/torch/nn/modules/pooling.h` -> `torch/nn/modules/pooling.py`
`torch/csrc/api/include/torch/nn/modules/conv.h` -> `torch/nn/modules/conv.py`
`torch/csrc/api/include/torch/nn/modules/batchnorm.h` -> `torch/nn/modules/batchnorm.py`
`torch/csrc/api/include/torch/nn/modules/sparse.h` -> `torch/nn/modules/sparse.py`
- Containers such as `any.h`, `named_any.h`, `modulelist.h`, `sequential.h` are moved into `torch/csrc/api/include/torch/nn/modules/container/`, because their implementations are too long to be combined into one file (like `torch/nn/modules/container.py` in Python API)
- `embedding.h` is not renamed to `sparse.h` yet, because we have another work stream that works on API parity for Embedding and EmbeddingBag, and renaming the file would cause conflict. After the embedding API parity work is done, we will rename `embedding.h` to `sparse.h` to match the Python file name, and move the embedding options out to options/ folder.
- `torch/csrc/api/include/torch/nn/functional/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/functional/pooling.h` contains the functions for pooling, which are then used by the pooling modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`.
- `torch/csrc/api/include/torch/nn/options/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/options/pooling.h` contains MaxPoolOptions, which is used by both MaxPool modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`, and max_pool functions in `torch/csrc/api/include/torch/nn/functional/pooling.h`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26262
Differential Revision: D17422426
Pulled By: yf225
fbshipit-source-id: c413d2a374ba716dac81db31516619bbd879db7f
2019-09-17 17:05:11 +00:00
|
|
|
#include <torch/torch.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
|
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <algorithm>
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
using namespace torch::nn;
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct AnyModuleTest : torch::test::SeedingFixture {};
|
2018-09-12 22:39:27 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, SimpleReturnType) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
int forward() {
|
|
|
|
|
return 123;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_EQ(any.forward<int>(), 123);
|
|
|
|
|
}
|
2018-09-12 22:39:27 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_EQ(any.forward<int>(5), 5);
|
|
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
const char* forward(const char* x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
|
|
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
std::string forward(int x, const double f) {
|
|
|
|
|
return std::to_string(static_cast<int>(x + f));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
int x = 4;
|
|
|
|
|
ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(
|
|
|
|
|
AnyModuleTest,
|
|
|
|
|
TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(
|
|
|
|
|
std::string a,
|
|
|
|
|
const std::string& b,
|
|
|
|
|
std::string&& c) {
|
|
|
|
|
const auto s = a + b + c;
|
|
|
|
|
return torch::ones({static_cast<int64_t>(s.size())});
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_TRUE(
|
|
|
|
|
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
|
|
|
|
|
.sum()
|
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876
Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes().
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>"
codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
Reviewed By: ezyang
Differential Revision: D9948572
fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
|
|
|
.item<int32_t>() == 6);
|
2018-09-22 04:12:37 +00:00
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, WrongArgumentType) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
int forward(float x) {
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
2018-09-22 04:12:37 +00:00
|
|
|
return x;
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(5.0),
|
|
|
|
|
"Expected argument #0 to be of type float, "
|
|
|
|
|
"but received value of type double");
|
|
|
|
|
}
|
|
|
|
|
|
2020-02-18 04:33:51 +00:00
|
|
|
struct M_test_wrong_number_of_arguments : torch::nn::Module {
|
|
|
|
|
int forward(int a, int b) {
|
|
|
|
|
return a + b;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, WrongNumberOfArguments) {
|
2020-02-18 04:33:51 +00:00
|
|
|
AnyModule any(M_test_wrong_number_of_arguments{});
|
|
|
|
|
#if defined(_MSC_VER)
|
|
|
|
|
std::string module_name = "struct M_test_wrong_number_of_arguments";
|
|
|
|
|
#else
|
|
|
|
|
std::string module_name = "M_test_wrong_number_of_arguments";
|
|
|
|
|
#endif
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(),
|
2020-02-18 04:33:51 +00:00
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 2 argument(s), but received 0. "
|
|
|
|
|
"If " +
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method has default arguments, "
|
|
|
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(5),
|
2020-02-18 04:33:51 +00:00
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 2 argument(s), but received 1. "
|
|
|
|
|
"If " +
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method has default arguments, "
|
|
|
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(1, 2, 3),
|
2020-02-18 04:33:51 +00:00
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 2 argument(s), but received 3.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct M_default_arg_with_macro : torch::nn::Module {
|
|
|
|
|
double forward(int a, int b = 2, double c = 3.0) {
|
|
|
|
|
return a + b + c;
|
|
|
|
|
}
|
2022-06-11 17:22:58 +00:00
|
|
|
|
2020-02-18 04:33:51 +00:00
|
|
|
protected:
|
|
|
|
|
FORWARD_HAS_DEFAULT_ARGS(
|
|
|
|
|
{1, torch::nn::AnyValue(2)},
|
|
|
|
|
{2, torch::nn::AnyValue(3.0)})
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct M_default_arg_without_macro : torch::nn::Module {
|
|
|
|
|
double forward(int a, int b = 2, double c = 3.0) {
|
|
|
|
|
return a + b + c;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST_F(
|
|
|
|
|
AnyModuleTest,
|
|
|
|
|
PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
|
|
|
|
|
{
|
|
|
|
|
AnyModule any(M_default_arg_with_macro{});
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(any.forward<double>(1), 6.0);
|
|
|
|
|
ASSERT_EQ(any.forward<double>(1, 3), 7.0);
|
|
|
|
|
ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
|
|
|
|
|
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(),
|
|
|
|
|
"M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(1, 2, 3.0, 4),
|
|
|
|
|
"M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
AnyModule any(M_default_arg_without_macro{});
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
|
|
|
|
|
|
|
|
|
|
#if defined(_MSC_VER)
|
|
|
|
|
std::string module_name = "struct M_default_arg_without_macro";
|
|
|
|
|
#else
|
|
|
|
|
std::string module_name = "M_default_arg_without_macro";
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(),
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 3 argument(s), but received 0. "
|
|
|
|
|
"If " +
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method has default arguments, "
|
|
|
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward<double>(1),
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 3 argument(s), but received 1. "
|
|
|
|
|
"If " +
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method has default arguments, "
|
|
|
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward<double>(1, 3),
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 3 argument(s), but received 2. "
|
|
|
|
|
"If " +
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method has default arguments, "
|
|
|
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward(1, 2, 3.0, 4),
|
|
|
|
|
module_name +
|
|
|
|
|
"'s forward() method expects 3 argument(s), but received 4.");
|
|
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
|
|
|
int value;
|
|
|
|
|
int forward(float x) {
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
2018-09-22 04:12:37 +00:00
|
|
|
return x;
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
|
|
|
|
|
AnyModule any(M{5});
|
|
|
|
|
ASSERT_EQ(any.get<M>().value, 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
|
2018-11-14 00:30:05 +00:00
|
|
|
struct N : torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
};
|
2018-09-22 04:12:37 +00:00
|
|
|
AnyModule any(M{5});
|
|
|
|
|
ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
|
|
|
|
|
AnyModule any(M{5});
|
|
|
|
|
auto ptr = any.ptr();
|
|
|
|
|
ASSERT_NE(ptr, nullptr);
|
|
|
|
|
ASSERT_EQ(ptr->name(), "M");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
|
|
|
|
|
AnyModule any(M{5});
|
|
|
|
|
auto ptr = any.ptr<M>();
|
|
|
|
|
ASSERT_NE(ptr, nullptr);
|
|
|
|
|
ASSERT_EQ(ptr->value, 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
|
2018-11-14 00:30:05 +00:00
|
|
|
struct N : torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
};
|
2018-09-22 04:12:37 +00:00
|
|
|
AnyModule any(M{5});
|
|
|
|
|
ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
explicit M(int value_) : value(value_) {}
|
|
|
|
|
int value;
|
|
|
|
|
int forward(float x) {
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
2018-09-22 04:12:37 +00:00
|
|
|
return x;
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
AnyModule any;
|
|
|
|
|
ASSERT_TRUE(any.is_empty());
|
|
|
|
|
any = std::make_shared<M>(5);
|
|
|
|
|
ASSERT_FALSE(any.is_empty());
|
|
|
|
|
ASSERT_EQ(any.get<M>().value, 5);
|
|
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
AnyModule any;
|
|
|
|
|
ASSERT_TRUE(any.is_empty());
|
|
|
|
|
ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
|
|
|
|
|
ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
|
|
|
|
|
ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.type_info(), "Cannot call type_info() on an empty AnyModule");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
|
|
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
std::string forward(int x) {
|
|
|
|
|
return std::to_string(x);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
struct N : torch::nn::Module {
|
|
|
|
|
int forward(float x) {
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
2018-09-22 04:12:37 +00:00
|
|
|
return 3 + x;
|
2018-06-28 13:30:36 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
};
|
|
|
|
|
AnyModule any;
|
|
|
|
|
ASSERT_TRUE(any.is_empty());
|
|
|
|
|
any = std::make_shared<M>();
|
|
|
|
|
ASSERT_FALSE(any.is_empty());
|
|
|
|
|
ASSERT_EQ(any.forward<std::string>(5), "5");
|
|
|
|
|
any = std::make_shared<N>();
|
|
|
|
|
ASSERT_FALSE(any.is_empty());
|
|
|
|
|
ASSERT_EQ(any.forward<int>(5.0f), 8);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
|
|
|
|
|
struct MImpl : torch::nn::Module {
|
|
|
|
|
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
|
|
|
int value;
|
|
|
|
|
int forward(float x) {
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
2018-09-22 04:12:37 +00:00
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
|
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
|
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
AnyModule any(M{5});
|
|
|
|
|
ASSERT_EQ(any.get<MImpl>().value, 5);
|
|
|
|
|
ASSERT_EQ(any.get<M>()->value, 5);
|
|
|
|
|
|
|
|
|
|
AnyModule module(Linear(3, 4));
|
|
|
|
|
std::shared_ptr<Module> ptr = module.ptr();
|
|
|
|
|
Linear linear(module.get<Linear>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// When you have an autograd::Variable, it should be converted to a
|
|
|
|
|
// torch::Tensor before being passed to the function (to avoid a type
|
|
|
|
|
// mismatch).
|
|
|
|
|
AnyModule any(M{});
|
|
|
|
|
ASSERT_TRUE(
|
2018-11-06 01:08:38 +00:00
|
|
|
any.forward(torch::autograd::Variable(torch::ones(5)))
|
|
|
|
|
.sum()
|
|
|
|
|
.item<float>() == 5);
|
2018-09-22 04:12:37 +00:00
|
|
|
// at::Tensors that are not variables work too.
|
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876
Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes().
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>"
codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
Reviewed By: ezyang
Differential Revision: D9948572
fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
|
|
|
ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace nn {
|
2020-02-18 04:33:51 +00:00
|
|
|
struct TestAnyValue {
|
2018-05-24 19:46:51 +00:00
|
|
|
template <typename T>
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
|
2020-02-18 04:33:51 +00:00
|
|
|
explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
|
|
|
|
|
AnyValue operator()() {
|
2018-05-24 19:46:51 +00:00
|
|
|
return std::move(value_);
|
|
|
|
|
}
|
2020-02-18 04:33:51 +00:00
|
|
|
AnyValue value_;
|
2018-05-24 19:46:51 +00:00
|
|
|
};
|
|
|
|
|
template <typename T>
|
2020-02-18 04:33:51 +00:00
|
|
|
AnyValue make_value(T&& value) {
|
|
|
|
|
return TestAnyValue(std::forward<T>(value))();
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
} // namespace nn
|
|
|
|
|
} // namespace torch
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct AnyValueTest : torch::test::SeedingFixture {};
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
|
2019-10-02 15:01:17 +00:00
|
|
|
auto value = make_value<int>(5);
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
2019-10-02 15:01:17 +00:00
|
|
|
// const and non-const types have the same typeid(),
|
|
|
|
|
// but casting Holder<int> to Holder<const int> is undefined
|
|
|
|
|
// behavior according to UBSAN:
|
|
|
|
|
// https://github.com/pytorch/pytorch/issues/26964
|
|
|
|
|
// ASSERT_NE(value.try_get<const int>(), nullptr);
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(value.get<int>(), 5);
|
|
|
|
|
}
|
2019-10-02 15:01:17 +00:00
|
|
|
// This test does not work at all, because it looks like make_value
|
|
|
|
|
// decays const int into int.
|
|
|
|
|
// TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
|
|
|
|
|
// auto value = make_value<const int>(5);
|
|
|
|
|
// ASSERT_NE(value.try_get<const int>(), nullptr);
|
|
|
|
|
// // ASSERT_NE(value.try_get<int>(), nullptr);
|
|
|
|
|
// ASSERT_EQ(value.get<const int>(), 5);
|
|
|
|
|
//}
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
|
|
|
|
|
auto value = make_value("hello");
|
|
|
|
|
ASSERT_NE(value.try_get<const char*>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.get<const char*>(), std::string("hello"));
|
|
|
|
|
}
|
|
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
|
|
|
|
|
auto value = make_value(std::string("hello"));
|
|
|
|
|
ASSERT_NE(value.try_get<std::string>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.get<std::string>(), "hello");
|
|
|
|
|
}
|
|
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
|
|
|
|
|
std::string s("hello");
|
|
|
|
|
std::string* p = &s;
|
|
|
|
|
auto value = make_value(p);
|
|
|
|
|
ASSERT_NE(value.try_get<std::string*>(), nullptr);
|
|
|
|
|
ASSERT_EQ(*value.get<std::string*>(), "hello");
|
|
|
|
|
}
|
|
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
|
|
|
|
|
std::string s("hello");
|
|
|
|
|
const std::string& t = s;
|
|
|
|
|
auto value = make_value(t);
|
|
|
|
|
ASSERT_NE(value.try_get<std::string>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.get<std::string>(), "hello");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
|
|
|
|
|
auto value = make_value(5);
|
|
|
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.try_get<float>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.try_get<long>(), nullptr);
|
|
|
|
|
ASSERT_EQ(value.try_get<std::string>(), nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
|
|
|
|
|
auto value = make_value(5);
|
|
|
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
value.get<float>(),
|
2020-02-18 04:33:51 +00:00
|
|
|
"Attempted to cast AnyValue to float, "
|
2018-09-22 04:12:37 +00:00
|
|
|
"but its actual type is int");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
value.get<long>(),
|
2020-02-18 04:33:51 +00:00
|
|
|
"Attempted to cast AnyValue to long, "
|
2018-09-22 04:12:37 +00:00
|
|
|
"but its actual type is int");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
|
|
|
|
|
auto value = make_value(5);
|
|
|
|
|
auto copy = make_value(std::move(value));
|
|
|
|
|
ASSERT_NE(copy.try_get<int>(), nullptr);
|
|
|
|
|
ASSERT_EQ(copy.get<int>(), 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
|
|
|
|
|
auto value = make_value(5);
|
|
|
|
|
auto copy = make_value(10);
|
|
|
|
|
copy = std::move(value);
|
|
|
|
|
ASSERT_NE(copy.try_get<int>(), nullptr);
|
|
|
|
|
ASSERT_EQ(copy.get<int>(), 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
|
|
|
|
|
auto value = make_value(5);
|
|
|
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
|
|
|
|
|
auto value = make_value("hello");
|
|
|
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
|
|
|
|
|
auto value = make_value(std::string("hello"));
|
|
|
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|