mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66515 These passes should not be used generally as they change API of the model's forward method, but they help experimenting with the model and ironing out all the kinks before it can be compiled properly. In the long run ideally we should provide a better way to enable such experiments. Differential Revision: D31590862 D31590862 Test Plan: Imported from OSS Reviewed By: navahgar Pulled By: ZolotukhinM fbshipit-source-id: 74ded34c6c871d4cafa29f43dc27c7e71daff8fc
320 lines
10 KiB
C++
320 lines
10 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/lower_tuples.h>
|
|
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <limits>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
class GraphOpt : public ::testing::Test {
|
|
public:
|
|
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
|
void SetUp() {
|
|
old_cat_wo_conditionals_ = getCatWoConditionals();
|
|
getCatWoConditionals() = true;
|
|
}
|
|
|
|
void TearDown() {
|
|
getCatWoConditionals() = old_cat_wo_conditionals_;
|
|
}
|
|
|
|
private:
|
|
bool old_cat_wo_conditionals_;
|
|
};
|
|
|
|
TEST_F(GraphOpt, OptimizeCat) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::log` op must be moved to the inputs of `aten::cat`.
|
|
testing::FileCheck()
|
|
.check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::cat")
|
|
->check_not("aten::log")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::log(at::cat({x, y, z}, 0));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCat2) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
|
|
// `aten::cat`.
|
|
testing::FileCheck()
|
|
.check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check_not("aten::log")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCat3) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
|
// But the `aten::mul` op must not be moved since it is not a single-tensor
|
|
// op (it has 2 tensor inputs).
|
|
testing::FileCheck()
|
|
.check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check("aten::mul")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto a = at::rand({60}, at::kFloat);
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
|
|
|
|
std::vector<at::Tensor> inputs = {a, x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Int(10, strides=[1], device=cpu),
|
|
%y : Int(20, strides=[1], device=cpu),
|
|
%z : Int(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
|
// The scalar type of the inputs to `cat` should now be `Float` since they
|
|
// are the result of `tanh` which does the type promotion.
|
|
testing::FileCheck()
|
|
.check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
|
|
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
|
|
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
|
|
auto ref = at::tanh(at::cat({x, y, z}, 0));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Double(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation should have happened because the `aten::cat` op performs
|
|
// type promotion. This case is currently not handled.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::log")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::log")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation is expected since the consumers of cat are not
|
|
// single-tensor element-wise ops.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::mul")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::mul")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(60, strides=[1], device=cpu),
|
|
%1 : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%one : int = prim::Constant[value=1]()
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation is expected since the consumers of cat are not
|
|
// single-tensor element-wise ops.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::mul")
|
|
->check("aten::add")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::mul")
|
|
->check_not("aten::add")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, AOTGraphPrepPasses) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x, %y, %z, %i : int):
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
return (%xyz_list, %i))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
removeGraphOutput(g, 1);
|
|
replaceListOutputWithTuple(g);
|
|
LowerAllTuples(g);
|
|
|
|
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|