2018-09-22 04:12:37 +00:00
|
|
|
#include <gtest/gtest.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
|
|
|
|
|
#include <torch/detail/static.h>
|
|
|
|
|
#include <torch/csrc/utils/variadic.h>
|
Re-organize C++ API `torch::nn` folder structure (#26262)
Summary:
This PR aims to re-organize C++ API `torch::nn` folder structure in the following way:
- Every module in `torch/csrc/api/include/torch/nn/modules/` (except `any.h`, `named_any.h`, `modulelist.h`, `sequential.h`, `embedding.h`) has a strictly equivalent Python file in `torch/nn/modules/`. For example:
`torch/csrc/api/include/torch/nn/modules/pooling.h` -> `torch/nn/modules/pooling.py`
`torch/csrc/api/include/torch/nn/modules/conv.h` -> `torch/nn/modules/conv.py`
`torch/csrc/api/include/torch/nn/modules/batchnorm.h` -> `torch/nn/modules/batchnorm.py`
`torch/csrc/api/include/torch/nn/modules/sparse.h` -> `torch/nn/modules/sparse.py`
- Containers such as `any.h`, `named_any.h`, `modulelist.h`, `sequential.h` are moved into `torch/csrc/api/include/torch/nn/modules/container/`, because their implementations are too long to be combined into one file (like `torch/nn/modules/container.py` in Python API)
- `embedding.h` is not renamed to `sparse.h` yet, because we have another work stream that works on API parity for Embedding and EmbeddingBag, and renaming the file would cause conflict. After the embedding API parity work is done, we will rename `embedding.h` to `sparse.h` to match the Python file name, and move the embedding options out to options/ folder.
- `torch/csrc/api/include/torch/nn/functional/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/functional/pooling.h` contains the functions for pooling, which are then used by the pooling modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`.
- `torch/csrc/api/include/torch/nn/options/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/options/pooling.h` contains MaxPoolOptions, which is used by both MaxPool modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`, and max_pool functions in `torch/csrc/api/include/torch/nn/functional/pooling.h`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26262
Differential Revision: D17422426
Pulled By: yf225
fbshipit-source-id: c413d2a374ba716dac81db31516619bbd879db7f
2019-09-17 17:05:11 +00:00
|
|
|
#include <torch/torch.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
|
|
typename T,
|
|
|
|
|
typename = torch::enable_if_t<!torch::detail::is_module<T>::value>>
|
|
|
|
|
bool f(T&& m) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
torch::detail::enable_if_module_t<T, bool> f(T&& m) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST(TestStatic, AllOf) {
|
|
|
|
|
ASSERT_TRUE(torch::all_of<>::value);
|
|
|
|
|
ASSERT_TRUE(torch::all_of<true>::value);
|
|
|
|
|
ASSERT_TRUE((torch::all_of<true, true, true>::value));
|
|
|
|
|
ASSERT_FALSE(torch::all_of<false>::value);
|
|
|
|
|
ASSERT_FALSE((torch::all_of<false, false, false>::value));
|
|
|
|
|
ASSERT_FALSE((torch::all_of<true, true, false>::value));
|
2018-09-18 00:26:32 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST(TestStatic, AnyOf) {
|
|
|
|
|
ASSERT_FALSE(torch::any_of<>::value);
|
|
|
|
|
ASSERT_TRUE(bool((torch::any_of<true>::value)));
|
|
|
|
|
ASSERT_TRUE(bool((torch::any_of<true, true, true>::value)));
|
|
|
|
|
ASSERT_FALSE(bool((torch::any_of<false>::value)));
|
2018-09-18 00:26:32 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST(TestStatic, EnableIfModule) {
|
|
|
|
|
ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
|
|
|
|
|
ASSERT_FALSE(f(5));
|
|
|
|
|
ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
|
|
|
|
|
ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
|
|
|
|
|
ASSERT_FALSE(
|
2018-09-18 00:26:32 +00:00
|
|
|
(torch::detail::check_not_lvalue_references<float, int&, char>()));
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
|
|
|
|
|
ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
|
2018-09-18 00:26:32 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
|
2019-01-14 22:32:32 +00:00
|
|
|
struct A : torch::nn::Module {
|
|
|
|
|
int forward() {
|
|
|
|
|
return 5;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct B : torch::nn::Module {
|
|
|
|
|
std::string forward(torch::Tensor tensor) {
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct C : torch::nn::Module {
|
|
|
|
|
float forward(torch::Tensor& tensor) {
|
|
|
|
|
return 5.0;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct D : torch::nn::Module {
|
|
|
|
|
char forward(torch::Tensor&& tensor) {
|
|
|
|
|
return 'x';
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct E : torch::nn::Module {};
|
|
|
|
|
|
|
|
|
|
// Put in a function because macros don't handle the comma between arguments to
|
|
|
|
|
// is_same well ...
|
|
|
|
|
template <typename Module, typename ExpectedType, typename... Args>
|
|
|
|
|
void assert_has_expected_type() {
|
|
|
|
|
using ReturnType =
|
|
|
|
|
typename torch::detail::return_type_of_forward<Module, Args...>::type;
|
|
|
|
|
constexpr bool is_expected_type =
|
|
|
|
|
std::is_same<ReturnType, ExpectedType>::value;
|
|
|
|
|
ASSERT_TRUE(is_expected_type) << Module().name();
|
|
|
|
|
}
|
|
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
2019-01-14 22:32:32 +00:00
|
|
|
TEST(TestStatic, ReturnTypeOfForward) {
|
|
|
|
|
assert_has_expected_type<A, int>();
|
|
|
|
|
assert_has_expected_type<B, std::string, torch::Tensor>();
|
|
|
|
|
assert_has_expected_type<C, float, torch::Tensor&>();
|
|
|
|
|
assert_has_expected_type<D, char, torch::Tensor&&>();
|
|
|
|
|
assert_has_expected_type<E, void>();
|
|
|
|
|
}
|
|
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST(TestStatic, Apply) {
|
2018-09-18 00:26:32 +00:00
|
|
|
std::vector<int> v;
|
|
|
|
|
torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(v.size(), 5);
|
2018-09-18 00:26:32 +00:00
|
|
|
for (size_t i = 0; i < v.size(); ++i) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(v.at(i), i + 1);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
}
|