mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63776 I reverted this out of an abundance of caution because some test failures occurred, but they were all due to precision issues fixed lower in this stack. Let's try again. I've rolled the elimination of the allow-parallelism-in-fusions toggle into this diff since they're pretty tightly coupled. ghstack-source-id: 136529847 Test Plan: CI Reviewed By: huiguoo Differential Revision: D30484555 fbshipit-source-id: 38fd33520f710585d1130c365a8c60c9ce794a59
372 lines
12 KiB
C++
372 lines
12 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <torch/csrc/jit/codegen/fuser/interface.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <sstream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
struct WithCPUFuser {
|
|
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
|
|
overrideCanFuseOnCPU(val);
|
|
}
|
|
|
|
~WithCPUFuser() {
|
|
overrideCanFuseOnCPU(cpuFuserEnabled);
|
|
}
|
|
|
|
bool cpuFuserEnabled;
|
|
};
|
|
|
|
TEST(TEFuserPass, FuserPass_1) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(128, strides=[1], device=cpu),
|
|
%1 : Float(128, strides=[1], device=cpu)):
|
|
%12 : int = prim::Constant[value=1]()
|
|
%2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
|
|
%2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
|
|
%3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
|
|
%4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
|
|
%5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g);
|
|
|
|
// We should not be able to fuse across the in-place operation here.
|
|
testing::FileCheck()
|
|
.check("prim::TensorExprGroup_")
|
|
->check("aten::add_")
|
|
->check("prim::TensorExprGroup_")
|
|
->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_2) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(128, strides=[1], device=cpu),
|
|
%1 : Float(128, strides=[1], device=cpu)):
|
|
%12 : int = prim::Constant[value=1]()
|
|
%a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
|
|
%b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
|
|
%c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
|
|
%d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
|
|
return (%d))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g);
|
|
|
|
// We should not be able to fuse across the in-place operation here.
|
|
testing::FileCheck()
|
|
.check("aten::add_")
|
|
->check("prim::TensorExprGroup_0")
|
|
->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_3) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(128, strides=[1], device=cpu),
|
|
%y : Float(128, strides=[1], device=cpu)):
|
|
%r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
|
|
return (%r))IR";
|
|
{
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
|
|
// We should not create a fusion group since its size would be too small
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// We should create a fusion group since its size is above the threshold
|
|
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_0DimInput) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(device=cpu),
|
|
%y : Float(device=cpu)):
|
|
%one : int = prim::Constant[value=1]()
|
|
%a : Float(device=cpu) = aten::mul(%x, %y)
|
|
%b : Float(device=cpu) = aten::add(%x, %a, %one)
|
|
return (%b))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g);
|
|
|
|
// We should fuse 0-dim tensors too
|
|
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
|
|
WithCPUFuser cf(false);
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(10, strides=[1], device=cpu)):
|
|
%a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
|
|
return (%a))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// Test that we're not starting fusion groups from nodes with unfusible device
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_UnknownShapes) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Tensor,
|
|
%y : Tensor):
|
|
%a : Tensor = aten::mul(%x, %y)
|
|
%b : Tensor = aten::mul(%x, %a)
|
|
return (%b))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g);
|
|
|
|
// Test that we're not generating fusion groups when shapes are not known
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_Multidevice) {
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
return (%cat))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// We should be able to fuse this
|
|
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cuda:0),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
return (%cat))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// We should not fuse this aten::cat since its inputs are from different
|
|
// devices
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(10, strides=[1], device=cuda:0)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xy_list : Tensor[] = prim::ListConstruct(%x, %y)
|
|
%xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
|
|
%r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
|
|
return (%r))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
|
|
// Test that we check device before merging one node (cat) into another
|
|
// (mul)
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(10, strides=[1], device=cuda:0)):
|
|
%z2 : Tensor = aten::mul(%z, %z)
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
|
|
return (%cat))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
|
|
// Test that we check device before merging one node (mul) into another
|
|
// (cat)
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cuda:0)):
|
|
%r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
|
|
return (%r))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// We should not fuse this graph since its inputs are from different devices
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
{
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cuda:0),
|
|
%y : Float(20, strides=[1], device=cuda:1),
|
|
%z : Float(20, strides=[1], device=cpu)):
|
|
%x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
|
|
%y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
|
|
%z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
|
|
return (%x2, %y2, %z2))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
|
|
// We should not fuse these two computations since they use different
|
|
// devices
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_MergeGroups) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Float(128, strides=[1], device=cpu),
|
|
%b : Float(128, strides=[1], device=cpu)):
|
|
%x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
|
|
%y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
|
|
return (%x, %y))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 1);
|
|
|
|
// The %x and %y computations are completely independent and yet we should put
|
|
// them into a single fusion group rather than having two separate ones.
|
|
testing::FileCheck()
|
|
.check("= prim::TensorExprGroup_")
|
|
->check_not("= prim::TensorExprGroup_")
|
|
->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_UnknownShapesIgnored) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(device=cpu),
|
|
%y : Float(device=cpu)):
|
|
%a : Float(device=cpu) = aten::mul(%x, %y)
|
|
%b : Float(device=cpu) = aten::mul(%x, %a)
|
|
return (%b))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2, /* disable_shape_checks= */ true);
|
|
|
|
// Test that we are generating fusion groups even though shapes are not known
|
|
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Bool(8, strides=[1], device=cpu),
|
|
%y : Bool(8, strides=[1], device=cpu)):
|
|
%a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
|
|
%b : Tensor = aten::__or__(%a, %y)
|
|
return (%b)
|
|
)IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_Where) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(8, strides=[1], device=cpu),
|
|
%y : Float(8, strides=[1], device=cpu),
|
|
%z : Float(8, strides=[1], device=cpu)):
|
|
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
|
|
%b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
|
|
return (%b)
|
|
)IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
TEST(TEFuserPass, FuserPass_WhereList) {
|
|
WithCPUFuser cf;
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(8, strides=[1], device=cpu),
|
|
%y : Float(8, strides=[1], device=cpu),
|
|
%z : Float(8, strides=[1], device=cpu)):
|
|
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
|
|
%b : Tensor[] = aten::where(%cond)
|
|
return (%b)
|
|
)IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
FuseTensorExprs(g, /* min_group_size= */ 2);
|
|
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|