pytorch/test/cpp/tensorexpr/test_graph_opt.cpp
Mikhail Zolotukhin 5d7cc8f22a [TensorExpr] Add some graph-rewrite passes to prepare models for AOT compilation. (#66515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66515

These passes should not be used generally as they change API of the
model's forward method, but they help experimenting with the model and
ironing out all the kinks before it can be compiled properly. In the
long run ideally we should provide a better way to enable such
experiments.

Differential Revision:
D31590862
D31590862

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: 74ded34c6c871d4cafa29f43dc27c7e71daff8fc
2022-01-07 01:03:53 -08:00

320 lines
10 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
#include <limits>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
class GraphOpt : public ::testing::Test {
public:
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
void SetUp() {
old_cat_wo_conditionals_ = getCatWoConditionals();
getCatWoConditionals() = true;
}
void TearDown() {
getCatWoConditionals() = old_cat_wo_conditionals_;
}
private:
bool old_cat_wo_conditionals_;
};
TEST_F(GraphOpt, OptimizeCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` op must be moved to the inputs of `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::log(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
// `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::log")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat3) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%a : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// But the `aten::mul` op must not be moved since it is not a single-tensor
// op (it has 2 tensor inputs).
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check("aten::mul")
->check_not("aten::tanh")
->run(*kernel.graph());
auto a = at::rand({60}, at::kFloat);
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
std::vector<at::Tensor> inputs = {a, x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Int(10, strides=[1], device=cpu),
%y : Int(20, strides=[1], device=cpu),
%z : Int(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// The scalar type of the inputs to `cat` should now be `Float` since they
// are the result of `tanh` which does the type promotion.
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
auto ref = at::tanh(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Double(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation should have happened because the `aten::cat` op performs
// type promotion. This case is currently not handled.
testing::FileCheck()
.check("aten::cat")
->check("aten::log")
->check_not("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check_not("aten::cat")
->check_not("aten::mul")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%1 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%one : int = prim::Constant[value=1]()
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check("aten::add")
->check_not("aten::cat")
->check_not("aten::mul")
->check_not("aten::add")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, AOTGraphPrepPasses) {
const auto graph_string = R"IR(
graph(%x, %y, %z, %i : int):
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
return (%xyz_list, %i))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
removeGraphOutput(g, 1);
replaceListOutputWithTuple(g);
LowerAllTuples(g);
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
}
} // namespace jit
} // namespace torch