2020-09-25 18:35:39 +00:00
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
2021-05-19 18:37:29 +00:00
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
2019-03-15 20:53:23 +00:00
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
2020-09-25 18:35:39 +00:00
|
|
|
TEST(IRTest, Attributes) {
|
2019-03-15 20:53:23 +00:00
|
|
|
Graph g;
|
|
|
|
|
auto one = attr::alpha;
|
|
|
|
|
auto two = attr::device;
|
|
|
|
|
auto three = attr::end;
|
|
|
|
|
auto four = attr::perm;
|
|
|
|
|
Node* n = g.create(Symbol::fromQualString("foo::bar"));
|
|
|
|
|
Node& attr = *n;
|
|
|
|
|
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
|
|
|
|
|
ASSERT_EQ(attr.f(one), 3.4);
|
|
|
|
|
ASSERT_EQ(attr.s(three), "what");
|
|
|
|
|
ASSERT_EQ(attr.i(two), 5);
|
|
|
|
|
attr.s_(one, "no");
|
|
|
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
|
|
|
ASSERT_TRUE(attr.hasAttribute(three));
|
|
|
|
|
ASSERT_TRUE(!attr.hasAttribute(four));
|
|
|
|
|
attr.ss_(two, {"hi", "now"});
|
|
|
|
|
ASSERT_EQ(attr.ss(two).at(1), "now");
|
|
|
|
|
|
|
|
|
|
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
|
|
|
|
|
Node& attr2 = *n2;
|
|
|
|
|
attr2.copyAttributes(attr);
|
|
|
|
|
ASSERT_EQ(attr2.s(one), "no");
|
|
|
|
|
attr2.f_(one, 5);
|
|
|
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
|
|
|
ASSERT_EQ(attr2.f(one), 5);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 18:35:39 +00:00
|
|
|
TEST(IRTest, Blocks) {
|
2019-03-15 20:53:23 +00:00
|
|
|
auto g = std::make_shared<Graph>();
|
2019-08-20 03:47:50 +00:00
|
|
|
const auto graph_string = R"IR(
|
|
|
|
|
graph(%a : Tensor,
|
|
|
|
|
%b : Tensor,
|
|
|
|
|
%c : Tensor):
|
|
|
|
|
%2 : int = prim::Constant[value=1]()
|
|
|
|
|
%3 : Tensor = aten::add(%a, %b, %2)
|
|
|
|
|
%5 : Tensor = prim::If(%c)
|
|
|
|
|
block0():
|
|
|
|
|
%6 : int = prim::Constant[value=1]()
|
|
|
|
|
%7 : Tensor = aten::add(%3, %3, %6)
|
|
|
|
|
-> (%7)
|
|
|
|
|
block1():
|
|
|
|
|
%8 : int = prim::Constant[value=1]()
|
|
|
|
|
%9 : Tensor = aten::add(%b, %3, %8)
|
|
|
|
|
%10 : int = prim::Constant[value=1]()
|
|
|
|
|
%11 : Tensor = aten::add(%9, %3, %10)
|
|
|
|
|
-> (%11)
|
|
|
|
|
%12 : int = prim::Constant[value=1]()
|
|
|
|
|
%13 : Tensor = aten::add(%5, %3, %12)
|
|
|
|
|
return (%13))IR";
|
2020-03-12 06:29:34 +00:00
|
|
|
torch::jit::parseIR(graph_string, g.get());
|
2019-08-20 03:47:50 +00:00
|
|
|
|
2019-03-15 20:53:23 +00:00
|
|
|
g->lint();
|
|
|
|
|
testing::FileCheck()
|
|
|
|
|
.check("add")
|
|
|
|
|
->check("prim::If")
|
|
|
|
|
->check("block0")
|
|
|
|
|
->check("aten::add")
|
|
|
|
|
->check("block1")
|
|
|
|
|
->check_count("aten::add", 3)
|
|
|
|
|
->run(*g);
|
2019-08-20 03:47:50 +00:00
|
|
|
|
|
|
|
|
// Removes block0 of the conditional
|
|
|
|
|
for (auto* node : g->block()->nodes()) {
|
|
|
|
|
if (node->kind() == prim::If) {
|
|
|
|
|
node->eraseBlock(0);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-03-15 20:53:23 +00:00
|
|
|
testing::FileCheck()
|
|
|
|
|
.check("add")
|
|
|
|
|
->check("prim::If")
|
|
|
|
|
->check("block0")
|
|
|
|
|
->check_not("block")
|
|
|
|
|
->run(*g);
|
|
|
|
|
g->lint();
|
|
|
|
|
// test recursive copy of blocks works
|
|
|
|
|
auto g2 = g->copy();
|
|
|
|
|
testing::FileCheck()
|
|
|
|
|
.check("add")
|
|
|
|
|
->check("prim::If")
|
|
|
|
|
->check("block0")
|
|
|
|
|
->check_not("block")
|
|
|
|
|
->run(*g2);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 18:35:39 +00:00
|
|
|
TEST(IRTest, CommonAncestor) {
|
2019-11-06 23:02:56 +00:00
|
|
|
std::string input_str = R"(
|
|
|
|
|
graph(%x : Tensor,
|
|
|
|
|
%a.1 : bool,
|
|
|
|
|
%b.1 : bool,
|
|
|
|
|
%c.1 : bool):
|
|
|
|
|
%4 : int = prim::If(%a.1)
|
|
|
|
|
block0():
|
|
|
|
|
%5 : int = prim::If(%b.1)
|
|
|
|
|
block0():
|
|
|
|
|
%6 : int = prim::Constant[value=2]()
|
|
|
|
|
-> (%6)
|
|
|
|
|
block1():
|
|
|
|
|
%7 : int = prim::Constant[value=3]()
|
|
|
|
|
-> (%7)
|
|
|
|
|
-> (%5)
|
|
|
|
|
block1():
|
|
|
|
|
%8 : int = prim::If(%c.1)
|
|
|
|
|
block0():
|
|
|
|
|
%9 : int = prim::Constant[value=4]()
|
|
|
|
|
-> (%9)
|
|
|
|
|
block1():
|
|
|
|
|
%10 : int = prim::Constant[value=5]()
|
|
|
|
|
-> (%10)
|
|
|
|
|
-> (%8)
|
|
|
|
|
return (%4)
|
|
|
|
|
)";
|
|
|
|
|
|
|
|
|
|
torch::jit::Graph g;
|
|
|
|
|
std::unordered_map<std::string, torch::jit::Value*> name_to_value;
|
2020-03-12 06:29:34 +00:00
|
|
|
torch::jit::parseIR(input_str, &g, name_to_value);
|
2019-11-06 23:02:56 +00:00
|
|
|
|
|
|
|
|
std::vector<std::string> value_names{"6", "7", "9", "10"};
|
|
|
|
|
std::unordered_set<std::string> value_names_set(
|
|
|
|
|
value_names.begin(), value_names.end());
|
|
|
|
|
|
|
|
|
|
/* clang-format off */
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
2019-11-06 23:02:56 +00:00
|
|
|
int ref_blocks_from_graph[4][4] = {
|
|
|
|
|
/* (6, 6), (6, 7), (6, 9), (6, 10) */
|
|
|
|
|
{ 2, 1, 0, 0 },
|
|
|
|
|
/* (7, 6), (7, 7), (7, 9), (7, 10) */
|
|
|
|
|
{ 1, 2, 0, 0 },
|
|
|
|
|
/* (9, 6), (9, 7), (9, 9), (9, 10) */
|
|
|
|
|
{ 0, 0, 2, 1, },
|
|
|
|
|
/* (10, 6),(10, 7),(10, 9),(10, 10) */
|
|
|
|
|
{ 0, 0, 1, 2 }
|
|
|
|
|
};
|
|
|
|
|
/* clang-format on */
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < value_names.size(); ++i) {
|
|
|
|
|
Value* i_val = name_to_value[value_names[i]];
|
|
|
|
|
for (size_t j = 0; j < value_names.size(); ++j) {
|
|
|
|
|
Value* j_val = name_to_value[value_names[j]];
|
|
|
|
|
Block* common_ancestor =
|
|
|
|
|
i_val->node()->findCommonAncestorBlockWith(j_val->node());
|
|
|
|
|
int blocks_from_graph_block =
|
|
|
|
|
common_ancestor->param_node()->blocksFromGraphBlock();
|
|
|
|
|
ASSERT_EQ(blocks_from_graph_block, ref_blocks_from_graph[i][j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-05-19 18:37:29 +00:00
|
|
|
TEST(IRTest, OperatorMap) {
|
|
|
|
|
OperatorMap<int> op_map;
|
|
|
|
|
const char* literal1 =
|
|
|
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor";
|
|
|
|
|
const char* literal2 =
|
|
|
|
|
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor";
|
|
|
|
|
const char* literal3 =
|
|
|
|
|
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor";
|
|
|
|
|
const char* literal4 =
|
|
|
|
|
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor";
|
|
|
|
|
const char* literal5 =
|
|
|
|
|
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor";
|
|
|
|
|
const char* literal6 =
|
|
|
|
|
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor";
|
|
|
|
|
std::shared_ptr<Operator> op1 = getOperatorForLiteral(literal1);
|
|
|
|
|
std::shared_ptr<Operator> op2 = getOperatorForLiteral(literal2);
|
|
|
|
|
std::shared_ptr<Operator> op3 = getOperatorForLiteral(literal3);
|
|
|
|
|
std::shared_ptr<Operator> op4 = getOperatorForLiteral(literal4);
|
|
|
|
|
std::shared_ptr<Operator> op5 = getOperatorForLiteral(literal5);
|
|
|
|
|
std::shared_ptr<Operator> op6 = getOperatorForLiteral(literal6);
|
|
|
|
|
op_map.insert(op1, 1);
|
|
|
|
|
op_map.insert({{op2, 2}, {op3, 3}});
|
|
|
|
|
op_map.insert({{op4, 4}, {op5, 5}});
|
|
|
|
|
op_map.insert(op6, 6);
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op1));
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op2));
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op3));
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op4));
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op5));
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op6));
|
|
|
|
|
op_map.erase(op6);
|
|
|
|
|
op_map.erase(op3);
|
|
|
|
|
op_map.erase(op1);
|
|
|
|
|
ASSERT_FALSE(op_map.contains(*op1));
|
|
|
|
|
ASSERT_FALSE(op_map.contains(*op3));
|
|
|
|
|
ASSERT_FALSE(op_map.contains(*op6));
|
|
|
|
|
op_map.insert(op1, 1);
|
|
|
|
|
ASSERT_TRUE(op_map.contains(*op1));
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o1 = op_map.find(*op1);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_TRUE(o1.has_value());
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o2 = op_map.find(*op2);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_TRUE(o2.has_value());
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o3 = op_map.find(*op3);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_FALSE(o3.has_value());
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o4 = op_map.find(*op4);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_TRUE(o4.has_value());
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o5 = op_map.find(*op5);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_TRUE(o5.has_value());
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<int> o6 = op_map.find(*op6);
|
2021-05-19 18:37:29 +00:00
|
|
|
ASSERT_FALSE(o6.has_value());
|
|
|
|
|
}
|
|
|
|
|
|
2019-03-15 20:53:23 +00:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|